| @@ -425,7 +425,7 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo( | |||
| return kernel_info_list[selected_index]; | |||
| } | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetAllMatchedFilteredKernelInfo( | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype( | |||
| const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result; | |||
| for (const auto &kernel_build_info : kernel_info_list) { | |||
| @@ -474,7 +474,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | |||
| std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; | |||
| // Matched kernel info | |||
| // Filter kernel info matched with me infered type | |||
| auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); | |||
| auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list); | |||
| if (!filtered_kernel_info_list.empty()) { | |||
| selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | |||
| select_status = kStatusAllMatched; | |||
| @@ -508,6 +508,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { | |||
| << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | |||
| kernel::AICpuQuery(kernel_node, &kernel_info_list); | |||
| select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | |||
| AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); | |||
| } | |||
| // The kernel info not finded both in the aicpu kernel list & aicore kernel list | |||
| if (select_status == kNoMatched) { | |||
| @@ -47,6 +47,13 @@ enum FusionType { | |||
| OPAQUE, | |||
| UNKNOWN_FUSION_TYPE = -1, | |||
| }; | |||
| enum OpPattern { | |||
| kCommonPattern = 0, | |||
| kFormatAgnosticPattern = 1, | |||
| kBroadcastPattern = 2, | |||
| kReducePattern = 3, | |||
| kDynamicFormatPattern = 4, | |||
| }; | |||
| // Backend processor | |||
| enum Processor { | |||
| @@ -162,5 +162,10 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->output_reshape_type_ = output_reshape_type; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->op_pattern_ = pattern; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -34,6 +34,7 @@ class KernelBuildInfo { | |||
| kernel_type_ = AUTO_DIFF_KERNEL; | |||
| fusion_type_ = OPAQUE; | |||
| processor_ = AICORE; | |||
| op_pattern_ = kCommonPattern; | |||
| input_reshape_type_ = {}; | |||
| output_reshape_type_ = {}; | |||
| inputs_format_ = {}; | |||
| @@ -70,6 +71,8 @@ class KernelBuildInfo { | |||
| std::vector<TypeId> GetAllOutputDeviceTypes() const; | |||
| OpPattern op_pattern() const { return op_pattern_; } | |||
| FusionType fusion_type() const { return fusion_type_; } | |||
| Processor processor() const { return processor_; } | |||
| @@ -88,6 +91,7 @@ class KernelBuildInfo { | |||
| private: | |||
| KernelType kernel_type_; | |||
| std::vector<std::string> inputs_format_; | |||
| OpPattern op_pattern_; | |||
| std::vector<std::string> outputs_format_; | |||
| std::vector<std::vector<Axis>> input_reshape_type_; | |||
| std::vector<std::vector<Axis>> output_reshape_type_; | |||
| @@ -125,6 +129,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||
| void SetProcessor(Processor processor); | |||
| void SetOpPattern(OpPattern pattern); | |||
| std::shared_ptr<KernelBuildInfo> Build(); | |||
| private: | |||
| @@ -40,7 +40,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, | |||
| (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); | |||
| } else { | |||
| MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; | |||
| for (size_t index; index < kernel_info_list->size(); ++index) { | |||
| for (size_t index = 0; index < kernel_info_list->size(); ++index) { | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); | |||
| MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); | |||
| } | |||
| @@ -21,6 +21,7 @@ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "ir/dtype.h" | |||
| #include "kernel/kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -100,7 +101,7 @@ class OpInfo { | |||
| std::string kernel_name() const { return kernel_name_; } | |||
| bool partial_flag() const { return partial_flag_; } | |||
| bool dynamic_format() const { return dynamic_format_; } | |||
| std::string op_pattern() const { return op_pattern_; } | |||
| OpPattern op_pattern() const { return op_pattern_; } | |||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | |||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | |||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | |||
| @@ -116,7 +117,7 @@ class OpInfo { | |||
| void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } | |||
| void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | |||
| void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } | |||
| void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } | |||
| void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } | |||
| void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | |||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | |||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | |||
| @@ -137,7 +138,7 @@ class OpInfo { | |||
| std::string kernel_name_; | |||
| bool partial_flag_ = false; | |||
| bool dynamic_format_ = false; | |||
| std::string op_pattern_; | |||
| OpPattern op_pattern_ = kCommonPattern; | |||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | |||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; | |||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <pybind11/pybind11.h> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <map> | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/overload.h" | |||
| #include "utils/context/ms_context.h" | |||
| @@ -35,6 +36,9 @@ constexpr auto kPartialFlag = "partial_flag"; | |||
| constexpr auto kReshapeType = "reshape_type"; | |||
| constexpr auto kOpPattern = "op_pattern"; | |||
| constexpr auto kDynamicFormat = "dynamic_format"; | |||
| constexpr auto kFormatAgnostic = "formatAgnostic"; | |||
| constexpr auto kBroadcast = "broadcast"; | |||
| constexpr auto kReduce = "reduce"; | |||
| constexpr auto kDtypeFormat = "dtype_format"; | |||
| constexpr auto kAttr = "attr"; | |||
| constexpr auto kIputs = "inputs"; | |||
| @@ -95,13 +99,19 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) | |||
| } | |||
| void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | |||
| const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, | |||
| {kFormatAgnostic, kBroadcastPattern}, | |||
| {kReduce, kReducePattern}, | |||
| {kDynamicFormat, kDynamicFormatPattern}}; | |||
| op_info->set_async_flag(obj.at(kAsyncFlag)); | |||
| op_info->set_binfile_name(obj.at(kBinfileName)); | |||
| op_info->set_compute_cost(obj.at(kComputeCost)); | |||
| op_info->set_kernel_name(obj.at(kKernelName)); | |||
| op_info->set_partial_flag(obj.at(kPartialFlag)); | |||
| if (obj.find(kOpPattern) != obj.end()) { | |||
| op_info->set_op_pattern(obj.at(kOpPattern)); | |||
| if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) { | |||
| op_info->set_op_pattern(obj.at(kOpPattern)); | |||
| } | |||
| } | |||
| if (obj.find(kDynamicFormat) != obj.end()) { | |||
| op_info->set_dynamic_format(obj.at(kDynamicFormat)); | |||
| @@ -492,6 +492,7 @@ void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuild | |||
| if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { | |||
| builder->SetFusionType(tbe::GetFusionType(fusion_type)); | |||
| } | |||
| builder->SetOpPattern(op_info_ptr->op_pattern()); | |||
| builder->SetKernelType(TBE_KERNEL); | |||
| } | |||
| @@ -509,7 +510,7 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn | |||
| if (primitive->GetAttr("dyn_input_sizes") != nullptr) { | |||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes")); | |||
| } | |||
| if (inputs.size() > 0) { | |||
| if (!inputs.empty()) { | |||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||
| size_t kernel_info_cnt = inputs[0]->dtypes().size(); | |||
| for (size_t j = 0; j < kernel_info_cnt; j++) { | |||
| @@ -624,21 +625,17 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| for (auto parse_info : parse_info_list) { | |||
| if (context_ptr->execution_mode() == kPynativeMode) { | |||
| kernel_info_list->push_back(parse_info); | |||
| } else { | |||
| if (IsValidKernelInfo(kernel_node, *(parse_info))) { | |||
| if (CheckSupported(kernel_node, parse_info)) { | |||
| kernel_info_list->push_back(parse_info); | |||
| } else { | |||
| MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | |||
| } | |||
| for (const auto &parse_info : parse_info_list) { | |||
| if (IsValidKernelInfo(kernel_node, *(parse_info))) { | |||
| if (CheckSupported(kernel_node, parse_info)) { | |||
| kernel_info_list->push_back(parse_info); | |||
| } else { | |||
| MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | |||
| } | |||
| } | |||
| } | |||
| if (kernel_info_list->empty()) { | |||
| MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; | |||
| if (kernel_info_list->empty()) { | |||
| MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; | |||
| } | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| @@ -44,6 +44,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info); | |||
| builder->SetKernelType(AICPU_KERNEL); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), node); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" | |||
| << node->DebugString() << "]"; | |||
| @@ -657,6 +657,16 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_ | |||
| to_node->set_abstract(from_node->abstract()); | |||
| } | |||
| kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| // select_kernel_build_info() has checked whether return pointer is null | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| return build_info->op_pattern(); | |||
| } | |||
| // get KernelBuildType of node, such as ATT,RT,FWK and so on | |||
| KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| @@ -138,6 +138,8 @@ class AnfRuntimeAlgorithm { | |||
| static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, | |||
| const std::vector<std::vector<size_t>> &shapes, AnfNode *node); | |||
| static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); | |||
| // get op pattern of the node | |||
| static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); | |||
| // get KernelBuildType of node ,such as ATT,RT,FWK and so on | |||
| static KernelType GetKernelType(const AnfNodePtr &node); | |||
| // get processor type:AICORE,AICPU... | |||
| @@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| constexpr auto kAttrIsAICPUKernel = "is_ai_cpu_kernel"; | |||
| constexpr auto kIsBackendCast = "is_backed_cast"; | |||
| constexpr auto kAttrOutputNames = "output_names"; | |||
| constexpr auto kAttrVisited = "visited"; | |||
| @@ -215,10 +216,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; | |||
| constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | |||
| constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | |||
| constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | |||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||
| constexpr auto kOpFormat_NDHWC = "NDHWC"; | |||
| const std::set<std::string> kOpFormatList = { | |||
| kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, | |||
| kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, | |||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC}; | |||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | |||
| const std::set<std::string> kOptOperatorSet = { | |||
| kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, | |||