|
|
|
@@ -13,13 +13,13 @@ |
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
"""operator dsl function: equal""" |
|
|
|
import akg.tvm |
|
|
|
import akg.topi |
|
|
|
from akg.utils.dsl_create import produce_shapes |
|
|
|
from akg.utils import validation_check as vc_util |
|
|
|
import _akg.tvm |
|
|
|
import _akg.topi |
|
|
|
from _akg.utils.dsl_create import produce_shapes |
|
|
|
from _akg.utils import validation_check as vc_util |
|
|
|
|
|
|
|
|
|
|
|
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor) |
|
|
|
@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) |
|
|
|
def equal(input1, input2): |
|
|
|
""" |
|
|
|
check whether input1 equals to input2. |
|
|
|
@@ -42,13 +42,13 @@ def equal(input1, input2): |
|
|
|
dtype = input1.dtype |
|
|
|
|
|
|
|
# get equal compute |
|
|
|
t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T") |
|
|
|
f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F") |
|
|
|
|
|
|
|
input1_bro = akg.topi.broadcast_to(input1, shape) |
|
|
|
input2_bro = akg.topi.broadcast_to(input2, shape) |
|
|
|
c_out = akg.tvm.compute(shape, lambda *indice: akg.tvm.expr.Select(input1_bro[indice] == input2_bro[indice], |
|
|
|
t_value[indice], f_value[indice]), name="C") |
|
|
|
res = akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") |
|
|
|
t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T") |
|
|
|
f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F") |
|
|
|
|
|
|
|
input1_bro = _akg.topi.broadcast_to(input1, shape) |
|
|
|
input2_bro = _akg.topi.broadcast_to(input2, shape) |
|
|
|
c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] == input2_bro[indice], |
|
|
|
t_value[indice], f_value[indice]), name="C") |
|
|
|
res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") |
|
|
|
|
|
|
|
return res |