From: @hwjiaorui Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -43,6 +43,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An | |||
| todos.push_back(node); | |||
| } | |||
| std::set<string> DynamicShapeConstInputToAttr = {kCastOpName, kExpandDimsOpName}; | |||
| for (auto &t : todos) { | |||
| CNodePtr cnode = t->cast<CNodePtr>(); | |||
| ConstInputToAttrInfoRegister reg; | |||
| @@ -61,7 +62,8 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An | |||
| continue; | |||
| } | |||
| } | |||
| if (AnfAlgo::IsDynamicShape(cnode)) { | |||
| if (AnfAlgo::IsDynamicShape(cnode) && | |||
| DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) { | |||
| MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); | |||
| continue; | |||
| } | |||
| @@ -223,7 +223,16 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv | |||
| AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| @@ -96,5 +96,132 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||
| auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(input_y); | |||
| MS_EXCEPTION_IF_NULL(input_y->shape()); | |||
| auto x_shape = input_x->shape()->shape(); | |||
| auto y_shape = input_y->shape()->shape(); | |||
| auto output_shape = BroadcastShape(x_shape, y_shape); | |||
| if (output_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," | |||
| << args_spec_list[1]->ToString(); | |||
| } | |||
| auto x_type = input_x->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(x_type); | |||
| MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>()); | |||
| auto y_type = input_y->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(y_type); | |||
| MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>()); | |||
| auto x_element = x_type->cast<TensorTypePtr>()->element(); | |||
| MS_EXCEPTION_IF_NULL(x_element); | |||
| auto y_element = y_type->cast<TensorTypePtr>()->element(); | |||
| MS_EXCEPTION_IF_NULL(y_element); | |||
| auto x_element_type = x_element->number_type(); | |||
| auto y_element_type = y_element->number_type(); | |||
| auto x_priority = type_priority_map.find(x_element_type); | |||
| if (x_priority == type_priority_map.end()) { | |||
| MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type."; | |||
| } | |||
| auto y_priority = type_priority_map.find(y_element_type); | |||
| if (y_priority == type_priority_map.end()) { | |||
| MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type."; | |||
| } | |||
| if (x_priority->second >= y_priority->second) { | |||
| return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape)); | |||
| } else { | |||
| return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape)); | |||
| } | |||
| } | |||
| AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||
| auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(input_y); | |||
| MS_EXCEPTION_IF_NULL(input_y->shape()); | |||
| auto x_shape = input_x->shape()->shape(); | |||
| auto y_shape = input_y->shape()->shape(); | |||
| auto out_shape = BroadcastShape(x_shape, y_shape); | |||
| if (out_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," | |||
| << args_spec_list[1]->ToString(); | |||
| } | |||
| auto output_type = std::make_shared<Bool>(); | |||
| auto ret = std::make_shared<AbstractTensor>(output_type, out_shape); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||
| auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(input_y); | |||
| MS_EXCEPTION_IF_NULL(input_y->shape()); | |||
| auto x_shape = input_x->shape()->shape(); | |||
| auto y_shape = input_y->shape()->shape(); | |||
| auto output_shape = BroadcastShape(x_shape, y_shape); | |||
| if (output_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," | |||
| << args_spec_list[1]->ToString(); | |||
| } | |||
| auto x_type = input_x->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(x_type); | |||
| MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>()); | |||
| auto y_type = input_y->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(y_type); | |||
| MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>()); | |||
| auto x_element = x_type->cast<TensorTypePtr>()->element(); | |||
| MS_EXCEPTION_IF_NULL(x_element); | |||
| auto y_element = y_type->cast<TensorTypePtr>()->element(); | |||
| MS_EXCEPTION_IF_NULL(y_element); | |||
| auto x_element_type = x_element->number_type(); | |||
| auto y_element_type = y_element->number_type(); | |||
| auto x_priority = type_priority_map.find(x_element_type); | |||
| if (x_priority == type_priority_map.end()) { | |||
| MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type."; | |||
| } | |||
| auto y_priority = type_priority_map.find(y_element_type); | |||
| if (y_priority == type_priority_map.end()) { | |||
| MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type."; | |||
| } | |||
| if (x_priority->second >= y_priority->second) { | |||
| return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape)); | |||
| } else { | |||
| return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape)); | |||
| } | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -440,5 +440,52 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape())); | |||
| } | |||
| AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||
| auto input_type = primitive->GetAttr("dst_type")->cast<TypePtr>(); | |||
| auto ret = std::make_shared<AbstractTensor>(input_type, input_x->shape()->shape()); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| auto axis = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(axis); | |||
| std::vector<int64_t> shape; | |||
| std::vector<int64_t> x_shape = x->shape()->shape(); | |||
| shape.insert(shape.end(), x_shape.begin(), x_shape.end()); | |||
| auto axis_value = axis->BuildValue(); | |||
| if (!axis_value->isa<tensor::Tensor>()) { | |||
| MS_LOG(EXCEPTION) << axis_value << " axis_value should be tensor, but got " << axis_value->type_name(); | |||
| } | |||
| auto axis_tensor = axis_value->cast<tensor::TensorPtr>(); | |||
| int value = *(static_cast<int *>(axis_tensor->data_c())); | |||
| if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) { | |||
| MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value | |||
| << " and input_x.dim is" << x_shape.size(); | |||
| } | |||
| if (value < 0) { | |||
| value = value + SizeToInt(x_shape.size()) + 1; | |||
| } | |||
| shape.insert(shape.begin() + value, 1); | |||
| auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape)); | |||
| return ret; | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -37,12 +37,13 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| // Maths | |||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMul, {InferImplMul, true}}, | |||
| {prim::kPrimTensorAdd, {InferImplTensorAdd, true}}, | |||
| {prim::kPrimSquare, {InferImplSquare, true}}, | |||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||
| {prim::kPrimSub, {InferImplSub, true}}, | |||
| {prim::kPrimEqual, {InferImplEqual, true}}, | |||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | |||
| // Array | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| @@ -128,6 +129,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | |||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | |||
| {prim::kPrimCast, {InferImplCast, true}}, | |||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| @@ -127,6 +127,11 @@ const std::unordered_map<TypeId, std::string> type_name_map = { | |||
| {kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"}, | |||
| {kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}}; | |||
| const std::unordered_map<TypeId, int> type_priority_map = { | |||
| {kNumberTypeBool, 0}, {kNumberTypeUInt8, 1}, {kNumberTypeInt8, 2}, | |||
| {kNumberTypeInt16, 3}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 5}, | |||
| {kNumberTypeFloat16, 6}, {kNumberTypeFloat32, 7}, {kNumberTypeFloat64, 8}}; | |||
| std::ostream &operator<<(std::ostream &os, const TypePtrList &types); | |||
| } // namespace mindspore | |||
| @@ -50,6 +50,7 @@ from .batchnorm_grad import _batch_norm_grad_tbe | |||
| from .bias_add import _bias_add_tbe | |||
| from .bias_add_grad import _bias_add_grad_tbe | |||
| from .cast import _cast_tbe | |||
| from .cast_ds import _cast_ds_tbe | |||
| from .conv2d import _conv2d_tbe | |||
| from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe | |||
| from .conv2d_backprop_input import _conv2d_backprop_input_tbe | |||
| @@ -80,6 +81,7 @@ from .trans_data import _trans_data_tbe | |||
| from .top_k import _top_k_tbe | |||
| from .matmul import _matmul_tbe | |||
| from .sub import _sub_tbe | |||
| from .sub_ds import _sub_ds_tbe | |||
| from .scatter_nd import _scatter_nd_tbe | |||
| from .scatter_nd_d import _scatter_nd_d_tbe | |||
| from .scatter_nd_add import _scatter_nd_add_tbe | |||
| @@ -121,6 +123,7 @@ from .npu_get_float_status import _npu_get_float_status_tbe | |||
| from .npu_alloc_float_status import _npu_alloc_float_status_tbe | |||
| from .one_hot import _one_hot_tbe | |||
| from .equal import _equal_tbe | |||
| from .equal_ds import _equal_ds_tbe | |||
| from .less import _less_tbe | |||
| from .less_equal import _less_equal_tbe | |||
| from .logical_and import _logical_and_tbe | |||
| @@ -159,6 +162,7 @@ from .select import _select_tbe | |||
| from .pow import _pow_tbe | |||
| from .maximum import _maximum_tbe | |||
| from .minimum import _minimum_tbe | |||
| from .minimum_ds import _minimum_ds_tbe | |||
| from .minimum_grad import _minimum_grad_tbe | |||
| from .maximum_grad import _maximum_grad_tbe | |||
| from .concat import _concat_tbe | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cast op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| cast_ds_op_info = TBERegOp("Cast") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("cast.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("cast") \ | |||
| .partial_flag(True) \ | |||
| .attr("dst_type", "required", "int", "all") \ | |||
| .dynamic_shape(True)\ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .dtype_format(DataType.BOOL_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.BOOL_None, DataType.U8_None) \ | |||
| .dtype_format(DataType.BOOL_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.BOOL_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.BOOL_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I8_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.U8_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.U8_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.I32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(cast_ds_op_info) | |||
| def _cast_ds_tbe(): | |||
| """Cast TBE register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Equal op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| equal_ds_op_info = TBERegOp("Equal") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("equal.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("equal") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True)\ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(equal_ds_op_info) | |||
| def _equal_ds_tbe(): | |||
| """Equal TBE register""" | |||
| return | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Minimum op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| minimum_ds_op_info = TBERegOp("Minimum") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("minimum.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("minimum") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True)\ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(minimum_ds_op_info) | |||
| def _minimum_ds_tbe(): | |||
| """Minimum TBE register""" | |||
| return | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Sub op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sub_ds_op_info = TBERegOp("Sub") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sub.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sub") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True)\ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(sub_ds_op_info) | |||
| def _sub_ds_tbe(): | |||
| """Add TBE register""" | |||
| return | |||