diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index 6eea6833dd..640d8d9801 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -43,6 +43,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An todos.push_back(node); } + std::set DynamicShapeConstInputToAttr = {kCastOpName, kExpandDimsOpName}; for (auto &t : todos) { CNodePtr cnode = t->cast(); ConstInputToAttrInfoRegister reg; @@ -61,7 +62,8 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An continue; } } - if (AnfAlgo::IsDynamicShape(cnode)) { + if (AnfAlgo::IsDynamicShape(cnode) && + DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) { MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); continue; } diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 7bf40dbd33..a87a3ec4d8 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -223,7 +223,16 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); - +AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index e208224544..6619337d10 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -96,5 +96,132 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p CheckArgsSize(op_name, args_spec_list, 1); return args_spec_list[0]->Broaden(); } + +AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto input_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_x->shape()); + + auto input_y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(input_y); + MS_EXCEPTION_IF_NULL(input_y->shape()); + + auto x_shape = input_x->shape()->shape(); + auto y_shape = input_y->shape()->shape(); + auto output_shape = BroadcastShape(x_shape, y_shape); + if (output_shape.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + + auto x_type = input_x->BuildType(); + MS_EXCEPTION_IF_NULL(x_type); + MS_EXCEPTION_IF_NULL(x_type->cast()); + auto y_type = input_y->BuildType(); + MS_EXCEPTION_IF_NULL(y_type); + MS_EXCEPTION_IF_NULL(y_type->cast()); + + auto x_element = x_type->cast()->element(); + MS_EXCEPTION_IF_NULL(x_element); + auto y_element = y_type->cast()->element(); + MS_EXCEPTION_IF_NULL(y_element); + + auto x_element_type = x_element->number_type(); + auto y_element_type = y_element->number_type(); + + auto x_priority = type_priority_map.find(x_element_type); + if (x_priority == type_priority_map.end()) { + MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type."; + } + auto y_priority = type_priority_map.find(y_element_type); + if (y_priority == type_priority_map.end()) { + MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type."; + } + + if (x_priority->second >= y_priority->second) { + return std::make_shared(input_x->element(), std::make_shared(output_shape)); + } else { + return std::make_shared(input_y->element(), std::make_shared(output_shape)); + } +} + +AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto input_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_x->shape()); + + auto input_y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(input_y); + MS_EXCEPTION_IF_NULL(input_y->shape()); + + auto x_shape = input_x->shape()->shape(); + auto y_shape = input_y->shape()->shape(); + auto out_shape = BroadcastShape(x_shape, y_shape); + if (out_shape.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + + auto output_type = std::make_shared(); + auto ret = std::make_shared(output_type, out_shape); + return ret; +} + +AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto input_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_x->shape()); + + auto input_y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(input_y); + MS_EXCEPTION_IF_NULL(input_y->shape()); + + auto x_shape = input_x->shape()->shape(); + auto y_shape = input_y->shape()->shape(); + auto output_shape = BroadcastShape(x_shape, y_shape); + if (output_shape.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + + auto x_type = input_x->BuildType(); + MS_EXCEPTION_IF_NULL(x_type); + MS_EXCEPTION_IF_NULL(x_type->cast()); + auto y_type = input_y->BuildType(); + MS_EXCEPTION_IF_NULL(y_type); + MS_EXCEPTION_IF_NULL(y_type->cast()); + + auto x_element = x_type->cast()->element(); + MS_EXCEPTION_IF_NULL(x_element); + auto y_element = y_type->cast()->element(); + MS_EXCEPTION_IF_NULL(y_element); + + auto x_element_type = x_element->number_type(); + auto y_element_type = y_element->number_type(); + + auto x_priority = type_priority_map.find(x_element_type); + if (x_priority == type_priority_map.end()) { + MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type."; + } + auto y_priority = type_priority_map.find(y_element_type); + if (y_priority == type_priority_map.end()) { + MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type."; + } + + if (x_priority->second >= y_priority->second) { + return std::make_shared(input_x->element(), std::make_shared(output_shape)); + } else { + return std::make_shared(input_y->element(), std::make_shared(output_shape)); + } +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 549c6085a3..9aa7406bae 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -440,5 +440,52 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP MS_EXCEPTION_IF_NULL(x->shape()); return std::make_shared(x->element(), std::make_shared(x->shape()->shape())); } + +AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto input_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_x->shape()); + auto input_type = primitive->GetAttr("dst_type")->cast(); + auto ret = std::make_shared(input_type, input_x->shape()->shape()); + return ret; +} + +AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + + auto axis = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(axis); + + std::vector shape; + std::vector x_shape = x->shape()->shape(); + shape.insert(shape.end(), x_shape.begin(), x_shape.end()); + + auto axis_value = axis->BuildValue(); + if (!axis_value->isa()) { + MS_LOG(EXCEPTION) << axis_value << " axis_value should be tensor, but got " << axis_value->type_name(); + } + auto axis_tensor = axis_value->cast(); + int value = *(static_cast(axis_tensor->data_c())); + if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) { + MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value + << " and input_x.dim is" << x_shape.size(); + } + if (value < 0) { + value = value + SizeToInt(x_shape.size()) + 1; + } + shape.insert(shape.begin() + value, 1); + + auto ret = std::make_shared(x->element(), std::make_shared(shape)); + return ret; +} + } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index acde4cfc71..ae18479c56 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -37,12 +37,13 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { // Maths {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimMul, {InferImplMul, true}}, {prim::kPrimTensorAdd, {InferImplTensorAdd, true}}, {prim::kPrimSquare, {InferImplSquare, true}}, - {prim::kPrimSqrt, {InferImplSqrt, true}}, + {prim::kPrimSub, {InferImplSub, true}}, + {prim::kPrimEqual, {InferImplEqual, true}}, + {prim::kPrimMinimum, {InferImplMinimum, true}}, // Array {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, @@ -128,6 +129,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimAllGather, {InferImplAllGather, true}}, {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, + {prim::kPrimCast, {InferImplCast, true}}, + {prim::kPrimExpandDims, {InferImplExpandDims, true}}, }; return prim_eval_implement_map; } diff --git a/mindspore/core/ir/dtype/type.h b/mindspore/core/ir/dtype/type.h index 4e6ff01c1d..52ba22f300 100644 --- a/mindspore/core/ir/dtype/type.h +++ b/mindspore/core/ir/dtype/type.h @@ -127,6 +127,11 @@ const std::unordered_map type_name_map = { {kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"}, {kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}}; +const std::unordered_map type_priority_map = { + {kNumberTypeBool, 0}, {kNumberTypeUInt8, 1}, {kNumberTypeInt8, 2}, + {kNumberTypeInt16, 3}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 5}, + {kNumberTypeFloat16, 6}, {kNumberTypeFloat32, 7}, {kNumberTypeFloat64, 8}}; + std::ostream &operator<<(std::ostream &os, const TypePtrList &types); } // namespace mindspore diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 8042f10026..3ce89abed3 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -50,6 +50,7 @@ from .batchnorm_grad import _batch_norm_grad_tbe from .bias_add import _bias_add_tbe from .bias_add_grad import _bias_add_grad_tbe from .cast import _cast_tbe +from .cast_ds import _cast_ds_tbe from .conv2d import _conv2d_tbe from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe from .conv2d_backprop_input import _conv2d_backprop_input_tbe @@ -80,6 +81,7 @@ from .trans_data import _trans_data_tbe from .top_k import _top_k_tbe from .matmul import _matmul_tbe from .sub import _sub_tbe +from .sub_ds import _sub_ds_tbe from .scatter_nd import _scatter_nd_tbe from .scatter_nd_d import _scatter_nd_d_tbe from .scatter_nd_add import _scatter_nd_add_tbe @@ -121,6 +123,7 @@ from .npu_get_float_status import _npu_get_float_status_tbe from .npu_alloc_float_status import _npu_alloc_float_status_tbe from .one_hot import _one_hot_tbe from .equal import _equal_tbe +from .equal_ds import _equal_ds_tbe from .less import _less_tbe from .less_equal import _less_equal_tbe from .logical_and import _logical_and_tbe @@ -159,6 +162,7 @@ from .select import _select_tbe from .pow import _pow_tbe from .maximum import _maximum_tbe from .minimum import _minimum_tbe +from .minimum_ds import _minimum_ds_tbe from .minimum_grad import _minimum_grad_tbe from .maximum_grad import _maximum_grad_tbe from .concat import _concat_tbe diff --git a/mindspore/ops/_op_impl/tbe/cast_ds.py b/mindspore/ops/_op_impl/tbe/cast_ds.py new file mode 100644 index 0000000000..bb9c472a07 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/cast_ds.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Cast op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +cast_ds_op_info = TBERegOp("Cast") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("cast.so") \ + .compute_cost(10) \ + .kernel_name("cast") \ + .partial_flag(True) \ + .attr("dst_type", "required", "int", "all") \ + .dynamic_shape(True)\ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.BOOL_None, DataType.F16_None) \ + .dtype_format(DataType.BOOL_None, DataType.U8_None) \ + .dtype_format(DataType.BOOL_None, DataType.F32_None) \ + .dtype_format(DataType.BOOL_None, DataType.I32_None) \ + .dtype_format(DataType.I8_None, DataType.F16_None) \ + .dtype_format(DataType.I8_None, DataType.F32_None) \ + .dtype_format(DataType.I8_None, DataType.I32_None) \ + .dtype_format(DataType.U8_None, DataType.F16_None) \ + .dtype_format(DataType.U8_None, DataType.F32_None) \ + .dtype_format(DataType.U8_None, DataType.I32_None) \ + .dtype_format(DataType.I32_None, DataType.BOOL_None) \ + .dtype_format(DataType.I32_None, DataType.F16_None) \ + .dtype_format(DataType.I32_None, DataType.F32_None) \ + .dtype_format(DataType.I32_None, DataType.I8_None) \ + .dtype_format(DataType.I32_None, DataType.U8_None) \ + .dtype_format(DataType.F16_None, DataType.U8_None) \ + .dtype_format(DataType.F16_None, DataType.F32_None) \ + .dtype_format(DataType.F16_None, DataType.I32_None) \ + .dtype_format(DataType.F32_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.I32_None) \ + .get_op_info() + + +@op_info_register(cast_ds_op_info) +def _cast_ds_tbe(): + """Cast TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/equal_ds.py b/mindspore/ops/_op_impl/tbe/equal_ds.py new file mode 100644 index 0000000000..76a218c14f --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/equal_ds.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Equal op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +equal_ds_op_info = TBERegOp("Equal") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("equal.so") \ + .compute_cost(10) \ + .kernel_name("equal") \ + .partial_flag(True) \ + .dynamic_shape(True)\ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(equal_ds_op_info) +def _equal_ds_tbe(): + """Equal TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/minimum_ds.py b/mindspore/ops/_op_impl/tbe/minimum_ds.py new file mode 100644 index 0000000000..3dbe6a710d --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/minimum_ds.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +"""Minimum op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +minimum_ds_op_info = TBERegOp("Minimum") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("minimum.so") \ + .compute_cost(10) \ + .kernel_name("minimum") \ + .partial_flag(True) \ + .dynamic_shape(True)\ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(minimum_ds_op_info) +def _minimum_ds_tbe(): + """Minimum TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sub_ds.py b/mindspore/ops/_op_impl/tbe/sub_ds.py new file mode 100644 index 0000000000..b38fdb098d --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sub_ds.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sub op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sub_ds_op_info = TBERegOp("Sub") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("sub.so") \ + .compute_cost(10) \ + .kernel_name("sub") \ + .partial_flag(True) \ + .dynamic_shape(True)\ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(sub_ds_op_info) +def _sub_ds_tbe(): + """Add TBE register""" + return