| @@ -61,6 +61,7 @@ class OpIOInfo { | |||||
| std::string name() const { return name_; } | std::string name() const { return name_; } | ||||
| bool need_compile() const { return need_compile_; } | bool need_compile() const { return need_compile_; } | ||||
| std::string param_type() const { return param_type_; } | std::string param_type() const { return param_type_; } | ||||
| std::string reshape_type() const { return reshape_type_; } | |||||
| std::string shape() const { return shape_; } | std::string shape() const { return shape_; } | ||||
| std::vector<std::string> dtypes() const { return dtypes_; } | std::vector<std::string> dtypes() const { return dtypes_; } | ||||
| std::vector<std::string> formats() const { return formats_; } | std::vector<std::string> formats() const { return formats_; } | ||||
| @@ -69,6 +70,7 @@ class OpIOInfo { | |||||
| void set_name(const std::string& name) { name_ = name; } | void set_name(const std::string& name) { name_ = name; } | ||||
| void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } | void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } | ||||
| void set_param_type(const std::string& param_type) { param_type_ = param_type; } | void set_param_type(const std::string& param_type) { param_type_ = param_type; } | ||||
| void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; } | |||||
| void set_shape(const std::string& shape) { shape_ = shape; } | void set_shape(const std::string& shape) { shape_ = shape; } | ||||
| void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; } | void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; } | ||||
| void set_formats(const std::vector<std::string>& formats) { formats_ = formats; } | void set_formats(const std::vector<std::string>& formats) { formats_ = formats; } | ||||
| @@ -78,6 +80,7 @@ class OpIOInfo { | |||||
| std::string name_; | std::string name_; | ||||
| bool need_compile_ = false; | bool need_compile_ = false; | ||||
| std::string param_type_; | std::string param_type_; | ||||
| std::string reshape_type_; | |||||
| std::string shape_; | std::string shape_; | ||||
| std::vector<std::string> dtypes_; | std::vector<std::string> dtypes_; | ||||
| std::vector<std::string> formats_; | std::vector<std::string> formats_; | ||||
| @@ -96,6 +99,8 @@ class OpInfo { | |||||
| int compute_cost() const { return compute_cost_; } | int compute_cost() const { return compute_cost_; } | ||||
| std::string kernel_name() const { return kernel_name_; } | std::string kernel_name() const { return kernel_name_; } | ||||
| bool partial_flag() const { return partial_flag_; } | bool partial_flag() const { return partial_flag_; } | ||||
| bool dynamic_format() const { return dynamic_format_; } | |||||
| std::string op_pattern() const { return op_pattern_; } | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | ||||
| @@ -110,6 +115,8 @@ class OpInfo { | |||||
| void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } | void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } | ||||
| void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } | void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } | ||||
| void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | ||||
| void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } | |||||
| void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } | |||||
| void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); } | void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); } | ||||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); } | void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); } | ||||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); } | void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); } | ||||
| @@ -129,6 +136,8 @@ class OpInfo { | |||||
| int compute_cost_ = 0; | int compute_cost_ = 0; | ||||
| std::string kernel_name_; | std::string kernel_name_; | ||||
| bool partial_flag_ = false; | bool partial_flag_ = false; | ||||
| bool dynamic_format_ = false; | |||||
| std::string op_pattern_; | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; | ||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; | std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; | ||||
| @@ -26,18 +26,22 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| constexpr auto kImplyType = "imply_type"; | constexpr auto kImplyType = "imply_type"; | ||||
| constexpr auto kOpName = "op_name"; | constexpr auto kOpName = "op_name"; | ||||
| constexpr auto kTbe = "TBE"; | |||||
| constexpr auto kAkg = "akg"; | |||||
| constexpr auto kAutodiff = "AutoDiff"; | |||||
| constexpr auto kFusionType = "fusion_type"; | constexpr auto kFusionType = "fusion_type"; | ||||
| constexpr auto kAsyncFlag = "async_flag"; | constexpr auto kAsyncFlag = "async_flag"; | ||||
| constexpr auto kBinfileName = "binfile_name"; | constexpr auto kBinfileName = "binfile_name"; | ||||
| constexpr auto kComputeCost = "compute_cost"; | constexpr auto kComputeCost = "compute_cost"; | ||||
| constexpr auto kKernelName = "kernel_name"; | constexpr auto kKernelName = "kernel_name"; | ||||
| constexpr auto kPartialFlag = "partial_flag"; | constexpr auto kPartialFlag = "partial_flag"; | ||||
| constexpr auto kReshapeType = "reshape_type"; | |||||
| constexpr auto kOpPattern = "op_pattern"; | |||||
| constexpr auto kDynamicFormat = "dynamic_format"; | |||||
| constexpr auto kDtypeFormat = "dtype_format"; | |||||
| constexpr auto kAttr = "attr"; | constexpr auto kAttr = "attr"; | ||||
| constexpr auto kIputs = "inputs"; | constexpr auto kIputs = "inputs"; | ||||
| constexpr auto kOutputs = "outputs"; | constexpr auto kOutputs = "outputs"; | ||||
| constexpr auto kTbe = "TBE"; | |||||
| constexpr auto kAkg = "akg"; | |||||
| constexpr auto kAutodiff = "AutoDiff"; | |||||
| constexpr auto kName = "name"; | constexpr auto kName = "name"; | ||||
| constexpr auto kParamType = "param_type"; | constexpr auto kParamType = "param_type"; | ||||
| constexpr auto kDtype = "dtype"; | constexpr auto kDtype = "dtype"; | ||||
| @@ -89,8 +93,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||||
| std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>(); | std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>(); | ||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| op_info->set_op_name(obj.at(kOpName)); | op_info->set_op_name(obj.at(kOpName)); | ||||
| op_info->set_imply_type(imply_type); | |||||
| op_info->set_impl_path(impl_path); | op_info->set_impl_path(impl_path); | ||||
| op_info->set_imply_type(imply_type); | |||||
| op_info->set_fusion_type(obj.at(kFusionType)); | op_info->set_fusion_type(obj.at(kFusionType)); | ||||
| if (imply_type == kTBE) { | if (imply_type == kTBE) { | ||||
| op_info->set_async_flag(obj.at(kAsyncFlag)); | op_info->set_async_flag(obj.at(kAsyncFlag)); | ||||
| @@ -98,6 +102,12 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||||
| op_info->set_compute_cost(obj.at(kComputeCost)); | op_info->set_compute_cost(obj.at(kComputeCost)); | ||||
| op_info->set_kernel_name(obj.at(kKernelName)); | op_info->set_kernel_name(obj.at(kKernelName)); | ||||
| op_info->set_partial_flag(obj.at(kPartialFlag)); | op_info->set_partial_flag(obj.at(kPartialFlag)); | ||||
| if (obj.find(kOpPattern) != obj.end()) { | |||||
| op_info->set_op_pattern(obj.at(kOpPattern)); | |||||
| } | |||||
| if (obj.find(kDynamicFormat) != obj.end()) { | |||||
| op_info->set_dynamic_format(obj.at(kDynamicFormat)); | |||||
| } | |||||
| } | } | ||||
| auto attrs = obj.at(kAttr); | auto attrs = obj.at(kAttr); | ||||
| for (const auto& attr : attrs) { | for (const auto& attr : attrs) { | ||||
| @@ -106,16 +116,20 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| nlohmann::json dtype_format; | |||||
| if (obj.find(kDtypeFormat) != obj.end()) { | |||||
| dtype_format = obj.at(kDtypeFormat); | |||||
| } | |||||
| auto inputs = obj.at(kIputs); | auto inputs = obj.at(kIputs); | ||||
| for (const auto& input : inputs) { | for (const auto& input : inputs) { | ||||
| if (!DecodeInputOutput(input, imply_type, kInput, op_info)) { | |||||
| if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { | |||||
| MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| auto outputs = obj.at(kOutputs); | auto outputs = obj.at(kOutputs); | ||||
| for (const auto& output : outputs) { | for (const auto& output : outputs) { | ||||
| if (!DecodeInputOutput(output, imply_type, kOutput, op_info)) { | |||||
| if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { | |||||
| MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -156,16 +170,42 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io, | |||||
| size_t index) { | |||||
| bool ret = true; | |||||
| try { | |||||
| std::vector<std::string> dtype; | |||||
| std::vector<std::string> format; | |||||
| for (const auto& it : dtype_format) { | |||||
| dtype.emplace_back(it[index][0]); | |||||
| format.emplace_back(it[index][1]); | |||||
| } | |||||
| op_io->set_dtypes(dtype); | |||||
| op_io->set_formats(format); | |||||
| } catch (const std::exception& e) { | |||||
| MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); | |||||
| ret = false; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, | bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, | ||||
| const std::shared_ptr<OpInfo>& op_info) { | |||||
| const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) { | |||||
| bool ret = true; | bool ret = true; | ||||
| try { | try { | ||||
| std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>(); | std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>(); | ||||
| MS_EXCEPTION_IF_NULL(op_io); | MS_EXCEPTION_IF_NULL(op_io); | ||||
| op_io->set_index(obj.at(kIndex)); | op_io->set_index(obj.at(kIndex)); | ||||
| op_io->set_name(obj.at(kName)); | op_io->set_name(obj.at(kName)); | ||||
| op_io->set_dtypes(obj.at(kDtype)); | |||||
| op_io->set_formats(obj.at(kFormat)); | |||||
| if (!dtype_format.empty()) { | |||||
| if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) { | |||||
| MS_LOG(ERROR) << "Decode dtype format failed"; | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| op_io->set_dtypes(obj.at(kDtype)); | |||||
| op_io->set_formats(obj.at(kFormat)); | |||||
| } | |||||
| if (op_io->dtypes().size() != op_io->formats().size()) { | if (op_io->dtypes().size() != op_io->formats().size()) { | ||||
| MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes() | MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes() | ||||
| << "is not equal to format size:" << op_io->formats(); | << "is not equal to format size:" << op_io->formats(); | ||||
| @@ -181,6 +221,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply | |||||
| if (obj.find(kShape) != obj.end()) { | if (obj.find(kShape) != obj.end()) { | ||||
| op_io->set_shape(obj.at(kShape)); | op_io->set_shape(obj.at(kShape)); | ||||
| } | } | ||||
| if (obj.find(kReshapeType) != obj.end()) { | |||||
| op_io->set_reshape_type(obj.at(kReshapeType)); | |||||
| } | |||||
| } | } | ||||
| if (io_type == kInput) { | if (io_type == kInput) { | ||||
| @@ -38,8 +38,10 @@ class OpLib { | |||||
| static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); | static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); | ||||
| static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | ||||
| const std::shared_ptr<OpInfo>& op_info); | const std::shared_ptr<OpInfo>& op_info); | ||||
| static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io, | |||||
| size_t index); | |||||
| static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, | static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, | ||||
| const std::shared_ptr<OpInfo>& op_info); | |||||
| const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format); | |||||
| static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info); | static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info); | ||||
| static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info); | static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info); | ||||
| }; | }; | ||||
| @@ -30,7 +30,7 @@ Note: | |||||
| from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | ||||
| from .op_info_register import op_info_register | |||||
| from .op_info_register import op_info_register, TBERegOp, DataType | |||||
| from .primitive import constexpr | from .primitive import constexpr | ||||
| from .._c_expression import signature_rw, signature_kind | from .._c_expression import signature_rw, signature_kind | ||||
| @@ -40,6 +40,6 @@ __primitive__ = [ | |||||
| ] | ] | ||||
| __all__ = ["get_vm_impl_fn", "vm_impl_registry", | __all__ = ["get_vm_impl_fn", "vm_impl_registry", | ||||
| "op_info_register", | |||||
| "op_info_register", "TBERegOp", "DataType", | |||||
| "constexpr"] | "constexpr"] | ||||
| __all__.extend(__primitive__) | __all__.extend(__primitive__) | ||||
| @@ -14,208 +14,41 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """AdamApplyOneWithDecay op""" | """AdamApplyOneWithDecay op""" | ||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| adam_apply_one_with_decay_op_info = TBERegOp("AdamApplyOneWithDecay") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("adam_apply_one_with_decay.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("adam_apply_one_with_decay") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "input0", False, "required", "all") \ | |||||
| .input(1, "input1", False, "required", "all") \ | |||||
| .input(2, "input2", False, "required", "all") \ | |||||
| .input(3, "input3", False, "required", "all") \ | |||||
| .input(4, "input4", False, "required", "all") \ | |||||
| .input(5, "mul0_x", False, "required", "all") \ | |||||
| .input(6, "mul1_x", False, "required", "all") \ | |||||
| .input(7, "mul2_x", False, "required", "all") \ | |||||
| .input(8, "mul3_x", False, "required", "all") \ | |||||
| .input(9, "mul4_x", False, "required", "all") \ | |||||
| .input(10, "add2_y", False, "required", "all") \ | |||||
| .output(0, "output0", False, "required", "all") \ | |||||
| .output(1, "output1", False, "required", "all") \ | |||||
| .output(2, "output2", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register("""{ | |||||
| "op_name": "AdamApplyOneWithDecay", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "adam_apply_one_with_decay.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "adam_apply_one_with_decay", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "input0", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 1, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "input1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 2, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "input2", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 3, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "input3", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 4, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "input4", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 5, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "mul0_x", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 6, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "mul1_x", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 7, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "mul2_x", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 8, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "mul3_x", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 9, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "mul4_x", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 10, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "add2_y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "output0", | |||||
| "need_compile": true, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 1, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "output1", | |||||
| "need_compile": true, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 2, | |||||
| "dtype": [ | |||||
| "float16", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "DefaultFormat" | |||||
| ], | |||||
| "name": "output2", | |||||
| "need_compile": true, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| @op_info_register(adam_apply_one_with_decay_op_info) | |||||
| def _adam_apply_one_with_decay_tbe(): | def _adam_apply_one_with_decay_tbe(): | ||||
| """AdamApplyOneWithDecay TBE register""" | """AdamApplyOneWithDecay TBE register""" | ||||
| return | return | ||||
| @@ -16,6 +16,7 @@ | |||||
| """Operators info register.""" | """Operators info register.""" | ||||
| import os | import os | ||||
| import json | |||||
| import inspect | import inspect | ||||
| from mindspore._c_expression import Oplib | from mindspore._c_expression import Oplib | ||||
| from mindspore._checkparam import ParamValidator as validator | from mindspore._checkparam import ParamValidator as validator | ||||
| @@ -32,21 +33,453 @@ def op_info_register(op_info): | |||||
| 'op_info' must be a str of json format represent the op info, the op info will be added into oplib. | 'op_info' must be a str of json format represent the op info, the op info will be added into oplib. | ||||
| Args: | Args: | ||||
| op_info (str): op info of json format. | |||||
| op_info (str or dict): op info of json format. | |||||
| Returns: | Returns: | ||||
| Function, returns a decorator for op info register. | Function, returns a decorator for op info register. | ||||
| """ | """ | ||||
| def register_decorator(func): | def register_decorator(func): | ||||
| validator.check_type("op_info", op_info, [str]) | |||||
| if isinstance(op_info, dict): | |||||
| op_info_real = json.dumps(op_info) | |||||
| else: | |||||
| op_info_real = op_info | |||||
| validator.check_type("op_info", op_info_real, [str]) | |||||
| op_lib = Oplib() | op_lib = Oplib() | ||||
| file_path = os.path.realpath(inspect.getfile(func)) | file_path = os.path.realpath(inspect.getfile(func)) | ||||
| # keep the path custom ops implementation. | # keep the path custom ops implementation. | ||||
| imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path | imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path | ||||
| if not op_lib.reg_op(op_info, imply_path): | |||||
| raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info)) | |||||
| if not op_lib.reg_op(op_info_real, imply_path): | |||||
| raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real)) | |||||
| def wrapped_function(*args, **kwargs): | def wrapped_function(*args, **kwargs): | ||||
| return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
| return wrapped_function | return wrapped_function | ||||
| return register_decorator | return register_decorator | ||||
| class RegOp(): | |||||
| """ | |||||
| Base class for op info register. | |||||
| Args: | |||||
| op_name (str): Name of op. | |||||
| inputs (list): Inputs inoformation of the op. | |||||
| outputs (list): Outputs information of the op. | |||||
| attr_ (list): Attribute information of the op. | |||||
| dtype_format_ (list): Dtype and format information of the op. | |||||
| """ | |||||
| def __init__(self, op_name=""): | |||||
| if not isinstance(op_name, str): | |||||
| raise ValueError("op name value must be string") | |||||
| if not op_name.strip(): | |||||
| raise ValueError("op name is empty") | |||||
| self.op_name = op_name | |||||
| self.inputs = [] | |||||
| self.outputs = [] | |||||
| self.attr_ = [] | |||||
| self.dtype_format_ = [] | |||||
| def is_string(self, value): | |||||
| """ | |||||
| Check if the value is a str type. | |||||
| Args: | |||||
| value: Parameter to to check. | |||||
| Raises: | |||||
| TypeError: If the type of value is not a str. | |||||
| """ | |||||
| if not isinstance(value, str): | |||||
| raise TypeError("%s value must be str" % str(value)) | |||||
| def is_int(self, value): | |||||
| """ | |||||
| Check if the value is a int. | |||||
| Args: | |||||
| value: Parameter to to check. | |||||
| Raises: | |||||
| TypeError: If the type of value is not a int. | |||||
| """ | |||||
| if not isinstance(value, int): | |||||
| raise TypeError("%s value must be int" % str(value)) | |||||
| def is_bool(self, value): | |||||
| """ | |||||
| Check if the value is a bool. | |||||
| Args: | |||||
| value: Parameter to to check. | |||||
| Raises: | |||||
| TypeError: If the type of value is not a bool. | |||||
| """ | |||||
| if not isinstance(value, bool): | |||||
| raise TypeError("%s value must be bool" % str(value)) | |||||
| def dtype_format(self, *args): | |||||
| """ | |||||
| Register dtype and format. | |||||
| Args: | |||||
| args (tuple): Value of dtype and format. | |||||
| Raises: | |||||
| ValueError: If the size of args not equal to input size add output size. | |||||
| TypeError: If the type of args is not tuple. | |||||
| """ | |||||
| if len(self.inputs) + len(self.outputs) != len(args): | |||||
| raise ValueError("input size add output size must be equal to detype format size") | |||||
| dtype_format = [] | |||||
| for arg in args: | |||||
| if not isinstance(arg, tuple) or len(arg) != 2: | |||||
| raise ValueError("dtype and format value must be tuple of two elements") | |||||
| self.is_string(arg[0]) | |||||
| self.is_string(arg[1]) | |||||
| dtype_format.append(arg) | |||||
| self.dtype_format_.append(tuple(dtype_format)) | |||||
| return self | |||||
| def get_op_info(self): | |||||
| """ | |||||
| Return all registration information for this instance. | |||||
| The '_' character ending the key is removed here for compatibility with previous version. | |||||
| Key will be unified into an underlined form later. | |||||
| """ | |||||
| op_info = {} | |||||
| for key, value in self.__dict__.items(): | |||||
| if isinstance(key, str) and key.endswith('_'): | |||||
| op_info[key.rstrip('_')] = value | |||||
| else: | |||||
| op_info[key] = value | |||||
| return op_info | |||||
| class TBERegOp(RegOp): | |||||
| """Class for TBE op info register.""" | |||||
| def __init__(self, op_name=""): | |||||
| super(TBERegOp, self).__init__(op_name) | |||||
| self.imply_type = "TBE" | |||||
| self.fusion_type_ = '' | |||||
| self.async_flag_ = False | |||||
| self.binfile_name_ = '' | |||||
| self.compute_cost_ = 10 | |||||
| self.kernel_name_ = '' | |||||
| self.partial_flag_ = False | |||||
| self.reshape_type_ = '' | |||||
| self.dynamic_format_ = False | |||||
| self.op_pattern_ = "" | |||||
| def fusion_type(self, fusion_type): | |||||
| """ | |||||
| Register fusion type. | |||||
| Args: | |||||
| fusion_type (str): Value of fusion type. | |||||
| """ | |||||
| self.is_string(fusion_type) | |||||
| self.fusion_type_ = fusion_type | |||||
| return self | |||||
| def async_flag(self, async_flag): | |||||
| """ | |||||
| Register async flag. | |||||
| Args: | |||||
| async_flag (bool): Value of async flag. | |||||
| """ | |||||
| self.is_bool(async_flag) | |||||
| self.async_flag_ = async_flag | |||||
| return self | |||||
| def binfile_name(self, binfile_name): | |||||
| """ | |||||
| Register binfile name. | |||||
| Args: | |||||
| binfile_name (str): Name of op binfile. | |||||
| """ | |||||
| self.is_string(binfile_name) | |||||
| self.binfile_name_ = binfile_name | |||||
| return self | |||||
| def compute_cost(self, compute_cost): | |||||
| """ | |||||
| Register compute cost. | |||||
| Args: | |||||
| compute_cost (int): Value of compute cost. | |||||
| """ | |||||
| self.is_int(compute_cost) | |||||
| self.compute_cost_ = compute_cost | |||||
| return self | |||||
| def kernel_name(self, kernel_name): | |||||
| """ | |||||
| Register kernel name. | |||||
| Args: | |||||
| kernel_name (str): Name of op kernel. | |||||
| """ | |||||
| self.is_string(kernel_name) | |||||
| self.kernel_name_ = kernel_name | |||||
| return self | |||||
| def partial_flag(self, partial_flag): | |||||
| """ | |||||
| Register partial flag. | |||||
| Args: | |||||
| partial_flag (bool): Value of partial flag. | |||||
| """ | |||||
| self.is_bool(partial_flag) | |||||
| self.partial_flag_ = partial_flag | |||||
| return self | |||||
| def reshape_type(self, reshape_type): | |||||
| """ | |||||
| Register reshape type. | |||||
| Args: | |||||
| reshape_type (str): Value of reshape type. | |||||
| """ | |||||
| self.is_string(reshape_type) | |||||
| self.reshape_type_ = reshape_type | |||||
| return self | |||||
| def dynamic_format(self, dynamic_format): | |||||
| """ | |||||
| Register dynamic format. | |||||
| Args: | |||||
| reshape_type (bool): Value of dynamic format. | |||||
| """ | |||||
| self.is_bool(dynamic_format) | |||||
| self.dynamic_format_ = dynamic_format | |||||
| return self | |||||
| def op_pattern(self, pattern=None): | |||||
| """ | |||||
| Register op pattern information. | |||||
| Args: | |||||
| pattern (str): Value of op pattern. | |||||
| """ | |||||
| if pattern is not None and self.istring(pattern): | |||||
| self.op_pattern_ = pattern | |||||
| return self | |||||
| def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs): | |||||
| """ | |||||
| Register op attribute information. | |||||
| Args: | |||||
| name (str): Name of the attribute. Default: None. | |||||
| param_type (str): Param type of the attribute. Default: None. | |||||
| type (str): Type of the attribute. Default: None. | |||||
| value (str): Value of the attribute. Default: None. | |||||
| default_value (str): Default value of attribute. Default: None. | |||||
| kwargs (dict): Other information for the attribute. | |||||
| """ | |||||
| param_list = [name, param_type, value_type, value, default_value] | |||||
| attr_dict = {} | |||||
| for index, element in enumerate(param_list): | |||||
| if element is not None: | |||||
| self.is_string(element) | |||||
| if index == 0: | |||||
| attr_dict["name"] = element | |||||
| elif index == 1: | |||||
| attr_dict["param_type"] = element | |||||
| elif index == 2: | |||||
| attr_dict["type"] = element | |||||
| elif index == 3: | |||||
| attr_dict["value"] = element | |||||
| elif index == 4: | |||||
| attr_dict["default_value"] = element | |||||
| if kwargs: | |||||
| attr_dict = dict(attr_dict, **kwargs) | |||||
| self.attr_.append(attr_dict) | |||||
| return self | |||||
| def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): | |||||
| """ | |||||
| Register op input information. | |||||
| Args: | |||||
| index (int): Order of the input. Default: None. | |||||
| name (str): Name of the input. Default: None. | |||||
| need_compile (bool): The input need compile whether or not. Default: None. | |||||
| param_type (str): Type of the input. Default: None. | |||||
| shape (str): Shape of the input. Default: None. | |||||
| kwargs (dict): Other information for the input. | |||||
| """ | |||||
| param_list = [index, name, need_compile, param_type, shape] | |||||
| input_dict = {} | |||||
| for idx, element in enumerate(param_list): | |||||
| if element is not None: | |||||
| if idx == 0: | |||||
| self.is_int(element) | |||||
| input_dict["index"] = element | |||||
| elif idx == 1: | |||||
| self.is_string(element) | |||||
| input_dict["name"] = element | |||||
| elif idx == 2: | |||||
| self.is_bool(element) | |||||
| input_dict["need_compile"] = element | |||||
| elif idx == 3: | |||||
| self.is_string(element) | |||||
| input_dict["param_type"] = element | |||||
| elif idx == 4: | |||||
| self.is_string(element) | |||||
| input_dict["shape"] = element | |||||
| if kwargs: | |||||
| input_dict = dict(input_dict, **kwargs) | |||||
| self.inputs.append(input_dict) | |||||
| return self | |||||
| def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): | |||||
| """ | |||||
| Register op output information. | |||||
| Args: | |||||
| index (int): Order of the output. Default: None. | |||||
| name (str): Name of the output. Default: None. | |||||
| need_compile (bool): The output need compile whether or not. Default: None. | |||||
| param_type (str): Type of the output. Default: None. | |||||
| shape (str): Shape of the output. Default: None. | |||||
| kwargs (dict): Other information for the output. | |||||
| """ | |||||
| param_list = [index, name, need_compile, param_type, shape] | |||||
| output_dict = {} | |||||
| for idx, element in enumerate(param_list): | |||||
| if element is not None: | |||||
| if idx == 0: | |||||
| self.is_int(element) | |||||
| output_dict["index"] = element | |||||
| elif idx == 1: | |||||
| self.is_string(element) | |||||
| output_dict["name"] = element | |||||
| elif idx == 2: | |||||
| self.is_bool(element) | |||||
| output_dict["need_compile"] = element | |||||
| elif idx == 3: | |||||
| self.is_string(element) | |||||
| output_dict["param_type"] = element | |||||
| elif idx == 4: | |||||
| self.is_string(element) | |||||
| output_dict["shape"] = element | |||||
| if kwargs: | |||||
| output_dict = dict(output_dict, **kwargs) | |||||
| self.outputs.append(output_dict) | |||||
| return self | |||||
| class DataType(): | |||||
| """ | |||||
| Various combinations of dtype and formatself. | |||||
| The current list below maybe not completed. If necessary, please add it. | |||||
| """ | |||||
| BOOL_None = ("bool", "") | |||||
| BOOL_Default = ("bool", "DefaultFormat") | |||||
| BOOL_5HD = ("bool", "NC1HWC0") | |||||
| BOOL_NCHW = ("bool", "NCHW") | |||||
| BOOL_NHWC = ("bool", "NHWC") | |||||
| BOOL_HWCN = ("bool", "HWCN") | |||||
| I8_None = ("int8", "") | |||||
| I8_Default = ("int8", "DefaultFormat") | |||||
| I8_5HD = ("int8", "NC1HWC0") | |||||
| I8_FracZ = ("int8", "Fracz") | |||||
| I8_FracNZ = ("int8", "FRACTAL_NZ") | |||||
| I8_NCHW = ("int8", "NCHW") | |||||
| I8_NHWC = ("int8", "NHWC") | |||||
| I8_HWCN = ("int8", "HWCN") | |||||
| U8_None = ("uint8", "") | |||||
| U8_Default = ("uint8", "DefaultFormat") | |||||
| U8_5HD = ("uint8", "NC1HWC0") | |||||
| U8_FracZ = ("uint8", "Fracz") | |||||
| U8_FracNZ = ("uint8", "FRACTAL_NZ") | |||||
| U8_NCHW = ("uint8", "NCHW") | |||||
| U8_NHWC = ("uint8", "NHWC") | |||||
| U8_HWCN = ("uint8", "HWCN") | |||||
| I16_None = ("int16", "") | |||||
| I16_Default = ("int16", "DefaultFormat") | |||||
| I16_5HD = ("int16", "NC1HWC0") | |||||
| I16_FracZ = ("int16", "Fracz") | |||||
| I16_FracNZ = ("int16", "FRACTAL_NZ") | |||||
| I16_NCHW = ("int16", "NCHW") | |||||
| I16_NHWC = ("int16", "NHWC") | |||||
| I16_HWCN = ("int16", "HWCN") | |||||
| U16_None = ("uint16", "") | |||||
| U16_Default = ("uint16", "DefaultFormat") | |||||
| U16_5HD = ("uint16", "NC1HWC0") | |||||
| U16_FracZ = ("uint16", "Fracz") | |||||
| U16_FracNZ = ("uint16", "FRACTAL_NZ") | |||||
| U16_NCHW = ("uint16", "NCHW") | |||||
| U16_NHWC = ("uint16", "NHWC") | |||||
| U16_HWCN = ("uint16", "HWCN") | |||||
| I32_None = ("int32", "") | |||||
| I32_Default = ("int32", "DefaultFormat") | |||||
| I32_5HD = ("int32", "NC1HWC0") | |||||
| I32_FracZ = ("int32", "Fracz") | |||||
| I32_FracNZ = ("int32", "FRACTAL_NZ") | |||||
| I32_NCHW = ("int32", "NCHW") | |||||
| I32_NHWC = ("int32", "NHWC") | |||||
| I32_HWCN = ("int32", "HWCN") | |||||
| U32_None = ("uint32", "") | |||||
| U32_Default = ("uint32", "DefaultFormat") | |||||
| U32_5HD = ("uint32", "NC1HWC0") | |||||
| U32_FracZ = ("uint32", "Fracz") | |||||
| U32_FracNZ = ("uint32", "FRACTAL_NZ") | |||||
| U32_NCHW = ("uint32", "NCHW") | |||||
| U32_NHWC = ("uint32", "NHWC") | |||||
| U32_HWCN = ("uint32", "HWCN") | |||||
| I64_None = ("int64", "") | |||||
| I64_Default = ("int64", "DefaultFormat") | |||||
| I64_5HD = ("int64", "NC1HWC0") | |||||
| I64_FracZ = ("int64", "Fracz") | |||||
| I64_FracNZ = ("int64", "FRACTAL_NZ") | |||||
| I64_NCHW = ("int64", "NCHW") | |||||
| I64_NHWC = ("int64", "NHWC") | |||||
| I64_HWCN = ("int64", "HWCN") | |||||
| U64_None = ("uint64", "") | |||||
| U64_Default = ("uint64", "DefaultFormat") | |||||
| U64_5HD = ("uint64", "NC1HWC0") | |||||
| U64_FracZ = ("uint64", "Fracz") | |||||
| U64_FracNZ = ("uint64", "FRACTAL_NZ") | |||||
| U64_NCHW = ("uint64", "NCHW") | |||||
| U64_NHWC = ("uint64", "NHWC") | |||||
| U64_HWCN = ("uint64", "HWCN") | |||||
| F16_None = ("float16", "") | |||||
| F16_Default = ("float16", "DefaultFormat") | |||||
| F16_5HD = ("float16", "NC1HWC0") | |||||
| F16_FracZ = ("float16", "Fracz") | |||||
| F16_FracNZ = ("float16", "FRACTAL_NZ") | |||||
| F16_C1HWNCoC0 = ("float16", "C1HWNCoC0") | |||||
| F16_NCHW = ("float16", "NCHW") | |||||
| F16_NHWC = ("float16", "NHWC") | |||||
| F16_HWCN = ("float16", "HWCN") | |||||
| F32_None = ("float32", "") | |||||
| F32_Default = ("float32", "DefaultFormat") | |||||
| F32_5HD = ("float32", "NC1HWC0") | |||||
| F32_FracZ = ("float32", "Fracz") | |||||
| F32_FracNZ = ("float32", "FRACTAL_NZ") | |||||
| F32_C1HWNCoC0 = ("float32", "C1HWNCoC0") | |||||
| F32_NCHW = ("float32", "NCHW") | |||||
| F32_NHWC = ("float32", "NHWC") | |||||
| F32_HWCN = ("float32", "HWCN") | |||||