| @@ -39,45 +39,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>; | using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>; | ||||
| const std::vector<std::string> local_framework_op_vec = {kInitDataSetQueue, kGetNext, kDropoutGenMask, kPrint}; | |||||
| void InitDataSetQueueAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||||
| MS_EXCEPTION_IF_NULL(node_attr); | |||||
| std::string channel_name = AnfAlgo::GetNodeAttr<std::string>(anf_node, kQueueName); | |||||
| (*node_attr)[kChannelName].set_s(channel_name); | |||||
| } | |||||
| void GetNextAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||||
| MS_EXCEPTION_IF_NULL(node_attr); | |||||
| std::string shared_name = AnfAlgo::GetNodeAttr<std::string>(anf_node, kSharedName); | |||||
| (*node_attr)[kChannelName].set_s(shared_name); | |||||
| } | |||||
| void DropoutGenMaskAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||||
| MS_EXCEPTION_IF_NULL(node_attr); | |||||
| int seed = AnfAlgo::GetNodeAttr<int>(anf_node, kSeed); | |||||
| int seed2 = AnfAlgo::GetNodeAttr<int>(anf_node, kSeed2); | |||||
| (*node_attr)["seed"].set_i(seed); | |||||
| (*node_attr)["seed2"].set_i(seed2); | |||||
| } | |||||
| void CreateAttrFuncMap(std::map<std::string, FNodeAttrHandle> *mOpAttrFuncMap) { | |||||
| (void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kInitDataSetQueue, InitDataSetQueueAttr)); | |||||
| (void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kGetNext, GetNextAttr)); | |||||
| (void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kDropoutGenMask, DropoutGenMaskAttr)); | |||||
| } | |||||
| const std::vector<std::string> local_framework_op_vec = {kInitData, kGetNext, kDropoutGenMask, kPrint}; | |||||
| bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num, | bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num, | ||||
| std::vector<size_t> *input_size_list) { | std::vector<size_t> *input_size_list) { | ||||
| @@ -147,24 +109,74 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A | |||||
| return true; | return true; | ||||
| } | } | ||||
| void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value, | |||||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) { | |||||
| MS_EXCEPTION_IF_NULL(node_attr); | |||||
| if (type == "int") { | |||||
| auto attr_value = GetValue<int>(value); | |||||
| (*node_attr)[attr_name].set_i(attr_value); | |||||
| } else if (type == "str") { | |||||
| auto attr_value = GetValue<std::string>(value); | |||||
| (*node_attr)[attr_name].set_s(attr_value); | |||||
| } else if (type == "bool") { | |||||
| auto attr_value = GetValue<bool>(value); | |||||
| (*node_attr)[attr_name].set_b(attr_value); | |||||
| } else if (type == "float") { | |||||
| auto attr_value = GetValue<float>(value); | |||||
| (*node_attr)[attr_name].set_f(attr_value); | |||||
| } else if (type == "listInt") { | |||||
| std::vector<int> attr_value; | |||||
| auto value_type = value->type(); | |||||
| MS_EXCEPTION_IF_NULL(value_type); | |||||
| auto value_type_str = value_type->ToString(); | |||||
| if (value_type_str == "Int32") { | |||||
| int data = GetValue<int>(value); | |||||
| attr_value.push_back(data); | |||||
| } else { | |||||
| attr_value = GetValue<std::vector<int>>(value); | |||||
| } | |||||
| mindspore::AttrValue input_shape_attr; | |||||
| mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array(); | |||||
| MS_EXCEPTION_IF_NULL(input_shape_attr_list); | |||||
| for (const auto shape : attr_value) { | |||||
| input_shape_attr_list->add_i(shape); | |||||
| } | |||||
| (*node_attr)[attr_name] = input_shape_attr; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "type: " << type << "not support"; | |||||
| } | |||||
| } | |||||
| void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | ||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | std::string op_name = AnfAlgo::GetCNodeName(anf_node); | ||||
| if (op_name == "InitDataSetQueue") { | |||||
| op_name = "InitData"; | |||||
| if (op_name == kInitDataSetQueue) { | |||||
| op_name = kInitData; | |||||
| } | } | ||||
| if (op_name == "Print") { | |||||
| if (op_name == kPrint) { | |||||
| return; | return; | ||||
| } | } | ||||
| std::map<std::string, FNodeAttrHandle> mOpAttrFuncMap; | |||||
| CreateAttrFuncMap(&mOpAttrFuncMap); | |||||
| FNodeAttrHandle func_ptr = nullptr; | |||||
| auto iter = mOpAttrFuncMap.find(op_name); | |||||
| if (iter != mOpAttrFuncMap.end()) { | |||||
| func_ptr = iter->second; | |||||
| MS_EXCEPTION_IF_NULL(func_ptr); | |||||
| func_ptr(anf_node, proto); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Don't support node [" << op_name << "] to set nodedef of attr"; | |||||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); | |||||
| MS_EXCEPTION_IF_NULL(op_info_ptr); | |||||
| auto attrs_ptr = op_info_ptr->attrs_ptr(); | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||||
| for (const auto &attr_ptr : attrs_ptr) { | |||||
| std::string attr_name = attr_ptr->name(); | |||||
| std::string real_name; | |||||
| auto value = primitive->GetAttr(attr_name); | |||||
| if (value != nullptr) { | |||||
| if (attr_name == kQueueName || attr_name == kSharedName) { | |||||
| real_name = kChannelName; | |||||
| } else if (attr_name == kSeed) { | |||||
| real_name = "seed"; | |||||
| } else if (attr_name == kSeed2) { | |||||
| real_name = "seed2"; | |||||
| } | |||||
| std::string type = attr_ptr->type(); | |||||
| ParseAttrValue(type, real_name, value, node_attr); | |||||
| } | |||||
| } | } | ||||
| MS_LOG(INFO) << "Set node attr end!"; | MS_LOG(INFO) << "Set node attr end!"; | ||||
| } | } | ||||
| @@ -17,68 +17,27 @@ | |||||
| #include "kernel/aicpu/aicpu_kernel_metadata.h" | #include "kernel/aicpu/aicpu_kernel_metadata.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "kernel/common_utils.h" | |||||
| #include "kernel/aicpu/aicpu_util.h" | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| constexpr auto kInitDataSetQueueOpName = "InitDataSetQueue"; | |||||
| constexpr auto kGetNext = "GetNext"; | |||||
| constexpr auto kDropoutGenMask = "DropoutGenMask"; | |||||
| constexpr auto kPrint = "Print"; | |||||
| const std::vector<std::string> AICPU_OPS = {kInitDataSetQueueOpName, kGetNext, kDropoutGenMask, kPrint}; | |||||
| std::shared_ptr<KernelBuildInfo> CreateKernelInfo(const std::vector<std::string> &inputs_format, | |||||
| const std::vector<TypeId> &inputs_device_type, | |||||
| const std::vector<std::string> &outputs_format, | |||||
| const std::vector<TypeId> &outputs_device_type) { | |||||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetInputsDeviceType(inputs_device_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_device_type); | |||||
| builder.SetProcessor(AICPU); | |||||
| builder.SetKernelType(AICPU_KERNEL); | |||||
| builder.SetFusionType(OPAQUE); | |||||
| return builder.Build(); | |||||
| } | |||||
| bool CheckIfExistAicpuMeta(const std::string &op_name) { | |||||
| if (std::find(AICPU_OPS.begin(), AICPU_OPS.end(), op_name) != AICPU_OPS.end()) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | ||||
| MS_LOG(INFO) << "AicpuMetadataInfo."; | MS_LOG(INFO) << "AicpuMetadataInfo."; | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | ||||
| if (CheckIfExistAicpuMeta(op_name)) { | |||||
| MS_LOG(DEBUG) << "Aicpu doesn't have metadata of op [" << op_name << "]."; | |||||
| return; | |||||
| } | |||||
| if (op_name == kInitDataSetQueueOpName) { | |||||
| kernel_info_list->push_back(CreateKernelInfo({}, {}, {}, {})); | |||||
| if (op_name == kInitDataSetQueue) { | |||||
| op_name = kInitData; | |||||
| } | } | ||||
| if (op_name == kGetNext) { | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> outputs_type; | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||||
| } | |||||
| kernel_info_list->push_back(CreateKernelInfo({}, {}, outputs_format, outputs_type)); | |||||
| } | |||||
| if (op_name == kDropoutGenMask) { | |||||
| kernel_info_list->push_back(CreateKernelInfo({kOpFormat_NCHW, kOpFormat_NCHW}, | |||||
| {kInt32->type_id(), kFloat16->type_id()}, {kOpFormat_NCHW}, | |||||
| {kUInt8->type_id()})); | |||||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); | |||||
| if (op_info_ptr == nullptr) { | |||||
| MS_LOG(WARNING) << "Aicpu doestn't have metadata of op [" << op_name << "]"; | |||||
| return; | |||||
| } | } | ||||
| // For compatibility with the current framework | |||||
| if (op_name == kPrint) { | if (op_name == kPrint) { | ||||
| std::vector<std::string> inputs_format; | std::vector<std::string> inputs_format; | ||||
| std::vector<TypeId> inputs_type; | std::vector<TypeId> inputs_type; | ||||
| @@ -92,11 +51,20 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr< | |||||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | outputs_format.emplace_back(kOpFormat_DEFAULT); | ||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | ||||
| } | } | ||||
| kernel_info_list->push_back(CreateKernelInfo(inputs_format, inputs_type, outputs_format, outputs_type)); | |||||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetInputsDeviceType(inputs_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_type); | |||||
| builder.SetProcessor(AICPU); | |||||
| builder.SetKernelType(AICPU_KERNEL); | |||||
| builder.SetFusionType(OPAQUE); | |||||
| kernel_info_list->push_back(builder.Build()); | |||||
| return; | |||||
| } | } | ||||
| if (kernel_info_list->empty()) { | |||||
| MS_LOG(INFO) << "Aicpu dose not has metadata of op[ " << op_name << "]."; | |||||
| if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) { | |||||
| MS_LOG(WARNING) << "Aicpu parsed metadata op [" << op_name << "] failed"; | |||||
| return; | |||||
| } | } | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -24,7 +24,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| constexpr auto kInitDataSetQueue = "InitData"; | |||||
| constexpr auto kInitDataSetQueue = "InitDataSetQueue"; | |||||
| constexpr auto kInitData = "InitData"; | |||||
| constexpr auto kGetNext = "GetNext"; | constexpr auto kGetNext = "GetNext"; | ||||
| constexpr auto kDropoutGenMask = "DropoutGenMask"; | constexpr auto kDropoutGenMask = "DropoutGenMask"; | ||||
| constexpr auto kPrint = "Print"; | constexpr auto kPrint = "Print"; | ||||
| @@ -417,6 +417,8 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu | |||||
| if (imply_type == kAKG) { | if (imply_type == kAKG) { | ||||
| builder->SetKernelType(AUTO_DIFF_KERNEL); | builder->SetKernelType(AUTO_DIFF_KERNEL); | ||||
| } else if (imply_type == kAICPU) { | |||||
| builder->SetKernelType(AICPU_KERNEL); | |||||
| } else { | } else { | ||||
| builder->SetKernelType(TBE_KERNEL); | builder->SetKernelType(TBE_KERNEL); | ||||
| } | } | ||||
| @@ -471,6 +473,13 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn | |||||
| return false; | return false; | ||||
| } | } | ||||
| kernel_info_list->push_back(builder->Build()); | |||||
| } | |||||
| } else { | |||||
| if (processor == AICPU) { | |||||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| SetKernelBuildInfo(builder, processor, op_info_ptr); | |||||
| kernel_info_list->push_back(builder->Build()); | kernel_info_list->push_back(builder->Build()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,7 +24,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| enum OpImplyType { kAKG = 0, kTBE }; | |||||
| enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU }; | |||||
| enum OpIOType { kInput = 0, kOutput }; | enum OpIOType { kInput = 0, kOutput }; | ||||
| class OpAttr { | class OpAttr { | ||||
| @@ -39,6 +39,7 @@ constexpr auto kDtypeFormat = "dtype_format"; | |||||
| constexpr auto kAttr = "attr"; | constexpr auto kAttr = "attr"; | ||||
| constexpr auto kIputs = "inputs"; | constexpr auto kIputs = "inputs"; | ||||
| constexpr auto kOutputs = "outputs"; | constexpr auto kOutputs = "outputs"; | ||||
| constexpr auto kAiCPU = "AiCPU"; | |||||
| constexpr auto kTbe = "TBE"; | constexpr auto kTbe = "TBE"; | ||||
| constexpr auto kAkg = "akg"; | constexpr auto kAkg = "akg"; | ||||
| constexpr auto kAutodiff = "AutoDiff"; | constexpr auto kAutodiff = "AutoDiff"; | ||||
| @@ -60,6 +61,8 @@ std::string ImplTypeToStr(OpImplyType impl_type) { | |||||
| return kTbe; | return kTbe; | ||||
| case kAKG: | case kAKG: | ||||
| return kAkg; | return kAkg; | ||||
| case kAICPU: | |||||
| return kAiCPU; | |||||
| default: | default: | ||||
| return "unknow"; | return "unknow"; | ||||
| } | } | ||||
| @@ -76,6 +79,9 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) | |||||
| } else if (imply_type_string == kAutodiff) { | } else if (imply_type_string == kAutodiff) { | ||||
| OpImplyType imply_type = kAKG; | OpImplyType imply_type = kAKG; | ||||
| ret = DecodeOpInfo(op_json, imply_type, impl_path); | ret = DecodeOpInfo(op_json, imply_type, impl_path); | ||||
| } else if (imply_type_string == kAiCPU) { | |||||
| OpImplyType imply_type = kAICPU; | |||||
| ret = DecodeOpInfo(op_json, imply_type, impl_path); | |||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Not support imply_type"; | MS_LOG(DEBUG) << "Not support imply_type"; | ||||
| } | } | ||||
| @@ -154,7 +160,9 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||||
| std::shared_ptr<OpAttr> op_attr = std::make_shared<OpAttr>(); | std::shared_ptr<OpAttr> op_attr = std::make_shared<OpAttr>(); | ||||
| MS_EXCEPTION_IF_NULL(op_attr); | MS_EXCEPTION_IF_NULL(op_attr); | ||||
| op_attr->set_name(obj.at(kName)); | op_attr->set_name(obj.at(kName)); | ||||
| op_attr->set_param_type(obj.at(kParamType)); | |||||
| if (imply_type != kAICPU) { | |||||
| op_attr->set_param_type(obj.at(kParamType)); | |||||
| } | |||||
| op_attr->set_type(obj.at(kType)); | op_attr->set_type(obj.at(kType)); | ||||
| if (imply_type == kTBE) { | if (imply_type == kTBE) { | ||||
| op_attr->set_value(obj.at(kValue)); | op_attr->set_value(obj.at(kValue)); | ||||
| @@ -242,9 +250,10 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im | |||||
| auto context = MsContext::GetInstance(); | auto context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| bool is_gpu = (context->device_target() == kGPUDevice); | bool is_gpu = (context->device_target() == kGPUDevice); | ||||
| if ((is_gpu && imply_type == kTBE) || (!is_gpu && imply_type != kTBE)) { | |||||
| MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << "imply_type:" << ImplTypeToStr(imply_type) | |||||
| << "current op num:" << op_info_.size(); | |||||
| if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) || | |||||
| (!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) { | |||||
| MS_LOG(ERROR) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type) | |||||
| << ", current op num:" << op_info_.size(); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| for (const auto& op_info : op_info_) { | for (const auto& op_info : op_info_) { | ||||
| @@ -253,8 +262,8 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im | |||||
| return op_info; | return op_info; | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << "imply_type:" << ImplTypeToStr(imply_type) | |||||
| << "current op num:" << op_info_.size(); | |||||
| MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type) | |||||
| << ", current op num:" << op_info_.size(); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ Note: | |||||
| from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | ||||
| from .op_info_register import op_info_register, TBERegOp, DataType | |||||
| from .op_info_register import op_info_register, AiCPURegOp, TBERegOp, DataType | |||||
| from .primitive import constexpr | from .primitive import constexpr | ||||
| from .._c_expression import signature_rw, signature_kind | from .._c_expression import signature_rw, signature_kind | ||||
| @@ -40,6 +40,6 @@ __primitive__ = [ | |||||
| ] | ] | ||||
| __all__ = ["get_vm_impl_fn", "vm_impl_registry", | __all__ = ["get_vm_impl_fn", "vm_impl_registry", | ||||
| "op_info_register", "TBERegOp", "DataType", | |||||
| "op_info_register", "AiCPURegOp", "TBERegOp", "DataType", | |||||
| "constexpr"] | "constexpr"] | ||||
| __all__.extend(__primitive__) | __all__.extend(__primitive__) | ||||
| @@ -16,5 +16,6 @@ | |||||
| from .akg.gpu import * | from .akg.gpu import * | ||||
| from .tbe import * | from .tbe import * | ||||
| from .aicpu import * | |||||
| __all__ = [] | __all__ = [] | ||||
| @@ -0,0 +1,19 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """aicpu ops""" | |||||
| from .init_data_set_queue import _init_data_set_queue_aicpu | |||||
| from .dropout_genmask import _dropout_genmask_aicpu | |||||
| from .get_next import _get_next_aicpu | |||||
| from .print_tensor import _print_aicpu | |||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """InitDataSetQueue op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| dropout_genmask_op_info = AiCPURegOp("DropoutGenMask") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x1", "required") \ | |||||
| .input(1, "x2", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .attr("Seed0", "int") \ | |||||
| .attr("Seed1", "int") \ | |||||
| .dtype_format(DataType.I32_NCHW, DataType.F16_NCHW, DataType.U8_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(dropout_genmask_op_info) | |||||
| def _dropout_genmask_aicpu(): | |||||
| """Dropout AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """InitDataSetQueue op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| get_next_op_info = AiCPURegOp("GetNext") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .output(0, "y", "dynamic") \ | |||||
| .attr("shared_name", "str") \ | |||||
| .dtype_format(DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F16_Default) \ | |||||
| .dtype_format(DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(get_next_op_info) | |||||
| def _get_next_aicpu(): | |||||
| """GetNext AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,27 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """InitDataSetQueue op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp | |||||
| init_data_set_queue_op_info = AiCPURegOp("InitData") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .attr("queue_name", "str") \ | |||||
| .get_op_info() | |||||
| @op_info_register(init_data_set_queue_op_info) | |||||
| def _init_data_set_queue_aicpu(): | |||||
| """InitDataSetQueue AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """InitDataSetQueue op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| print_op_info = AiCPURegOp("Print") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "dynamic") \ | |||||
| .output(0, "y", "required") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(print_op_info) | |||||
| def _print_aicpu(): | |||||
| """Print AiCPU register""" | |||||
| return | |||||
| @@ -78,14 +78,15 @@ class RegOp(): | |||||
| self.inputs = [] | self.inputs = [] | ||||
| self.outputs = [] | self.outputs = [] | ||||
| self.attr_ = [] | self.attr_ = [] | ||||
| self.fusion_type_ = '' | |||||
| self.dtype_format_ = [] | self.dtype_format_ = [] | ||||
| def is_string(self, value): | |||||
| def _is_string(self, value): | |||||
| """ | """ | ||||
| Check if the value is a str type. | Check if the value is a str type. | ||||
| Args: | Args: | ||||
| value: Parameter to to check. | |||||
| value: Parameter to be checked. | |||||
| Raises: | Raises: | ||||
| TypeError: If the type of value is not a str. | TypeError: If the type of value is not a str. | ||||
| @@ -93,12 +94,12 @@ class RegOp(): | |||||
| if not isinstance(value, str): | if not isinstance(value, str): | ||||
| raise TypeError("%s value must be str" % str(value)) | raise TypeError("%s value must be str" % str(value)) | ||||
| def is_int(self, value): | |||||
| def _is_int(self, value): | |||||
| """ | """ | ||||
| Check if the value is a int. | Check if the value is a int. | ||||
| Args: | Args: | ||||
| value: Parameter to to check. | |||||
| value: Parameter to be checked. | |||||
| Raises: | Raises: | ||||
| TypeError: If the type of value is not a int. | TypeError: If the type of value is not a int. | ||||
| @@ -106,12 +107,12 @@ class RegOp(): | |||||
| if not isinstance(value, int): | if not isinstance(value, int): | ||||
| raise TypeError("%s value must be int" % str(value)) | raise TypeError("%s value must be int" % str(value)) | ||||
| def is_bool(self, value): | |||||
| def _is_bool(self, value): | |||||
| """ | """ | ||||
| Check if the value is a bool. | Check if the value is a bool. | ||||
| Args: | Args: | ||||
| value: Parameter to to check. | |||||
| value: Parameter to be checked. | |||||
| Raises: | Raises: | ||||
| TypeError: If the type of value is not a bool. | TypeError: If the type of value is not a bool. | ||||
| @@ -119,6 +120,51 @@ class RegOp(): | |||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("%s value must be bool" % str(value)) | raise TypeError("%s value must be bool" % str(value)) | ||||
| def _check_param(self, param_list, key_list, fn_list, kwargs): | |||||
| """ | |||||
| Check if the parameter type is correct. | |||||
| Args: | |||||
| param_list (list): Parameter list to be checked. | |||||
| key_list (list): The keys of output dict. | |||||
| fn_list (list): Function used for parameter checking. If the function list has only one element, | |||||
| all parameters will use the same function. | |||||
| kwargs (dict): Other parameter information. | |||||
| Raises: | |||||
| TypeError: If the type of value is not list. | |||||
| ValueError: If the size of param list is not equal to the size of key list, or | |||||
| the size of param list is not equal to the size of funtion list. | |||||
| """ | |||||
| for i in [param_list, key_list, fn_list]: | |||||
| if not isinstance(i, list): | |||||
| raise TypeError("%s value must be list type" % str(i)) | |||||
| if len(param_list) != len(key_list) or (len(fn_list) != 1 and len(param_list) != len(fn_list)): | |||||
| raise ValueError("param_list size {}, key_list size {}, must be equal.And fn_list size {}.". | |||||
| format(len(param_list), len(key_list), len(fn_list))) | |||||
| out_dict = {} | |||||
| for idx, element in enumerate(param_list): | |||||
| if element is not None: | |||||
| if len(fn_list) == 1: | |||||
| fn_list[0](element) | |||||
| else: | |||||
| fn_list[idx](element) | |||||
| out_dict[key_list[idx]] = element | |||||
| if kwargs: | |||||
| out_dict = dict(out_dict, kwargs) | |||||
| return out_dict | |||||
| def fusion_type(self, fusion_type): | |||||
| """ | |||||
| Register fusion type. | |||||
| Args: | |||||
| fusion_type (str): Value of fusion type. | |||||
| """ | |||||
| self._is_string(fusion_type) | |||||
| self.fusion_type_ = fusion_type | |||||
| return self | |||||
| def dtype_format(self, *args): | def dtype_format(self, *args): | ||||
| """ | """ | ||||
| Register dtype and format. | Register dtype and format. | ||||
| @@ -136,8 +182,8 @@ class RegOp(): | |||||
| for arg in args: | for arg in args: | ||||
| if not isinstance(arg, tuple) or len(arg) != 2: | if not isinstance(arg, tuple) or len(arg) != 2: | ||||
| raise ValueError("dtype and format value must be tuple of two elements") | raise ValueError("dtype and format value must be tuple of two elements") | ||||
| self.is_string(arg[0]) | |||||
| self.is_string(arg[1]) | |||||
| self._is_string(arg[0]) | |||||
| self._is_string(arg[1]) | |||||
| dtype_format.append(arg) | dtype_format.append(arg) | ||||
| self.dtype_format_.append(tuple(dtype_format)) | self.dtype_format_.append(tuple(dtype_format)) | ||||
| return self | return self | ||||
| @@ -159,13 +205,71 @@ class RegOp(): | |||||
| return op_info | return op_info | ||||
| class AiCPURegOp(RegOp): | |||||
| """Class for AiCPU op info register""" | |||||
| def __init__(self, op_name): | |||||
| super(AiCPURegOp, self).__init__(op_name) | |||||
| self.imply_type = "AiCPU" | |||||
| def input(self, index=None, name=None, param_type=None, **kwargs): | |||||
| """ | |||||
| Register AiCPU op input information. | |||||
| Args: | |||||
| index (int): Order of the input. Default: None. | |||||
| name (str): Name of the input. Default: None. | |||||
| param_type (str): Param type of the input. Default: None. | |||||
| kwargs (dict): Other information for the input. | |||||
| """ | |||||
| param_list = [index, name, param_type] | |||||
| key_list = ["index", "name", "param_type"] | |||||
| fn_list = [self._is_int, self._is_string, self._is_string] | |||||
| input_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||||
| self.inputs.append(input_dict) | |||||
| return self | |||||
| def output(self, index=None, name=None, param_type=None, **kwargs): | |||||
| """ | |||||
| Register AiCPU op output information. | |||||
| Args: | |||||
| index (int): Order of the output. Default: None. | |||||
| name (str): Name of the output. Default: None. | |||||
| param_type (str): Param type of the output. Default: None. | |||||
| kwargs (dict): Other information for the output. | |||||
| """ | |||||
| param_list = [index, name, param_type] | |||||
| key_list = ["index", "name", "param_type"] | |||||
| fn_list = [self._is_int, self._is_string, self._is_string] | |||||
| output_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||||
| self.outputs.append(output_dict) | |||||
| return self | |||||
| def attr(self, name=None, value_type=None, value=None, **kwargs): | |||||
| """ | |||||
| Register AiCPU op attribute information. | |||||
| Args: | |||||
| name (str): Name of the attribute. Default: None. | |||||
| value_type (str): Value type of the attribute. Default: None. | |||||
| value (str): Value type of the attribute. Default: None. | |||||
| kwargs (dict): Other information for the attribute. | |||||
| """ | |||||
| param_list = [name, value_type, value] | |||||
| key_list = ["name", "type", "value"] | |||||
| fn_list = [self._is_string] | |||||
| attr_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||||
| self.attr_.append(attr_dict) | |||||
| return self | |||||
| class TBERegOp(RegOp): | class TBERegOp(RegOp): | ||||
| """Class for TBE op info register.""" | """Class for TBE op info register.""" | ||||
| def __init__(self, op_name=""): | def __init__(self, op_name=""): | ||||
| super(TBERegOp, self).__init__(op_name) | super(TBERegOp, self).__init__(op_name) | ||||
| self.imply_type = "TBE" | self.imply_type = "TBE" | ||||
| self.fusion_type_ = '' | |||||
| self.async_flag_ = False | self.async_flag_ = False | ||||
| self.binfile_name_ = '' | self.binfile_name_ = '' | ||||
| self.compute_cost_ = 10 | self.compute_cost_ = 10 | ||||
| @@ -175,17 +279,6 @@ class TBERegOp(RegOp): | |||||
| self.dynamic_format_ = False | self.dynamic_format_ = False | ||||
| self.op_pattern_ = "" | self.op_pattern_ = "" | ||||
| def fusion_type(self, fusion_type): | |||||
| """ | |||||
| Register fusion type. | |||||
| Args: | |||||
| fusion_type (str): Value of fusion type. | |||||
| """ | |||||
| self.is_string(fusion_type) | |||||
| self.fusion_type_ = fusion_type | |||||
| return self | |||||
| def async_flag(self, async_flag): | def async_flag(self, async_flag): | ||||
| """ | """ | ||||
| Register async flag. | Register async flag. | ||||
| @@ -193,7 +286,7 @@ class TBERegOp(RegOp): | |||||
| Args: | Args: | ||||
| async_flag (bool): Value of async flag. | async_flag (bool): Value of async flag. | ||||
| """ | """ | ||||
| self.is_bool(async_flag) | |||||
| self._is_bool(async_flag) | |||||
| self.async_flag_ = async_flag | self.async_flag_ = async_flag | ||||
| return self | return self | ||||
| @@ -204,7 +297,7 @@ class TBERegOp(RegOp): | |||||
| Args: | Args: | ||||
| binfile_name (str): Name of op binfile. | binfile_name (str): Name of op binfile. | ||||
| """ | """ | ||||
| self.is_string(binfile_name) | |||||
| self._is_string(binfile_name) | |||||
| self.binfile_name_ = binfile_name | self.binfile_name_ = binfile_name | ||||
| return self | return self | ||||
| @@ -215,7 +308,7 @@ class TBERegOp(RegOp): | |||||
| Args: | Args: | ||||
| compute_cost (int): Value of compute cost. | compute_cost (int): Value of compute cost. | ||||
| """ | """ | ||||
| self.is_int(compute_cost) | |||||
| self._is_int(compute_cost) | |||||
| self.compute_cost_ = compute_cost | self.compute_cost_ = compute_cost | ||||
| return self | return self | ||||
| @@ -226,7 +319,7 @@ class TBERegOp(RegOp): | |||||
| Args: | Args: | ||||
| kernel_name (str): Name of op kernel. | kernel_name (str): Name of op kernel. | ||||
| """ | """ | ||||
| self.is_string(kernel_name) | |||||
| self._is_string(kernel_name) | |||||
| self.kernel_name_ = kernel_name | self.kernel_name_ = kernel_name | ||||
| return self | return self | ||||
| @@ -237,7 +330,7 @@ class TBERegOp(RegOp): | |||||
| Args: | Args: | ||||
| partial_flag (bool): Value of partial flag. | partial_flag (bool): Value of partial flag. | ||||
| """ | """ | ||||
| self.is_bool(partial_flag) | |||||
| self._is_bool(partial_flag) | |||||
| self.partial_flag_ = partial_flag | self.partial_flag_ = partial_flag | ||||
| return self | return self | ||||
| @@ -248,7 +341,7 @@ class TBERegOp(RegOp): | |||||
| Args: | Args: | ||||
| reshape_type (str): Value of reshape type. | reshape_type (str): Value of reshape type. | ||||
| """ | """ | ||||
| self.is_string(reshape_type) | |||||
| self._is_string(reshape_type) | |||||
| self.reshape_type_ = reshape_type | self.reshape_type_ = reshape_type | ||||
| return self | return self | ||||
| @@ -259,56 +352,43 @@ class TBERegOp(RegOp): | |||||
| Args: | Args: | ||||
| reshape_type (bool): Value of dynamic format. | reshape_type (bool): Value of dynamic format. | ||||
| """ | """ | ||||
| self.is_bool(dynamic_format) | |||||
| self._is_bool(dynamic_format) | |||||
| self.dynamic_format_ = dynamic_format | self.dynamic_format_ = dynamic_format | ||||
| return self | return self | ||||
| def op_pattern(self, pattern=None): | def op_pattern(self, pattern=None): | ||||
| """ | """ | ||||
| Register op pattern information. | |||||
| Register TBE op pattern information. | |||||
| Args: | Args: | ||||
| pattern (str): Value of op pattern. | pattern (str): Value of op pattern. | ||||
| """ | """ | ||||
| if pattern is not None and self.istring(pattern): | |||||
| if pattern is not None and self._is_string(pattern): | |||||
| self.op_pattern_ = pattern | self.op_pattern_ = pattern | ||||
| return self | return self | ||||
| def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs): | def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs): | ||||
| """ | """ | ||||
| Register op attribute information. | |||||
| Register TBE op attribute information. | |||||
| Args: | Args: | ||||
| name (str): Name of the attribute. Default: None. | name (str): Name of the attribute. Default: None. | ||||
| param_type (str): Param type of the attribute. Default: None. | param_type (str): Param type of the attribute. Default: None. | ||||
| type (str): Type of the attribute. Default: None. | |||||
| value_type (str): Type of the attribute. Default: None. | |||||
| value (str): Value of the attribute. Default: None. | value (str): Value of the attribute. Default: None. | ||||
| default_value (str): Default value of attribute. Default: None. | default_value (str): Default value of attribute. Default: None. | ||||
| kwargs (dict): Other information for the attribute. | kwargs (dict): Other information for the attribute. | ||||
| """ | """ | ||||
| param_list = [name, param_type, value_type, value, default_value] | param_list = [name, param_type, value_type, value, default_value] | ||||
| attr_dict = {} | |||||
| for index, element in enumerate(param_list): | |||||
| if element is not None: | |||||
| self.is_string(element) | |||||
| if index == 0: | |||||
| attr_dict["name"] = element | |||||
| elif index == 1: | |||||
| attr_dict["param_type"] = element | |||||
| elif index == 2: | |||||
| attr_dict["type"] = element | |||||
| elif index == 3: | |||||
| attr_dict["value"] = element | |||||
| elif index == 4: | |||||
| attr_dict["default_value"] = element | |||||
| if kwargs: | |||||
| attr_dict = dict(attr_dict, **kwargs) | |||||
| key_list = ["name", "param_type", "type", "value", "default_value"] | |||||
| fn_list = [self._is_string] | |||||
| attr_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||||
| self.attr_.append(attr_dict) | self.attr_.append(attr_dict) | ||||
| return self | return self | ||||
| def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): | def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): | ||||
| """ | """ | ||||
| Register op input information. | |||||
| Register TBE op input information. | |||||
| Args: | Args: | ||||
| index (int): Order of the input. Default: None. | index (int): Order of the input. Default: None. | ||||
| @@ -319,32 +399,15 @@ class TBERegOp(RegOp): | |||||
| kwargs (dict): Other information for the input. | kwargs (dict): Other information for the input. | ||||
| """ | """ | ||||
| param_list = [index, name, need_compile, param_type, shape] | param_list = [index, name, need_compile, param_type, shape] | ||||
| input_dict = {} | |||||
| for idx, element in enumerate(param_list): | |||||
| if element is not None: | |||||
| if idx == 0: | |||||
| self.is_int(element) | |||||
| input_dict["index"] = element | |||||
| elif idx == 1: | |||||
| self.is_string(element) | |||||
| input_dict["name"] = element | |||||
| elif idx == 2: | |||||
| self.is_bool(element) | |||||
| input_dict["need_compile"] = element | |||||
| elif idx == 3: | |||||
| self.is_string(element) | |||||
| input_dict["param_type"] = element | |||||
| elif idx == 4: | |||||
| self.is_string(element) | |||||
| input_dict["shape"] = element | |||||
| if kwargs: | |||||
| input_dict = dict(input_dict, **kwargs) | |||||
| key_list = ["index", "name", "need_compile", "param_type", "shape"] | |||||
| fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string] | |||||
| input_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||||
| self.inputs.append(input_dict) | self.inputs.append(input_dict) | ||||
| return self | return self | ||||
| def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): | def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): | ||||
| """ | """ | ||||
| Register op output information. | |||||
| Register TBE op output information. | |||||
| Args: | Args: | ||||
| index (int): Order of the output. Default: None. | index (int): Order of the output. Default: None. | ||||
| @@ -355,29 +418,13 @@ class TBERegOp(RegOp): | |||||
| kwargs (dict): Other information for the output. | kwargs (dict): Other information for the output. | ||||
| """ | """ | ||||
| param_list = [index, name, need_compile, param_type, shape] | param_list = [index, name, need_compile, param_type, shape] | ||||
| output_dict = {} | |||||
| for idx, element in enumerate(param_list): | |||||
| if element is not None: | |||||
| if idx == 0: | |||||
| self.is_int(element) | |||||
| output_dict["index"] = element | |||||
| elif idx == 1: | |||||
| self.is_string(element) | |||||
| output_dict["name"] = element | |||||
| elif idx == 2: | |||||
| self.is_bool(element) | |||||
| output_dict["need_compile"] = element | |||||
| elif idx == 3: | |||||
| self.is_string(element) | |||||
| output_dict["param_type"] = element | |||||
| elif idx == 4: | |||||
| self.is_string(element) | |||||
| output_dict["shape"] = element | |||||
| if kwargs: | |||||
| output_dict = dict(output_dict, **kwargs) | |||||
| key_list = ["index", "name", "need_compile", "param_type", "shape"] | |||||
| fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string] | |||||
| output_dict = self._check_param(param_list, key_list, fn_list, kwargs) | |||||
| self.outputs.append(output_dict) | self.outputs.append(output_dict) | ||||
| return self | return self | ||||
| class DataType(): | class DataType(): | ||||
| """ | """ | ||||
| Various combinations of dtype and formatself. | Various combinations of dtype and formatself. | ||||