| @@ -39,45 +39,7 @@ namespace mindspore { | |||
| namespace kernel { | |||
| using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>; | |||
| const std::vector<std::string> local_framework_op_vec = {kInitDataSetQueue, kGetNext, kDropoutGenMask, kPrint}; | |||
| void InitDataSetQueueAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(proto); | |||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||
| MS_EXCEPTION_IF_NULL(node_attr); | |||
| std::string channel_name = AnfAlgo::GetNodeAttr<std::string>(anf_node, kQueueName); | |||
| (*node_attr)[kChannelName].set_s(channel_name); | |||
| } | |||
| void GetNextAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(proto); | |||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||
| MS_EXCEPTION_IF_NULL(node_attr); | |||
| std::string shared_name = AnfAlgo::GetNodeAttr<std::string>(anf_node, kSharedName); | |||
| (*node_attr)[kChannelName].set_s(shared_name); | |||
| } | |||
| void DropoutGenMaskAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(proto); | |||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||
| MS_EXCEPTION_IF_NULL(node_attr); | |||
| int seed = AnfAlgo::GetNodeAttr<int>(anf_node, kSeed); | |||
| int seed2 = AnfAlgo::GetNodeAttr<int>(anf_node, kSeed2); | |||
| (*node_attr)["seed"].set_i(seed); | |||
| (*node_attr)["seed2"].set_i(seed2); | |||
| } | |||
| void CreateAttrFuncMap(std::map<std::string, FNodeAttrHandle> *mOpAttrFuncMap) { | |||
| (void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kInitDataSetQueue, InitDataSetQueueAttr)); | |||
| (void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kGetNext, GetNextAttr)); | |||
| (void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kDropoutGenMask, DropoutGenMaskAttr)); | |||
| } | |||
| const std::vector<std::string> local_framework_op_vec = {kInitData, kGetNext, kDropoutGenMask, kPrint}; | |||
| bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num, | |||
| std::vector<size_t> *input_size_list) { | |||
| @@ -147,24 +109,74 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A | |||
| return true; | |||
| } | |||
| void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value, | |||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) { | |||
| MS_EXCEPTION_IF_NULL(node_attr); | |||
| if (type == "int") { | |||
| auto attr_value = GetValue<int>(value); | |||
| (*node_attr)[attr_name].set_i(attr_value); | |||
| } else if (type == "str") { | |||
| auto attr_value = GetValue<std::string>(value); | |||
| (*node_attr)[attr_name].set_s(attr_value); | |||
| } else if (type == "bool") { | |||
| auto attr_value = GetValue<bool>(value); | |||
| (*node_attr)[attr_name].set_b(attr_value); | |||
| } else if (type == "float") { | |||
| auto attr_value = GetValue<float>(value); | |||
| (*node_attr)[attr_name].set_f(attr_value); | |||
| } else if (type == "listInt") { | |||
| std::vector<int> attr_value; | |||
| auto value_type = value->type(); | |||
| MS_EXCEPTION_IF_NULL(value_type); | |||
| auto value_type_str = value_type->ToString(); | |||
| if (value_type_str == "Int32") { | |||
| int data = GetValue<int>(value); | |||
| attr_value.push_back(data); | |||
| } else { | |||
| attr_value = GetValue<std::vector<int>>(value); | |||
| } | |||
| mindspore::AttrValue input_shape_attr; | |||
| mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array(); | |||
| MS_EXCEPTION_IF_NULL(input_shape_attr_list); | |||
| for (const auto shape : attr_value) { | |||
| input_shape_attr_list->add_i(shape); | |||
| } | |||
| (*node_attr)[attr_name] = input_shape_attr; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "type: " << type << "not support"; | |||
| } | |||
| } | |||
| void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||
| if (op_name == "InitDataSetQueue") { | |||
| op_name = "InitData"; | |||
| if (op_name == kInitDataSetQueue) { | |||
| op_name = kInitData; | |||
| } | |||
| if (op_name == "Print") { | |||
| if (op_name == kPrint) { | |||
| return; | |||
| } | |||
| std::map<std::string, FNodeAttrHandle> mOpAttrFuncMap; | |||
| CreateAttrFuncMap(&mOpAttrFuncMap); | |||
| FNodeAttrHandle func_ptr = nullptr; | |||
| auto iter = mOpAttrFuncMap.find(op_name); | |||
| if (iter != mOpAttrFuncMap.end()) { | |||
| func_ptr = iter->second; | |||
| MS_EXCEPTION_IF_NULL(func_ptr); | |||
| func_ptr(anf_node, proto); | |||
| } else { | |||
| MS_LOG(ERROR) << "Don't support node [" << op_name << "] to set nodedef of attr"; | |||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); | |||
| MS_EXCEPTION_IF_NULL(op_info_ptr); | |||
| auto attrs_ptr = op_info_ptr->attrs_ptr(); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||
| for (const auto &attr_ptr : attrs_ptr) { | |||
| std::string attr_name = attr_ptr->name(); | |||
| std::string real_name; | |||
| auto value = primitive->GetAttr(attr_name); | |||
| if (value != nullptr) { | |||
| if (attr_name == kQueueName || attr_name == kSharedName) { | |||
| real_name = kChannelName; | |||
| } else if (attr_name == kSeed) { | |||
| real_name = "seed"; | |||
| } else if (attr_name == kSeed2) { | |||
| real_name = "seed2"; | |||
| } | |||
| std::string type = attr_ptr->type(); | |||
| ParseAttrValue(type, real_name, value, node_attr); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Set node attr end!"; | |||
| } | |||
| @@ -17,68 +17,27 @@ | |||
| #include "kernel/aicpu/aicpu_kernel_metadata.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "kernel/aicpu/aicpu_util.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr auto kInitDataSetQueueOpName = "InitDataSetQueue"; | |||
| constexpr auto kGetNext = "GetNext"; | |||
| constexpr auto kDropoutGenMask = "DropoutGenMask"; | |||
| constexpr auto kPrint = "Print"; | |||
| const std::vector<std::string> AICPU_OPS = {kInitDataSetQueueOpName, kGetNext, kDropoutGenMask, kPrint}; | |||
| std::shared_ptr<KernelBuildInfo> CreateKernelInfo(const std::vector<std::string> &inputs_format, | |||
| const std::vector<TypeId> &inputs_device_type, | |||
| const std::vector<std::string> &outputs_format, | |||
| const std::vector<TypeId> &outputs_device_type) { | |||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||
| builder.SetInputsFormat(inputs_format); | |||
| builder.SetInputsDeviceType(inputs_device_type); | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputsDeviceType(outputs_device_type); | |||
| builder.SetProcessor(AICPU); | |||
| builder.SetKernelType(AICPU_KERNEL); | |||
| builder.SetFusionType(OPAQUE); | |||
| return builder.Build(); | |||
| } | |||
| bool CheckIfExistAicpuMeta(const std::string &op_name) { | |||
| if (std::find(AICPU_OPS.begin(), AICPU_OPS.end(), op_name) != AICPU_OPS.end()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||
| MS_LOG(INFO) << "AicpuMetadataInfo."; | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (CheckIfExistAicpuMeta(op_name)) { | |||
| MS_LOG(DEBUG) << "Aicpu doesn't have metadata of op [" << op_name << "]."; | |||
| return; | |||
| } | |||
| if (op_name == kInitDataSetQueueOpName) { | |||
| kernel_info_list->push_back(CreateKernelInfo({}, {}, {}, {})); | |||
| if (op_name == kInitDataSetQueue) { | |||
| op_name = kInitData; | |||
| } | |||
| if (op_name == kGetNext) { | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_type; | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||
| } | |||
| kernel_info_list->push_back(CreateKernelInfo({}, {}, outputs_format, outputs_type)); | |||
| } | |||
| if (op_name == kDropoutGenMask) { | |||
| kernel_info_list->push_back(CreateKernelInfo({kOpFormat_NCHW, kOpFormat_NCHW}, | |||
| {kInt32->type_id(), kFloat16->type_id()}, {kOpFormat_NCHW}, | |||
| {kUInt8->type_id()})); | |||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); | |||
| if (op_info_ptr == nullptr) { | |||
| MS_LOG(WARNING) << "Aicpu doestn't have metadata of op [" << op_name << "]"; | |||
| return; | |||
| } | |||
| // For compatibility with the current framework | |||
| if (op_name == kPrint) { | |||
| std::vector<std::string> inputs_format; | |||
| std::vector<TypeId> inputs_type; | |||
| @@ -92,11 +51,20 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr< | |||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||
| } | |||
| kernel_info_list->push_back(CreateKernelInfo(inputs_format, inputs_type, outputs_format, outputs_type)); | |||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||
| builder.SetInputsFormat(inputs_format); | |||
| builder.SetInputsDeviceType(inputs_type); | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputsDeviceType(outputs_type); | |||
| builder.SetProcessor(AICPU); | |||
| builder.SetKernelType(AICPU_KERNEL); | |||
| builder.SetFusionType(OPAQUE); | |||
| kernel_info_list->push_back(builder.Build()); | |||
| return; | |||
| } | |||
| if (kernel_info_list->empty()) { | |||
| MS_LOG(INFO) << "Aicpu dose not has metadata of op[ " << op_name << "]."; | |||
| if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) { | |||
| MS_LOG(WARNING) << "Aicpu parsed metadata op [" << op_name << "] failed"; | |||
| return; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| @@ -24,7 +24,8 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr auto kInitDataSetQueue = "InitData"; | |||
| constexpr auto kInitDataSetQueue = "InitDataSetQueue"; | |||
| constexpr auto kInitData = "InitData"; | |||
| constexpr auto kGetNext = "GetNext"; | |||
| constexpr auto kDropoutGenMask = "DropoutGenMask"; | |||
| constexpr auto kPrint = "Print"; | |||
| @@ -417,6 +417,8 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu | |||
| if (imply_type == kAKG) { | |||
| builder->SetKernelType(AUTO_DIFF_KERNEL); | |||
| } else if (imply_type == kAICPU) { | |||
| builder->SetKernelType(AICPU_KERNEL); | |||
| } else { | |||
| builder->SetKernelType(TBE_KERNEL); | |||
| } | |||
| @@ -471,6 +473,13 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn | |||
| return false; | |||
| } | |||
| kernel_info_list->push_back(builder->Build()); | |||
| } | |||
| } else { | |||
| if (processor == AICPU) { | |||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| SetKernelBuildInfo(builder, processor, op_info_ptr); | |||
| kernel_info_list->push_back(builder->Build()); | |||
| } | |||
| } | |||
| @@ -24,7 +24,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| enum OpImplyType { kAKG = 0, kTBE }; | |||
| enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU }; | |||
| enum OpIOType { kInput = 0, kOutput }; | |||
| class OpAttr { | |||
| @@ -39,6 +39,7 @@ constexpr auto kDtypeFormat = "dtype_format"; | |||
| constexpr auto kAttr = "attr"; | |||
| constexpr auto kIputs = "inputs"; | |||
| constexpr auto kOutputs = "outputs"; | |||
| constexpr auto kAiCPU = "AiCPU"; | |||
| constexpr auto kTbe = "TBE"; | |||
| constexpr auto kAkg = "akg"; | |||
| constexpr auto kAutodiff = "AutoDiff"; | |||
| @@ -60,6 +61,8 @@ std::string ImplTypeToStr(OpImplyType impl_type) { | |||
| return kTbe; | |||
| case kAKG: | |||
| return kAkg; | |||
| case kAICPU: | |||
| return kAiCPU; | |||
| default: | |||
| return "unknow"; | |||
| } | |||
| @@ -76,6 +79,9 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) | |||
| } else if (imply_type_string == kAutodiff) { | |||
| OpImplyType imply_type = kAKG; | |||
| ret = DecodeOpInfo(op_json, imply_type, impl_path); | |||
| } else if (imply_type_string == kAiCPU) { | |||
| OpImplyType imply_type = kAICPU; | |||
| ret = DecodeOpInfo(op_json, imply_type, impl_path); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Not support imply_type"; | |||
| } | |||
| @@ -154,7 +160,9 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||
| std::shared_ptr<OpAttr> op_attr = std::make_shared<OpAttr>(); | |||
| MS_EXCEPTION_IF_NULL(op_attr); | |||
| op_attr->set_name(obj.at(kName)); | |||
| op_attr->set_param_type(obj.at(kParamType)); | |||
| if (imply_type != kAICPU) { | |||
| op_attr->set_param_type(obj.at(kParamType)); | |||
| } | |||
| op_attr->set_type(obj.at(kType)); | |||
| if (imply_type == kTBE) { | |||
| op_attr->set_value(obj.at(kValue)); | |||
| @@ -242,9 +250,10 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool is_gpu = (context->device_target() == kGPUDevice); | |||
| if ((is_gpu && imply_type == kTBE) || (!is_gpu && imply_type != kTBE)) { | |||
| MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << "imply_type:" << ImplTypeToStr(imply_type) | |||
| << "current op num:" << op_info_.size(); | |||
| if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) || | |||
| (!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) { | |||
| MS_LOG(ERROR) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type) | |||
| << ", current op num:" << op_info_.size(); | |||
| return nullptr; | |||
| } | |||
| for (const auto& op_info : op_info_) { | |||
| @@ -253,8 +262,8 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im | |||
| return op_info; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << "imply_type:" << ImplTypeToStr(imply_type) | |||
| << "current op num:" << op_info_.size(); | |||
| MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type) | |||
| << ", current op num:" << op_info_.size(); | |||
| return nullptr; | |||
| } | |||
| @@ -30,7 +30,7 @@ Note: | |||
| from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | |||
| from .op_info_register import op_info_register, TBERegOp, DataType | |||
| from .op_info_register import op_info_register, AiCPURegOp, TBERegOp, DataType | |||
| from .primitive import constexpr | |||
| from .._c_expression import signature_rw, signature_kind | |||
| @@ -40,6 +40,6 @@ __primitive__ = [ | |||
| ] | |||
| __all__ = ["get_vm_impl_fn", "vm_impl_registry", | |||
| "op_info_register", "TBERegOp", "DataType", | |||
| "op_info_register", "AiCPURegOp", "TBERegOp", "DataType", | |||
| "constexpr"] | |||
| __all__.extend(__primitive__) | |||
| @@ -16,5 +16,6 @@ | |||
| from .akg.gpu import * | |||
| from .tbe import * | |||
| from .aicpu import * | |||
| __all__ = [] | |||
| @@ -0,0 +1,19 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """aicpu ops""" | |||
| from .init_data_set_queue import _init_data_set_queue_aicpu | |||
| from .dropout_genmask import _dropout_genmask_aicpu | |||
| from .get_next import _get_next_aicpu | |||
| from .print_tensor import _print_aicpu | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """InitDataSetQueue op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| dropout_genmask_op_info = AiCPURegOp("DropoutGenMask") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "x1", "required") \ | |||
| .input(1, "x2", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .attr("Seed0", "int") \ | |||
| .attr("Seed1", "int") \ | |||
| .dtype_format(DataType.I32_NCHW, DataType.F16_NCHW, DataType.U8_NCHW) \ | |||
| .get_op_info() | |||
| @op_info_register(dropout_genmask_op_info) | |||
| def _dropout_genmask_aicpu(): | |||
| """Dropout AiCPU register""" | |||
| return | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """InitDataSetQueue op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| get_next_op_info = AiCPURegOp("GetNext") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .output(0, "y", "dynamic") \ | |||
| .attr("shared_name", "str") \ | |||
| .dtype_format(DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_Default) \ | |||
| .dtype_format(DataType.I16_Default) \ | |||
| .dtype_format(DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default) \ | |||
| .dtype_format(DataType.F16_Default) \ | |||
| .dtype_format(DataType.U8_Default) \ | |||
| .dtype_format(DataType.U16_Default) \ | |||
| .dtype_format(DataType.U32_Default) \ | |||
| .dtype_format(DataType.U64_Default) \ | |||
| .dtype_format(DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(get_next_op_info) | |||
| def _get_next_aicpu(): | |||
| """GetNext AiCPU register""" | |||
| return | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """InitDataSetQueue op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp | |||
| init_data_set_queue_op_info = AiCPURegOp("InitData") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .attr("queue_name", "str") \ | |||
| .get_op_info() | |||
| @op_info_register(init_data_set_queue_op_info) | |||
| def _init_data_set_queue_aicpu(): | |||
| """InitDataSetQueue AiCPU register""" | |||
| return | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """InitDataSetQueue op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| print_op_info = AiCPURegOp("Print") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "x", "dynamic") \ | |||
| .output(0, "y", "required") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(print_op_info) | |||
| def _print_aicpu(): | |||
| """Print AiCPU register""" | |||
| return | |||
| @@ -78,14 +78,15 @@ class RegOp(): | |||
| self.inputs = [] | |||
| self.outputs = [] | |||
| self.attr_ = [] | |||
| self.fusion_type_ = '' | |||
| self.dtype_format_ = [] | |||
| def is_string(self, value): | |||
| def _is_string(self, value): | |||
| """ | |||
| Check if the value is a str type. | |||
| Args: | |||
| value: Parameter to to check. | |||
| value: Parameter to be checked. | |||
| Raises: | |||
| TypeError: If the type of value is not a str. | |||
| @@ -93,12 +94,12 @@ class RegOp(): | |||
| if not isinstance(value, str): | |||
| raise TypeError("%s value must be str" % str(value)) | |||
| def is_int(self, value): | |||
| def _is_int(self, value): | |||
| """ | |||
| Check if the value is a int. | |||
| Args: | |||
| value: Parameter to to check. | |||
| value: Parameter to be checked. | |||
| Raises: | |||
| TypeError: If the type of value is not a int. | |||
| @@ -106,12 +107,12 @@ class RegOp(): | |||
| if not isinstance(value, int): | |||
| raise TypeError("%s value must be int" % str(value)) | |||
| def is_bool(self, value): | |||
| def _is_bool(self, value): | |||
| """ | |||
| Check if the value is a bool. | |||
| Args: | |||
| value: Parameter to to check. | |||
| value: Parameter to be checked. | |||
| Raises: | |||
| TypeError: If the type of value is not a bool. | |||
| @@ -119,6 +120,51 @@ class RegOp(): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("%s value must be bool" % str(value)) | |||
| def _check_param(self, param_list, key_list, fn_list, kwargs): | |||
| """ | |||
| Check if the parameter type is correct. | |||
| Args: | |||
| param_list (list): Parameter list to be checked. | |||
| key_list (list): The keys of output dict. | |||
| fn_list (list): Function used for parameter checking. If the function list has only one element, | |||
| all parameters will use the same function. | |||
| kwargs (dict): Other parameter information. | |||
| Raises: | |||
| TypeError: If the type of value is not list. | |||
| ValueError: If the size of param list is not equal to the size of key list, or | |||
| the size of param list is not equal to the size of funtion list. | |||
| """ | |||
| for i in [param_list, key_list, fn_list]: | |||
| if not isinstance(i, list): | |||
| raise TypeError("%s value must be list type" % str(i)) | |||
| if len(param_list) != len(key_list) or (len(fn_list) != 1 and len(param_list) != len(fn_list)): | |||
| raise ValueError("param_list size {}, key_list size {}, must be equal.And fn_list size {}.". | |||
| format(len(param_list), len(key_list), len(fn_list))) | |||
| out_dict = {} | |||
| for idx, element in enumerate(param_list): | |||
| if element is not None: | |||
| if len(fn_list) == 1: | |||
| fn_list[0](element) | |||
| else: | |||
| fn_list[idx](element) | |||
| out_dict[key_list[idx]] = element | |||
| if kwargs: | |||
| out_dict = dict(out_dict, kwargs) | |||
| return out_dict | |||
| def fusion_type(self, fusion_type): | |||
| """ | |||
| Register fusion type. | |||
| Args: | |||
| fusion_type (str): Value of fusion type. | |||
| """ | |||
| self._is_string(fusion_type) | |||
| self.fusion_type_ = fusion_type | |||
| return self | |||
| def dtype_format(self, *args): | |||
| """ | |||
| Register dtype and format. | |||
| @@ -136,8 +182,8 @@ class RegOp(): | |||
| for arg in args: | |||
| if not isinstance(arg, tuple) or len(arg) != 2: | |||
| raise ValueError("dtype and format value must be tuple of two elements") | |||
| self.is_string(arg[0]) | |||
| self.is_string(arg[1]) | |||
| self._is_string(arg[0]) | |||
| self._is_string(arg[1]) | |||
| dtype_format.append(arg) | |||
| self.dtype_format_.append(tuple(dtype_format)) | |||
| return self | |||
| @@ -159,13 +205,71 @@ class RegOp(): | |||
| return op_info | |||
| class AiCPURegOp(RegOp): | |||
| """Class for AiCPU op info register""" | |||
| def __init__(self, op_name): | |||
| super(AiCPURegOp, self).__init__(op_name) | |||
| self.imply_type = "AiCPU" | |||
| def input(self, index=None, name=None, param_type=None, **kwargs): | |||
| """ | |||
| Register AiCPU op input information. | |||
| Args: | |||
| index (int): Order of the input. Default: None. | |||
| name (str): Name of the input. Default: None. | |||
| param_type (str): Param type of the input. Default: None. | |||
| kwargs (dict): Other information for the input. | |||
| """ | |||
| param_list = [index, name, param_type] | |||
| key_list = ["index", "name", "param_type"] | |||
| fn_list = [self._is_int, self._is_string, self._is_string] | |||
| input_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||
| self.inputs.append(input_dict) | |||
| return self | |||
| def output(self, index=None, name=None, param_type=None, **kwargs): | |||
| """ | |||
| Register AiCPU op output information. | |||
| Args: | |||
| index (int): Order of the output. Default: None. | |||
| name (str): Name of the output. Default: None. | |||
| param_type (str): Param type of the output. Default: None. | |||
| kwargs (dict): Other information for the output. | |||
| """ | |||
| param_list = [index, name, param_type] | |||
| key_list = ["index", "name", "param_type"] | |||
| fn_list = [self._is_int, self._is_string, self._is_string] | |||
| output_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||
| self.outputs.append(output_dict) | |||
| return self | |||
| def attr(self, name=None, value_type=None, value=None, **kwargs): | |||
| """ | |||
| Register AiCPU op attribute information. | |||
| Args: | |||
| name (str): Name of the attribute. Default: None. | |||
| value_type (str): Value type of the attribute. Default: None. | |||
| value (str): Value type of the attribute. Default: None. | |||
| kwargs (dict): Other information for the attribute. | |||
| """ | |||
| param_list = [name, value_type, value] | |||
| key_list = ["name", "type", "value"] | |||
| fn_list = [self._is_string] | |||
| attr_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||
| self.attr_.append(attr_dict) | |||
| return self | |||
| class TBERegOp(RegOp): | |||
| """Class for TBE op info register.""" | |||
| def __init__(self, op_name=""): | |||
| super(TBERegOp, self).__init__(op_name) | |||
| self.imply_type = "TBE" | |||
| self.fusion_type_ = '' | |||
| self.async_flag_ = False | |||
| self.binfile_name_ = '' | |||
| self.compute_cost_ = 10 | |||
| @@ -175,17 +279,6 @@ class TBERegOp(RegOp): | |||
| self.dynamic_format_ = False | |||
| self.op_pattern_ = "" | |||
| def fusion_type(self, fusion_type): | |||
| """ | |||
| Register fusion type. | |||
| Args: | |||
| fusion_type (str): Value of fusion type. | |||
| """ | |||
| self.is_string(fusion_type) | |||
| self.fusion_type_ = fusion_type | |||
| return self | |||
| def async_flag(self, async_flag): | |||
| """ | |||
| Register async flag. | |||
| @@ -193,7 +286,7 @@ class TBERegOp(RegOp): | |||
| Args: | |||
| async_flag (bool): Value of async flag. | |||
| """ | |||
| self.is_bool(async_flag) | |||
| self._is_bool(async_flag) | |||
| self.async_flag_ = async_flag | |||
| return self | |||
| @@ -204,7 +297,7 @@ class TBERegOp(RegOp): | |||
| Args: | |||
| binfile_name (str): Name of op binfile. | |||
| """ | |||
| self.is_string(binfile_name) | |||
| self._is_string(binfile_name) | |||
| self.binfile_name_ = binfile_name | |||
| return self | |||
| @@ -215,7 +308,7 @@ class TBERegOp(RegOp): | |||
| Args: | |||
| compute_cost (int): Value of compute cost. | |||
| """ | |||
| self.is_int(compute_cost) | |||
| self._is_int(compute_cost) | |||
| self.compute_cost_ = compute_cost | |||
| return self | |||
| @@ -226,7 +319,7 @@ class TBERegOp(RegOp): | |||
| Args: | |||
| kernel_name (str): Name of op kernel. | |||
| """ | |||
| self.is_string(kernel_name) | |||
| self._is_string(kernel_name) | |||
| self.kernel_name_ = kernel_name | |||
| return self | |||
| @@ -237,7 +330,7 @@ class TBERegOp(RegOp): | |||
| Args: | |||
| partial_flag (bool): Value of partial flag. | |||
| """ | |||
| self.is_bool(partial_flag) | |||
| self._is_bool(partial_flag) | |||
| self.partial_flag_ = partial_flag | |||
| return self | |||
| @@ -248,7 +341,7 @@ class TBERegOp(RegOp): | |||
| Args: | |||
| reshape_type (str): Value of reshape type. | |||
| """ | |||
| self.is_string(reshape_type) | |||
| self._is_string(reshape_type) | |||
| self.reshape_type_ = reshape_type | |||
| return self | |||
| @@ -259,56 +352,43 @@ class TBERegOp(RegOp): | |||
| Args: | |||
| reshape_type (bool): Value of dynamic format. | |||
| """ | |||
| self.is_bool(dynamic_format) | |||
| self._is_bool(dynamic_format) | |||
| self.dynamic_format_ = dynamic_format | |||
| return self | |||
| def op_pattern(self, pattern=None): | |||
| """ | |||
| Register op pattern information. | |||
| Register TBE op pattern information. | |||
| Args: | |||
| pattern (str): Value of op pattern. | |||
| """ | |||
| if pattern is not None and self.istring(pattern): | |||
| if pattern is not None and self._is_string(pattern): | |||
| self.op_pattern_ = pattern | |||
| return self | |||
| def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs): | |||
| """ | |||
| Register op attribute information. | |||
| Register TBE op attribute information. | |||
| Args: | |||
| name (str): Name of the attribute. Default: None. | |||
| param_type (str): Param type of the attribute. Default: None. | |||
| type (str): Type of the attribute. Default: None. | |||
| value_type (str): Type of the attribute. Default: None. | |||
| value (str): Value of the attribute. Default: None. | |||
| default_value (str): Default value of attribute. Default: None. | |||
| kwargs (dict): Other information for the attribute. | |||
| """ | |||
| param_list = [name, param_type, value_type, value, default_value] | |||
| attr_dict = {} | |||
| for index, element in enumerate(param_list): | |||
| if element is not None: | |||
| self.is_string(element) | |||
| if index == 0: | |||
| attr_dict["name"] = element | |||
| elif index == 1: | |||
| attr_dict["param_type"] = element | |||
| elif index == 2: | |||
| attr_dict["type"] = element | |||
| elif index == 3: | |||
| attr_dict["value"] = element | |||
| elif index == 4: | |||
| attr_dict["default_value"] = element | |||
| if kwargs: | |||
| attr_dict = dict(attr_dict, **kwargs) | |||
| key_list = ["name", "param_type", "type", "value", "default_value"] | |||
| fn_list = [self._is_string] | |||
| attr_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||
| self.attr_.append(attr_dict) | |||
| return self | |||
| def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): | |||
| """ | |||
| Register op input information. | |||
| Register TBE op input information. | |||
| Args: | |||
| index (int): Order of the input. Default: None. | |||
| @@ -319,32 +399,15 @@ class TBERegOp(RegOp): | |||
| kwargs (dict): Other information for the input. | |||
| """ | |||
| param_list = [index, name, need_compile, param_type, shape] | |||
| input_dict = {} | |||
| for idx, element in enumerate(param_list): | |||
| if element is not None: | |||
| if idx == 0: | |||
| self.is_int(element) | |||
| input_dict["index"] = element | |||
| elif idx == 1: | |||
| self.is_string(element) | |||
| input_dict["name"] = element | |||
| elif idx == 2: | |||
| self.is_bool(element) | |||
| input_dict["need_compile"] = element | |||
| elif idx == 3: | |||
| self.is_string(element) | |||
| input_dict["param_type"] = element | |||
| elif idx == 4: | |||
| self.is_string(element) | |||
| input_dict["shape"] = element | |||
| if kwargs: | |||
| input_dict = dict(input_dict, **kwargs) | |||
| key_list = ["index", "name", "need_compile", "param_type", "shape"] | |||
| fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string] | |||
| input_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||
| self.inputs.append(input_dict) | |||
| return self | |||
| def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): | |||
| """ | |||
| Register op output information. | |||
| Register TBE op output information. | |||
| Args: | |||
| index (int): Order of the output. Default: None. | |||
| @@ -355,29 +418,13 @@ class TBERegOp(RegOp): | |||
| kwargs (dict): Other information for the output. | |||
| """ | |||
| param_list = [index, name, need_compile, param_type, shape] | |||
| output_dict = {} | |||
| for idx, element in enumerate(param_list): | |||
| if element is not None: | |||
| if idx == 0: | |||
| self.is_int(element) | |||
| output_dict["index"] = element | |||
| elif idx == 1: | |||
| self.is_string(element) | |||
| output_dict["name"] = element | |||
| elif idx == 2: | |||
| self.is_bool(element) | |||
| output_dict["need_compile"] = element | |||
| elif idx == 3: | |||
| self.is_string(element) | |||
| output_dict["param_type"] = element | |||
| elif idx == 4: | |||
| self.is_string(element) | |||
| output_dict["shape"] = element | |||
| if kwargs: | |||
| output_dict = dict(output_dict, **kwargs) | |||
| key_list = ["index", "name", "need_compile", "param_type", "shape"] | |||
| fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string] | |||
| output_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||
| self.outputs.append(output_dict) | |||
| return self | |||
| class DataType(): | |||
| """ | |||
| Various combinations of dtype and formatself. | |||