# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: select""" import akg.topi import akg.tvm import akg.lang.cce from akg.utils import validation_check as vc_util from akg.utils.format_transform import get_shape from akg.utils import kernel_exec as utils VALUE_ONE = 1 def select_compute(condition, x1, x2): """select compute implementation""" shape = get_shape(x1) con_shape = get_shape(condition) num_dtype = x1.dtype bool_dtype = condition.dtype if num_dtype in ("int8", "uint8"): x1_dtype = "float32" ones = akg.lang.cce.broadcast(akg.tvm.const(VALUE_ONE, dtype="float32"), shape, output_dtype="float32") x1 = akg.topi.cast(x1, "float32") x2 = akg.topi.cast(x2, "float32") else: x1_dtype = num_dtype ones = akg.lang.cce.broadcast(akg.tvm.const(VALUE_ONE, dtype=num_dtype), shape, output_dtype=num_dtype) if bool_dtype == "int8": if x1_dtype == "int32": condition_dtype = akg.lang.cce.ceil(condition) else: condition_dtype = akg.topi.cast(condition, x1_dtype) else: if x1_dtype == "int32": condition_dtype = condition else: condition_dtype = akg.topi.cast(condition, x1_dtype) if list(con_shape) != list(shape): condition_dtype = akg.lang.cce.broadcast(condition_dtype, shape) vinsn_support_dtype = ("float16", "float32") if utils.product_is_mini(): vinsn_support_dtype = ("float16", ) if num_dtype in vinsn_support_dtype: res = akg.topi.where(condition_dtype, x1, x2) else: condition_opp = akg.lang.cce.vsub(ones, condition_dtype) temp_x = akg.lang.cce.vmul(x1, condition_dtype) temp_y = akg.lang.cce.vmul(x2, condition_opp) res = akg.lang.cce.vadd(temp_x, temp_y) if num_dtype in ("int8", "uint8"): res = akg.topi.cast(res, num_dtype) return res @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor) def select(condition, x1, x2): """ Selects elements from x1 or x2, depending on condition. Note: every parmas' shape need legal, can support condition's shape broadcast. Args: condition (tvm.tensor.Tensor): Tensor of type int8, int32, must be 0 or 1. x1 (tvm.tensor.Tensor): Tensor of type float16, float32, int8, int32, uint8. x2 (tvm.tensor.Tensor): Tensor of type float16, float32, int8, int32, uint8. Returns: tvm.tensor.Tensor, has the same type and shape as x1. """ shape_x1 = get_shape(x1) shape_x2 = get_shape(x2) con_shape = get_shape(condition) vc_util.elemwise_shape_check(shape_x1, shape_x2) vc_util.elemwise_dtype_check(x1.dtype, x2.dtype, [vc_util.DtypeForDavinci.ALL_FLOAT, vc_util.DtypeForDavinci.INT8, vc_util.DtypeForDavinci.INT32, vc_util.DtypeForDavinci.UINT8]) vc_util.ops_dtype_check(condition.dtype, [vc_util.DtypeForDavinci.INT8, vc_util.DtypeForDavinci.INT32]) vc_util.auto_broadcast_check(con_shape, shape_x1) res = select_compute(condition, x1, x2) return res