| @@ -27,11 +27,11 @@ accumulate_n_v2_op_info = TBERegOp("AccumulateNV2") \ | |||||
| .input(0, "x", False, "dynamic", "all") \ | .input(0, "x", False, "dynamic", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("broadcast") \ | .op_pattern("broadcast") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.I8_None, DataType.I8_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.U8_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -28,10 +28,8 @@ approximate_equal_op_info = TBERegOp("ApproximateEqual") \ | |||||
| .input(0, "x1", False, "required", "all") \ | .input(0, "x1", False, "required", "all") \ | ||||
| .input(1, "x2", False, "required", "all") \ | .input(1, "x2", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.BOOL_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.BOOL_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -28,10 +28,8 @@ binary_cross_entropy_op_info = TBERegOp("BinaryCrossEntropy") \ | |||||
| .input(1, "y", False, "required", "all") \ | .input(1, "y", False, "required", "all") \ | ||||
| .input(2, "weight", False, "optional", "all") \ | .input(2, "weight", False, "optional", "all") \ | ||||
| .output(0, "output", False, "required", "all") \ | .output(0, "output", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -29,8 +29,8 @@ lin_space_op_info = TBERegOp("LinSpace") \ | |||||
| .input(2, "stop", False, "required", "all") \ | .input(2, "stop", False, "required", "all") \ | ||||
| .input(3, "num", False, "required", "all") \ | .input(3, "num", False, "required", "all") \ | ||||
| .output(0, "output", False, "required", "all") \ | .output(0, "output", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | |||||
| DataType.F32_Default,) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.I32_None, | |||||
| DataType.F32_None,) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,16 +26,12 @@ mod_op_info = TBERegOp("Mod") \ | |||||
| .input(0, "x1", False, "required", "all") \ | .input(0, "x1", False, "required", "all") \ | ||||
| .input(1, "x2", False, "required", "all") \ | .input(1, "x2", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .op_pattern("broadcast") \ | |||||
| .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -27,10 +27,11 @@ reduce_mean_d_op_info = TBERegOp("ReduceMeanD") \ | |||||
| .attr("keep_dims", "optional", "bool", "all") \ | .attr("keep_dims", "optional", "bool", "all") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .op_pattern("reduce") \ | |||||
| .dtype_format(DataType.I8_None, DataType.I8_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.U8_None) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,8 +26,8 @@ softsign_op_info = TBERegOp("Softsign") \ | |||||
| .op_pattern("formatAgnostic") \ | .op_pattern("formatAgnostic") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -29,28 +29,7 @@ split_v_op_info = TBERegOp("SplitV") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "dynamic", "all") \ | .output(0, "y", False, "dynamic", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | .op_pattern("dynamicFormat") \ | ||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -3055,7 +3055,7 @@ class InplaceUpdate(PrimitiveWithInfer): | |||||
| raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') | raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| for idx in range(x_rank)[1:]: | for idx in range(x_rank)[1:]: | ||||
| validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name) | |||||
| validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||
| @@ -947,7 +947,7 @@ class InplaceAdd(PrimitiveWithInfer): | |||||
| raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') | raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| for idx in range(x_rank)[1:]: | for idx in range(x_rank)[1:]: | ||||
| validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name) | |||||
| validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||
| @@ -1005,7 +1005,7 @@ class InplaceSub(PrimitiveWithInfer): | |||||
| raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') | raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| for idx in range(x_rank)[1:]: | for idx in range(x_rank)[1:]: | ||||
| validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name) | |||||
| validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||