| @@ -425,7 +425,7 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo( | |||||
| return kernel_info_list[selected_index]; | return kernel_info_list[selected_index]; | ||||
| } | } | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetAllMatchedFilteredKernelInfo( | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype( | |||||
| const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result; | ||||
| for (const auto &kernel_build_info : kernel_info_list) { | for (const auto &kernel_build_info : kernel_info_list) { | ||||
| @@ -474,7 +474,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | |||||
| std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; | std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; | ||||
| // Matched kernel info | // Matched kernel info | ||||
| // Filter kernel info matched with me infered type | // Filter kernel info matched with me infered type | ||||
| auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); | |||||
| auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list); | |||||
| if (!filtered_kernel_info_list.empty()) { | if (!filtered_kernel_info_list.empty()) { | ||||
| selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | ||||
| select_status = kStatusAllMatched; | select_status = kStatusAllMatched; | ||||
| @@ -508,6 +508,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | ||||
| kernel::AICpuQuery(kernel_node, &kernel_info_list); | kernel::AICpuQuery(kernel_node, &kernel_info_list); | ||||
| select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | ||||
| AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); | |||||
| } | } | ||||
| // The kernel info not finded both in the aicpu kernel list & aicore kernel list | // The kernel info not finded both in the aicpu kernel list & aicore kernel list | ||||
| if (select_status == kNoMatched) { | if (select_status == kNoMatched) { | ||||
| @@ -47,6 +47,13 @@ enum FusionType { | |||||
| OPAQUE, | OPAQUE, | ||||
| UNKNOWN_FUSION_TYPE = -1, | UNKNOWN_FUSION_TYPE = -1, | ||||
| }; | }; | ||||
| enum OpPattern { | |||||
| kCommonPattern = 0, | |||||
| kFormatAgnosticPattern = 1, | |||||
| kBroadcastPattern = 2, | |||||
| kReducePattern = 3, | |||||
| kDynamicFormatPattern = 4, | |||||
| }; | |||||
| // Backend processor | // Backend processor | ||||
| enum Processor { | enum Processor { | ||||
| @@ -162,5 +162,10 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( | |||||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | MS_EXCEPTION_IF_NULL(kernel_build_info_); | ||||
| kernel_build_info_->output_reshape_type_ = output_reshape_type; | kernel_build_info_->output_reshape_type_ = output_reshape_type; | ||||
| } | } | ||||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||||
| kernel_build_info_->op_pattern_ = pattern; | |||||
| } | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,6 +34,7 @@ class KernelBuildInfo { | |||||
| kernel_type_ = AUTO_DIFF_KERNEL; | kernel_type_ = AUTO_DIFF_KERNEL; | ||||
| fusion_type_ = OPAQUE; | fusion_type_ = OPAQUE; | ||||
| processor_ = AICORE; | processor_ = AICORE; | ||||
| op_pattern_ = kCommonPattern; | |||||
| input_reshape_type_ = {}; | input_reshape_type_ = {}; | ||||
| output_reshape_type_ = {}; | output_reshape_type_ = {}; | ||||
| inputs_format_ = {}; | inputs_format_ = {}; | ||||
| @@ -70,6 +71,8 @@ class KernelBuildInfo { | |||||
| std::vector<TypeId> GetAllOutputDeviceTypes() const; | std::vector<TypeId> GetAllOutputDeviceTypes() const; | ||||
| OpPattern op_pattern() const { return op_pattern_; } | |||||
| FusionType fusion_type() const { return fusion_type_; } | FusionType fusion_type() const { return fusion_type_; } | ||||
| Processor processor() const { return processor_; } | Processor processor() const { return processor_; } | ||||
| @@ -88,6 +91,7 @@ class KernelBuildInfo { | |||||
| private: | private: | ||||
| KernelType kernel_type_; | KernelType kernel_type_; | ||||
| std::vector<std::string> inputs_format_; | std::vector<std::string> inputs_format_; | ||||
| OpPattern op_pattern_; | |||||
| std::vector<std::string> outputs_format_; | std::vector<std::string> outputs_format_; | ||||
| std::vector<std::vector<Axis>> input_reshape_type_; | std::vector<std::vector<Axis>> input_reshape_type_; | ||||
| std::vector<std::vector<Axis>> output_reshape_type_; | std::vector<std::vector<Axis>> output_reshape_type_; | ||||
| @@ -125,6 +129,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||||
| void SetProcessor(Processor processor); | void SetProcessor(Processor processor); | ||||
| void SetOpPattern(OpPattern pattern); | |||||
| std::shared_ptr<KernelBuildInfo> Build(); | std::shared_ptr<KernelBuildInfo> Build(); | ||||
| private: | private: | ||||
| @@ -40,7 +40,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, | |||||
| (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); | (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); | ||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; | MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; | ||||
| for (size_t index; index < kernel_info_list->size(); ++index) { | |||||
| for (size_t index = 0; index < kernel_info_list->size(); ++index) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); | MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); | ||||
| MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); | MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); | ||||
| } | } | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| #include "kernel/kernel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| @@ -100,7 +101,7 @@ class OpInfo { | |||||
| std::string kernel_name() const { return kernel_name_; } | std::string kernel_name() const { return kernel_name_; } | ||||
| bool partial_flag() const { return partial_flag_; } | bool partial_flag() const { return partial_flag_; } | ||||
| bool dynamic_format() const { return dynamic_format_; } | bool dynamic_format() const { return dynamic_format_; } | ||||
| std::string op_pattern() const { return op_pattern_; } | |||||
| OpPattern op_pattern() const { return op_pattern_; } | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | ||||
| @@ -116,7 +117,7 @@ class OpInfo { | |||||
| void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } | void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } | ||||
| void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | ||||
| void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } | void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } | ||||
| void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } | |||||
| void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } | |||||
| void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | ||||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | ||||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | ||||
| @@ -137,7 +138,7 @@ class OpInfo { | |||||
| std::string kernel_name_; | std::string kernel_name_; | ||||
| bool partial_flag_ = false; | bool partial_flag_ = false; | ||||
| bool dynamic_format_ = false; | bool dynamic_format_ = false; | ||||
| std::string op_pattern_; | |||||
| OpPattern op_pattern_ = kCommonPattern; | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; | ||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; | std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <pybind11/pybind11.h> | #include <pybind11/pybind11.h> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/overload.h" | #include "utils/overload.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| @@ -35,6 +36,9 @@ constexpr auto kPartialFlag = "partial_flag"; | |||||
| constexpr auto kReshapeType = "reshape_type"; | constexpr auto kReshapeType = "reshape_type"; | ||||
| constexpr auto kOpPattern = "op_pattern"; | constexpr auto kOpPattern = "op_pattern"; | ||||
| constexpr auto kDynamicFormat = "dynamic_format"; | constexpr auto kDynamicFormat = "dynamic_format"; | ||||
| constexpr auto kFormatAgnostic = "formatAgnostic"; | |||||
| constexpr auto kBroadcast = "broadcast"; | |||||
| constexpr auto kReduce = "reduce"; | |||||
| constexpr auto kDtypeFormat = "dtype_format"; | constexpr auto kDtypeFormat = "dtype_format"; | ||||
| constexpr auto kAttr = "attr"; | constexpr auto kAttr = "attr"; | ||||
| constexpr auto kIputs = "inputs"; | constexpr auto kIputs = "inputs"; | ||||
| @@ -95,13 +99,19 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) | |||||
| } | } | ||||
| void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | ||||
| const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, | |||||
| {kFormatAgnostic, kBroadcastPattern}, | |||||
| {kReduce, kReducePattern}, | |||||
| {kDynamicFormat, kDynamicFormatPattern}}; | |||||
| op_info->set_async_flag(obj.at(kAsyncFlag)); | op_info->set_async_flag(obj.at(kAsyncFlag)); | ||||
| op_info->set_binfile_name(obj.at(kBinfileName)); | op_info->set_binfile_name(obj.at(kBinfileName)); | ||||
| op_info->set_compute_cost(obj.at(kComputeCost)); | op_info->set_compute_cost(obj.at(kComputeCost)); | ||||
| op_info->set_kernel_name(obj.at(kKernelName)); | op_info->set_kernel_name(obj.at(kKernelName)); | ||||
| op_info->set_partial_flag(obj.at(kPartialFlag)); | op_info->set_partial_flag(obj.at(kPartialFlag)); | ||||
| if (obj.find(kOpPattern) != obj.end()) { | if (obj.find(kOpPattern) != obj.end()) { | ||||
| op_info->set_op_pattern(obj.at(kOpPattern)); | |||||
| if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) { | |||||
| op_info->set_op_pattern(obj.at(kOpPattern)); | |||||
| } | |||||
| } | } | ||||
| if (obj.find(kDynamicFormat) != obj.end()) { | if (obj.find(kDynamicFormat) != obj.end()) { | ||||
| op_info->set_dynamic_format(obj.at(kDynamicFormat)); | op_info->set_dynamic_format(obj.at(kDynamicFormat)); | ||||
| @@ -492,6 +492,7 @@ void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuild | |||||
| if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { | if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { | ||||
| builder->SetFusionType(tbe::GetFusionType(fusion_type)); | builder->SetFusionType(tbe::GetFusionType(fusion_type)); | ||||
| } | } | ||||
| builder->SetOpPattern(op_info_ptr->op_pattern()); | |||||
| builder->SetKernelType(TBE_KERNEL); | builder->SetKernelType(TBE_KERNEL); | ||||
| } | } | ||||
| @@ -509,7 +510,7 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn | |||||
| if (primitive->GetAttr("dyn_input_sizes") != nullptr) { | if (primitive->GetAttr("dyn_input_sizes") != nullptr) { | ||||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes")); | dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes")); | ||||
| } | } | ||||
| if (inputs.size() > 0) { | |||||
| if (!inputs.empty()) { | |||||
| MS_EXCEPTION_IF_NULL(inputs[0]); | MS_EXCEPTION_IF_NULL(inputs[0]); | ||||
| size_t kernel_info_cnt = inputs[0]->dtypes().size(); | size_t kernel_info_cnt = inputs[0]->dtypes().size(); | ||||
| for (size_t j = 0; j < kernel_info_cnt; j++) { | for (size_t j = 0; j < kernel_info_cnt; j++) { | ||||
| @@ -624,21 +625,17 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| for (auto parse_info : parse_info_list) { | |||||
| if (context_ptr->execution_mode() == kPynativeMode) { | |||||
| kernel_info_list->push_back(parse_info); | |||||
| } else { | |||||
| if (IsValidKernelInfo(kernel_node, *(parse_info))) { | |||||
| if (CheckSupported(kernel_node, parse_info)) { | |||||
| kernel_info_list->push_back(parse_info); | |||||
| } else { | |||||
| MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | |||||
| } | |||||
| for (const auto &parse_info : parse_info_list) { | |||||
| if (IsValidKernelInfo(kernel_node, *(parse_info))) { | |||||
| if (CheckSupported(kernel_node, parse_info)) { | |||||
| kernel_info_list->push_back(parse_info); | |||||
| } else { | |||||
| MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| if (kernel_info_list->empty()) { | |||||
| MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; | |||||
| if (kernel_info_list->empty()) { | |||||
| MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -44,6 +44,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info); | auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info); | ||||
| builder->SetKernelType(AICPU_KERNEL); | builder->SetKernelType(AICPU_KERNEL); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | ||||
| AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), node); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" | MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" | ||||
| << node->DebugString() << "]"; | << node->DebugString() << "]"; | ||||
| @@ -657,6 +657,16 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_ | |||||
| to_node->set_abstract(from_node->abstract()); | to_node->set_abstract(from_node->abstract()); | ||||
| } | } | ||||
| kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto kernel_info = node->kernel_info(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| // select_kernel_build_info() has checked whether return pointer is null | |||||
| auto build_info = kernel_info->select_kernel_build_info(); | |||||
| MS_EXCEPTION_IF_NULL(build_info); | |||||
| return build_info->op_pattern(); | |||||
| } | |||||
| // get KernelBuildType of node, such as ATT,RT,FWK and so on | // get KernelBuildType of node, such as ATT,RT,FWK and so on | ||||
| KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { | KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -138,6 +138,8 @@ class AnfRuntimeAlgorithm { | |||||
| static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, | static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, | ||||
| const std::vector<std::vector<size_t>> &shapes, AnfNode *node); | const std::vector<std::vector<size_t>> &shapes, AnfNode *node); | ||||
| static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); | static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); | ||||
| // get op pattern of the node | |||||
| static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); | |||||
| // get KernelBuildType of node ,such as ATT,RT,FWK and so on | // get KernelBuildType of node ,such as ATT,RT,FWK and so on | ||||
| static KernelType GetKernelType(const AnfNodePtr &node); | static KernelType GetKernelType(const AnfNodePtr &node); | ||||
| // get processor type:AICORE,AICPU... | // get processor type:AICORE,AICPU... | ||||
| @@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto"; | |||||
| // attr key name | // attr key name | ||||
| constexpr auto kAttrInputNames = "input_names"; | constexpr auto kAttrInputNames = "input_names"; | ||||
| constexpr auto kAttrIsAICPUKernel = "is_ai_cpu_kernel"; | |||||
| constexpr auto kIsBackendCast = "is_backed_cast"; | constexpr auto kIsBackendCast = "is_backed_cast"; | ||||
| constexpr auto kAttrOutputNames = "output_names"; | constexpr auto kAttrOutputNames = "output_names"; | ||||
| constexpr auto kAttrVisited = "visited"; | constexpr auto kAttrVisited = "visited"; | ||||
| @@ -215,10 +216,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; | |||||
| constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | ||||
| constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | ||||
| constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | ||||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||||
| constexpr auto kOpFormat_NDHWC = "NDHWC"; | |||||
| const std::set<std::string> kOpFormatList = { | |||||
| kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, | |||||
| kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, | |||||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC}; | |||||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | ||||
| const std::set<std::string> kOptOperatorSet = { | const std::set<std::string> kOptOperatorSet = { | ||||
| kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, | kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, | ||||