From 691b0648e30d4d64912dbd6fe405c23959e925ed Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Tue, 12 May 2020 11:09:21 +0800 Subject: [PATCH] convert unsupported kernel in aicore to aicpu --- mindspore/ccsrc/common/trans.cc | 51 ++- .../device/ascend/kernel_select_ascend.cc | 59 ++- .../device/ascend/kernel_select_ascend.h | 9 +- .../ccsrc/kernel/hccl/hccl_kernel_metadata.cc | 2 +- mindspore/ccsrc/kernel/kernel_query.cc | 46 ++- mindspore/ccsrc/kernel/kernel_query.h | 4 +- .../ccsrc/kernel/tbe/tbe_kernel_select.cc | 5 - .../ascend/ascend_backend_optimization.cc | 2 + .../pre_activate/ascend/ascend_helper.cc | 1 + .../ccsrc/pre_activate/ascend/ascend_helper.h | 13 +- .../convert_unsupported_transnode_to_aicpu.cc | 54 +++ .../convert_unsupported_transnode_to_aicpu.h | 37 ++ .../format_type/insert_cast_for_runop.h | 6 +- .../format_type/insert_transdata_for_runop.h | 6 +- .../ascend/ir_fission/topk_split.cc | 2 +- .../ir_fusion/transpose_transdata_fusion.cc | 2 +- .../ir_fusion/transpose_transdata_fusion.h | 6 +- mindspore/ccsrc/session/ascend_session.cc | 4 +- mindspore/ccsrc/session/kernel_graph.cc | 19 +- mindspore/ccsrc/utils/utils.h | 20 +- .../cpp/device/ascend_kernel_select_test.cc | 345 ------------------ .../ascend/ir_fission/topk_split_test.cc | 2 +- .../transpose_transdata_fusion_test.cc | 42 +-- 23 files changed, 266 insertions(+), 471 deletions(-) create mode 100644 mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h delete mode 100644 tests/ut/cpp/device/ascend_kernel_select_test.cc diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 225ec05196..954261f912 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -85,7 +85,7 @@ const std::map type_map = {{kNumberTypeBool, 1}, {kNumberType } while (0) template -T Ceil(T n1, T n2) { +T DivCeil(T n1, T n2) { return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; } @@ -371,15 +371,48 @@ std::vector C1hwncoc0DeviceShape(const std::vector &shape) { device_shape.push_back(kCubeSize); return device_shape; } + +std::vector FracZc04DeviceShape(const std::vector &shape) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Check dims failed."; + } + std::vector device_shape; + size_t c0 = 4; + size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize); + size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize); + device_shape.push_back(first_dim); + device_shape.push_back(no); + device_shape.push_back(kCubeSize); + device_shape.push_back(kCubeSize); + return device_shape; +} + +std::vector Nc1hwc04DeviceShape(const std::vector &shape) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Check dims failed."; + } + std::vector device_shape; + size_t C1 = 1; + size_t C0 = 4; + device_shape.push_back(shape[0]); + device_shape.push_back(C1); + device_shape.push_back(shape[2]); + device_shape.push_back(shape[3]); + device_shape.push_back(C0); + return device_shape; +} } // namespace std::vector TransShapeToDevice(const std::vector &shape, const std::string &format) { using DeviceShapeTransfer = std::function(const std::vector &)>; - const std::map device_shape_map{ - {kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape}, - {kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape}, - {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, - }; + const std::map device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, + {kOpFormat_NHWC, NhwcDeviceShape}, + {kOpFormat_HWCN, HwchDeviceShape}, + {kOpFormat_FRAC_Z, FracZDeviceShape}, + {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, + {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, + {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, + {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}}; if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { return shape; @@ -506,13 +539,13 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - size_t c1 = Ceil(c, c0); + size_t c1 = DivCeil(c, c0); size_t hw = h * w; size_t chw = c * hw; size_t hwc0 = hw * c0; size_t nchw = n * chw; - size_t hf_cnt = Ceil(n, kCubeSize); + size_t hf_cnt = DivCeil(n, kCubeSize); size_t vf_cnt = c1 * hw; size_t fractal_ele_cnt = c0 * kCubeSize; size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; @@ -775,7 +808,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - size_t c1 = Ceil(c, c0); + size_t c1 = DivCeil(c, c0); size_t hw = h * w; size_t chw = c * hw; size_t c1hwc0 = c1 * hw * c0; diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index d615d261c7..4563512c1d 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -34,6 +34,7 @@ namespace ascend { namespace { const float kWegihtBaseScore = 1; const float kFeatureMapBaseScore = 10; +constexpr auto kPriChoosenFormat = "pri_format"; enum MatchCountPriority : int { MATCH_COUNT_PRIORITY_BEGIN = 0, MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, @@ -85,6 +86,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { if (need_change_nd) { priority_matched_format = kOpFormat_DEFAULT; } + AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); return priority_matched_format; } /** @@ -394,9 +396,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, std::ostringstream buffer; buffer << cnode->DebugString(); if (precision_reduce) { - buffer << " reduce precision, node datatype: "; + buffer << " reduce precision, node datatype: \n"; } else { - buffer << " raise precision, node datatype: "; + buffer << " raise precision, node datatype: \n"; } PrintInputAndOutputInferType(buffer, cnode); buffer << ", select kernel:" << selected_kernel_build_info->ToString(); @@ -464,66 +466,57 @@ std::vector> FilterRaisedOrReducePrecis } } // namespace -std::shared_ptr CanHitKernelInfo( - int *status, const CNodePtr &kernel_node, - const std::vector> &kernel_info_list) { +KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, + const std::vector> &kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); + KernelSelectStatus select_status = kNoMatched; bool precision_reduce = false; std::shared_ptr selected_kernel_info = nullptr; + // Matched kernel info + // Filter kernel info matched with me infered type auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); if (!filtered_kernel_info_list.empty()) { selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); + select_status = kStatusAllMatched; } else { // selected kernel info using raised precision or reduce precision filtered_kernel_info_list = FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); if (selected_kernel_info == nullptr) { - return nullptr; + return select_status; } else { PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); - *status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; + select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; } } - return selected_kernel_info; + // Set kernel info to the anfnode + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); + // Set format and data type for input tensor. + SetTensorDeviceInfo(*selected_kernel_info, kernel_node); + return select_status; } -int SelectKernelInfo(const CNodePtr &kernel_node) { +KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { std::vector> kernel_info_list; - int status = kStatusAllMatched; MS_EXCEPTION_IF_NULL(kernel_node); kernel::KernelQuery(kernel_node, &kernel_info_list); - // filter kernel info matched with me infered type - auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list); - if (selected_kernel_info == nullptr) { + auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); + // If aicore not find valid kernel info reloading aicpu kernel info list to find it + if (select_status == kNoMatched) { MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; - kernel::AicpuQuery(kernel_node, &kernel_info_list); - selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list); + kernel::AICpuQuery(kernel_node, &kernel_info_list); + select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); } - if (selected_kernel_info == nullptr) { + // The kernel info not finded both in the aicpu kernel list & aicore kernel list + if (select_status == kNoMatched) { std::ostringstream buffer; PrintInputAndOutputInferType(buffer, kernel_node); MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() << "] cannot find valid kernel info, not supported the type " << buffer.str(); } - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); - // Set format and data type for input tensor. - SetTensorDeviceInfo(*selected_kernel_info, kernel_node); - return status; -} - -bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, - const kernel::KernelBuildInfoPtr &new_kernel_build_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector> kernel_info_list; - kernel::KernelQuery(kernel_node, &kernel_info_list); - auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(), - [&new_kernel_build_info](const kernel::KernelBuildInfoPtr item) { - MS_EXCEPTION_IF_NULL(item); - return *item == *new_kernel_build_info; - }); - return result != kernel_info_list.end(); + return select_status; } } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.h b/mindspore/ccsrc/device/ascend/kernel_select_ascend.h index af353815bf..c4c777c18a 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.h +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.h @@ -21,8 +21,13 @@ namespace mindspore { namespace device { namespace ascend { -int SelectKernelInfo(const CNodePtr &kernel_node); -bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info); +enum KernelSelectStatus { + kNoMatched = -1, + kStatusAllMatched = 0, + kStatusReducePrecision = 1, + kStatusRaisePrecision = 2, +}; +KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node); } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc index dc26045374..6c101c92bb 100755 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc @@ -35,7 +35,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector input_format, output_format; std::vector input_type, output_type; for (const auto &data_type : data_type_list) { - for (const auto &format : k4DSupportFormat) { + for (const auto &format : kOpFormatList) { auto builder = std::make_shared(); input_format.clear(); input_format.push_back(format); diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index a2a5958a3f..8cdf91fd9f 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -35,14 +35,18 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); }); - kernel_info_list->clear(); if (!filtered_list.empty()) { + kernel_info_list->clear(); (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); } else { - MS_LOG(EXCEPTION) << "node" << kernel_node->DebugString() << "'s output size : [" - << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" - << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) - << "] cannot match any kernelInfo !"; + MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; + for (size_t index; index < kernel_info_list->size(); ++index) { + MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); + MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); + } + MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" + << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" + << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; } } } // namespace @@ -50,7 +54,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorempty()) { AicpuMetadataInfo(kernel_node, kernel_info_list); } @@ -68,12 +71,41 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { +void AICpuQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); kernel_info_list->clear(); AicpuMetadataInfo(kernel_node, kernel_info_list); FilterInvalidKernelInfo(kernel_node, kernel_info_list); } +bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(select_kernel_build_info); + std::vector> kernel_info_list; + auto cnode = kernel_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + AicpuMetadataInfo(cnode, &kernel_info_list); + FilterInvalidKernelInfo(cnode, &kernel_info_list); + return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), + [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { + MS_EXCEPTION_IF_NULL(item); + return *item == *select_kernel_build_info; + }); +} + +bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(select_kernel_build_info); + std::vector> kernel_info_list; + auto cnode = kernel_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + TbeMetadataInfo(cnode, &kernel_info_list); + FilterInvalidKernelInfo(cnode, &kernel_info_list); + return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), + [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { + MS_EXCEPTION_IF_NULL(item); + return *item == *select_kernel_build_info; + }); +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_query.h b/mindspore/ccsrc/kernel/kernel_query.h index 3e16b6b612..52ab018988 100644 --- a/mindspore/ccsrc/kernel/kernel_query.h +++ b/mindspore/ccsrc/kernel/kernel_query.h @@ -26,7 +26,9 @@ namespace mindspore { namespace kernel { void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -void AicpuQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +void AICpuQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); +bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index 9aa5784966..33743b3175 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -551,11 +551,6 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &shape, const std::string &format) { - const std::set kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, - kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, - kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, - kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; - // if format is default, it remarkes support all format if (kOpFormatList.find(format) == kOpFormatList.end()) { MS_LOG(EXCEPTION) << "Got the unknown format " << format; diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 3b36ceef09..356926e2ff 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -54,6 +54,7 @@ #include "pre_activate/pass/optimize_dependence.h" #include "pre_activate/pass/erase_visit_attr.h" #include "pre_activate/ascend/format_type/insert_cast.h" +#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" #include "pre_activate/pass/eliminate_redundant_op.h" #include "pre_activate/pass/common_subexpression_elimination.h" #include "pre_activate/ascend/format_type/merge_cast_to_op.h" @@ -172,6 +173,7 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); optimizer->AddPassManager(mixed_precision_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 9fdb2080b3..7e503ef349 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -268,6 +268,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr } AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); return cast; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index 1840966358..2f270b109b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -30,10 +30,6 @@ class KernelSelect { KernelSelect() = default; virtual ~KernelSelect() = default; virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } - virtual bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, - const kernel::KernelBuildInfoPtr &new_kernel_build_info) { - return device::ascend::CheckKernelAccuracySupported(kernel_node, new_kernel_build_info); - } }; using KernelSelectPtr = std::shared_ptr; @@ -41,8 +37,13 @@ class SupportedChecker { public: SupportedChecker() = default; virtual ~SupportedChecker() = default; - virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - return kernel::CheckSupported(anf_node, select_kernel_build_info); + virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, + const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info); + } + virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, + const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info); } }; using SupportedCheckerPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc new file mode 100644 index 0000000000..120462fd53 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" +#include +#include "session/anf_runtime_algorithm.h" +#include "kernel/kernel_build_info.h" +#include "kernel/kernel_query.h" +namespace mindspore { +namespace opt { +const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &, + const mindspore::AnfNodePtr &node, + const mindspore::EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto node_name = AnfAlgo::GetCNodeName(node); + if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) { + return nullptr; + } + auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); + if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) { + return node; + } else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) { + auto builder = std::make_shared(kernel_builder_info); + builder->SetKernelType(AICPU_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + } else { + MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" + << node->DebugString() << "]"; + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h new file mode 100644 index 0000000000..80cc8170ac --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "pre_activate/common/optimizer.h" +#include "pre_activate/ascend/ascend_helper.h" +#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H +#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H +namespace mindspore { +namespace opt { +class ConvertUnSupportNodeToAICPU : public PatternProcessPass { + public: + explicit ConvertUnSupportNodeToAICPU(bool multigraph = true) + : PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph), + supported_checker_(std::make_shared()) {} + ~ConvertUnSupportNodeToAICPU() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + SupportedCheckerPtr supported_checker_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h index 8bc42eb26a..4467cc5198 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ -#define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ #include #include "pre_activate/common/optimizer.h" @@ -32,4 +32,4 @@ class RunOpInsertCast : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h index 298a3deda9..f699cdd580 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ -#define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ #include #include @@ -41,4 +41,4 @@ class RunOpInsertTransData : public PatternProcessPass { } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc index 4bdd5f0382..64f5ba0cf6 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc @@ -128,7 +128,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod auto indices_const = CreateValueNode(new_cnode); new_cnode->add_input(indices_const); MS_EXCEPTION_IF_NULL(supported_checker_); - if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) { + if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc index 1386005d1b..1651718703 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc @@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); auto new_fusion_transdata = std::make_shared(kTransDataOpName); - if (kernel_select_->CheckKernelAccuracySupported(transdata_cnode, new_transdata_builder->Build())) { + if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) { std::vector inputs = {NewValueNode(new_fusion_transdata), utils::cast((*equiv)[input_varptr_])}; auto new_node = func_graph->NewCNode(inputs); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h index bc9f17f340..833588cf45 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h @@ -34,7 +34,7 @@ class TransposeTransDataFusion : public PatternProcessPass { explicit TransposeTransDataFusion(bool multigraph = true) : PatternProcessPass("transpose_transdata_fusion", multigraph) { input_varptr_ = std::make_shared(); - kernel_select_ = std::make_shared(); + supported_checker_ = std::make_shared(); } ~TransposeTransDataFusion() override = default; const BaseRef DefinePattern() const override; @@ -42,7 +42,9 @@ class TransposeTransDataFusion : public PatternProcessPass { private: VarPtr input_varptr_; - KernelSelectPtr kernel_select_; + + private: + SupportedCheckerPtr supported_checker_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 1311cb465e..deec2c648a 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -329,9 +329,9 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { size_t reduce_precision_count = 0; for (const auto &cnode : kernel_graph.execution_order()) { auto status = device::ascend::SelectKernelInfo(cnode); - if (status == kStatusRaisePrecision) { + if (status == device::ascend::kStatusRaisePrecision) { raise_precision_count++; - } else if (status == kStatusReducePrecision) { + } else if (status == device::ascend::kStatusReducePrecision) { reduce_precision_count++; } MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 95ac38c405..24b30b233b 100755 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -27,6 +27,8 @@ namespace mindspore { namespace session { namespace { +constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; +constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, std::unordered_set *visited_nodes) { MS_EXCEPTION_IF_NULL(que); @@ -180,11 +182,24 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { cnode->set_abstract(std::make_shared()); // create kernel_info from new parameter auto kernel_info = std::make_shared(); + std::vector feature_map_input_indexs; // if the node only has the primitive(such as getNext) or the node's input has a feature map input // then the node's output is a feature map output - if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(), - [&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) { + for (size_t index = 1; index < inputs.size(); ++index) { + auto node = inputs[index]; + if (AnfAlgo::IsFeatureMapOutput(node)) { + feature_map_input_indexs.push_back(index); + } + } + if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); + } + if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { kernel_info->SetFeatureMapFlag(true); + AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(true), cnode); + AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); + } else { + AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(false), cnode); } cnode->set_kernel_info(kernel_info); AnfAlgo::SetGraphId(graph_id_, cnode.get()); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index d5313247d2..34826accf0 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -139,6 +139,7 @@ constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2D // attr key name constexpr auto kAttrInputNames = "input_names"; +constexpr auto kIsBackendCast = "is_backed_cast"; constexpr auto kAttrOutputNames = "output_names"; constexpr auto kAttrVisited = "visited"; constexpr auto kAttrShape = "shape"; @@ -196,10 +197,6 @@ constexpr auto kControlDependBehindIndex = 2; // index define of depend constexpr auto kRealInputIndexInDepend = 1; constexpr auto kDependAttachNodeIndex = 2; -// status of kernel select result -const int kStatusReducePrecision = -1; -const int kStatusRaisePrecision = 1; -const int kStatusAllMatched = 0; // format constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; @@ -213,18 +210,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; -const std::set k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, - kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, - kOpFormat_C1HWNCoC0}; - -const std::set k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z, - kOpFormat_NC1KHKWHWC0}; -const std::set k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0}; -const std::set k4DSupportFormat = k1DSupportFormat; -const std::vector> kShapeSupportFormatMap = {k1DSupportFormat, k2DSupportFormat, k3DSupportFormat, - k4DSupportFormat}; +const std::set kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, + kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, + kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, + kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; const std::set kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; - const std::set kOptOperatorSet = { kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName, diff --git a/tests/ut/cpp/device/ascend_kernel_select_test.cc b/tests/ut/cpp/device/ascend_kernel_select_test.cc deleted file mode 100644 index 79986d375d..0000000000 --- a/tests/ut/cpp/device/ascend_kernel_select_test.cc +++ /dev/null @@ -1,345 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "mindspore/ccsrc/device/ascend/kernel_select_ascend.h" -#include "common/common_test.h" -#include "session/kernel_graph.h" -#include "kernel/kernel.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "mindspore/ccsrc/device/kernel_info.h" -#include "mindspore/ccsrc/kernel/kernel_build_info.h" -#include -namespace mindspore { -namespace device { -namespace ascend { -namespace { -using KernelInfo = device::KernelInfo; -using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; -using KernelBuildInfo = kernel::KernelBuildInfo; -using KernelGraph = session::KernelGraph; -using KernelBuildInfoPtr = std::shared_ptr; -using KernelBuilderPtr = std::shared_ptr; -using Shape = std::vector; -using ShapeList = std::vector; -enum MatchCountPriority { - MATCH_COUNT_PRIORITY_BEGIN = 0, - MATCH_FORMAT_COUNT = MATCH_COUNT_PRIORITY_BEGIN, - MATCH_DTYPE_COUNT, - MATCH_NZ_FORMAT_COUNT, - MATCH_5D_FORMAT_COUNT, - MATCH_OUTPUT_DTYPE_COUNT, - MATCH_COUNT_PRIORITY_END -}; - -const std::set kOpFormatList = { - kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, - kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ}; - -bool IsShapeMatchFormat(const std::vector &shape, const std::string &format) { - // if format is default,it remarkes support all format - if (kOpFormatList.find(format) == kOpFormatList.end()) { - MS_EXCEPTION(ArgumentError) << "got the unknow format " << format; - } - if (format == kOpFormat_DEFAULT) { - return true; - } - // if shape size is 0,the shape will be a scalar - if (shape.empty()) { - return true; - } - if (shape.size() > kShapeSupportFormatMap.size()) { - return false; - } - if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { - return shape[shape.size() - 1] % 16 != 0 && shape[shape.size() - 2] % 16 != 0; - } - return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); -} - -bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto check_function = [](const std::vector &shape, const std::string &format) -> bool { - if (!IsShapeMatchFormat(shape, format)) { - return false; - } - for (auto shape_value : shape) { - if (shape_value == 0) { - MS_EXCEPTION(ArgumentError) << "dimension size of the tensor shape should be a positive integer, but got [" - << shape_value << "]"; - } - } - return true; - }; - for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); - if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { - return false; - } - } - for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); - if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { - return false; - } - } - return true; -} - -bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { - MS_EXCEPTION_IF_NULL(cnode); - // Check input data type - for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { - AnfNodePtr cur_input = cnode->input(input_index + 1); - MS_EXCEPTION_IF_NULL(cur_input); - TypeId input_origin_type; - if (cur_input->isa() && AnfAlgo::IsParameterWeight(cur_input->cast())) { - // weight - input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); - } else { - // feature map - input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); - } - if (input_origin_type == kTypeUnknown) { - continue; - } - if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { - return false; - } - } - // Check output data type - for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { - if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) { - return false; - } - } - return true; -} - -/** - * compare too vector by priority,select a better vector,like compare too num,first compare highest num location,if - * equal then next num location - * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] - */ -bool PriorityChooseItem(const std::vector &cur_item, std::vector *best_item) { - MS_EXCEPTION_IF_NULL(best_item); - if (cur_item.size() != best_item->size()) { - MS_LOG(ERROR) << "item size should be same!"; - return false; - } - // Update the best_item by comparing the cur_item and best_item - for (size_t i = 0; i < cur_item.size(); i++) { - if (cur_item[i] > best_item->at(i)) { - *best_item = cur_item; - return true; - } else if (cur_item[i] == best_item->at(i)) { - continue; - } else { - return false; - } - } - return false; -} - -void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node, - std::vector *const cur_kernelinfo_match_counts) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts); - if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { - MS_EXCEPTION(ArgumentError) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; - } - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - AnfNodePtr input_anf_node = kernel_node->input(input_index + 1); - MS_EXCEPTION_IF_NULL(input_anf_node); - // if a input parameter is a weight with default format, the input shouldn't participate the judge - if (input_anf_node->isa()) { - auto para = input_anf_node->cast(); - if (AnfAlgo::IsParameterWeight(para) && AnfAlgo::GetOutputDeviceDataType(para, 0) == kTypeUnknown) { - continue; - } - } - if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { - (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; - } - if (kernel_build_info.GetInputDeviceType(input_index) == - AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { - (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT]++; - } - if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_FRAC_NZ) { - (*cur_kernelinfo_match_counts)[MATCH_NZ_FORMAT_COUNT]++; - } - if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_NC1HWC0) { - (*cur_kernelinfo_match_counts)[MATCH_5D_FORMAT_COUNT]++; - } - } - - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - // cal count of same output dtype between abstract and kernel info - if (kernel_build_info.GetOutputDeviceType(output_index) == - AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { - (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; - } - } -} - -void SetKernelBuildInfo(KernelBuilderPtr builder) { - builder->SetFusionType(kernel::OPAQUE); - builder->SetKernelType(AUTO_DIFF_KERNEL); - builder->SetProcessor(kernel::AICORE); -} - -void test_select(const CNodePtr &kernel_node, std::vector> kernel_info_list) { - std::vector most_match_counts = {-1, -1, -1, -1, -1}; - int selected_index = -1; - for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { - std::vector cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; - if (!IsValidKernelInfo(kernel_node, *(kernel_info_list[info_index]))) { - continue; - } - if (!MatchInferOutputDataType(kernel_node, *(kernel_info_list[info_index]))) { - continue; - } - std::shared_ptr kernel_info_ptr = kernel_info_list[info_index]; - UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); - // Currently the selection policy is the match format count first, and then is datatype counts. - if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { - selected_index = SizeToInt(info_index); - } - } - if (selected_index == -1) { - MS_EXCEPTION(NotExistsError) << "" << kernel_node->DebugString() << " Cannot find valid kernel Info !"; - } - auto index = IntToSize(selected_index); - if (index >= kernel_info_list.size()) { - MS_EXCEPTION(ArgumentError) << "index outof range"; - } - std::shared_ptr selected_kernel_info_ptr = kernel_info_list[index]; - MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get()); -} - -void SetParentAbstract(std::vector parent_list, std::vector> shapes, - std::vector types) { - for (const auto &node : parent_list) { - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, node.get()); - } -} -} // namespace -class AscendKernelSelctTest : public UT::Common { - public: - AscendKernelSelctTest() = default; - void SetUp() override {} - void TearDown() override {} -}; - -TEST_F(AscendKernelSelctTest, TestSelect) { - std::vector build_list; - std::vector type_list = {kNumberTypeFloat32}; - for (size_t i = 0; i <= 4; ++i) { - build_list.push_back(std::make_shared()); - SetKernelBuildInfo(build_list[i]); - build_list[i]->SetInputsDeviceType(type_list); - build_list[i]->SetOutputsDeviceType(type_list); - } - - std::vector nd_fmt = {kOpFormat_DEFAULT}; - std::vector nz_fmt = {kOpFormat_FRAC_NZ}; - auto anf_graph = std::make_shared(); - - // 16's multiple should not chose format NZ - Shape nd_shapes = {2, 32, 224, 224}; - - Shape nz_shapes = {3, 3, 5, 5}; - auto add_value = NewValueNode(prim::kPrimTensorAdd); - auto a_node = anf_graph->NewCNode(std::vector{add_value}); - auto b_node = anf_graph->NewCNode(std::vector{add_value}); - std::vector parent_list = {add_value, a_node, b_node}; - - auto c_node = anf_graph->NewCNode(parent_list); - - // a b - // \ / - // c - // a & b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}} - // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} - // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3,224, 224}} - - // set a & b's info - SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list); - // set abstract c - AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nd_shapes}, c_node.get()); - // set format of kernel info - build_list[0]->SetOutputsFormat(nz_fmt); - build_list[1]->SetOutputsFormat(nz_fmt); - - build_list[2]->SetInputsFormat(std::vector{nd_fmt[0], nd_fmt[0]}); - build_list[3]->SetInputsFormat(std::vector{nz_fmt[0], nz_fmt[0]}); - build_list[2]->SetInputsDeviceType(std::vector{kNumberTypeFloat32, kNumberTypeFloat32}); - build_list[3]->SetInputsDeviceType(std::vector{kNumberTypeFloat32, kNumberTypeFloat32}); - build_list[2]->SetOutputsFormat(nd_fmt); - build_list[3]->SetOutputsFormat(nz_fmt); - std::vector select_info_list; - // set select info list - select_info_list.emplace_back(build_list[2]->Build()); - select_info_list.emplace_back(build_list[3]->Build()); - - // set device info for a & b - AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get()); - AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get()); - - test_select(c_node, select_info_list); - EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT); - EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_DEFAULT); - - // set a & b's info - // a b - // \ / - // c - // a: kernel_info:{output_format:{5d},dtype:{kNumberTypeFloat32}} - // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} - // b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}} - // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} - // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} - - // set a & b's info - SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list); - // set abstract c - AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nz_shapes}, c_node.get()); - // set format of kernel info - build_list[0]->SetOutputsFormat(std::vector{kOpFormat_NC1HWC0}); - build_list[1]->SetOutputsFormat(nz_fmt); - - build_list[2]->SetInputsFormat(std::vector{kOpFormat_NC1HWC0, nd_fmt[0]}); - build_list[3]->SetInputsFormat(std::vector{nd_fmt[0], nz_fmt[0]}); - build_list[2]->SetInputsDeviceType(std::vector{kNumberTypeFloat32, kNumberTypeFloat32}); - build_list[3]->SetInputsDeviceType(std::vector{kNumberTypeFloat32, kNumberTypeFloat32}); - build_list[2]->SetOutputsFormat(nd_fmt); - build_list[3]->SetOutputsFormat(nz_fmt); - // set select info list - select_info_list.emplace_back(build_list[2]->Build()); - select_info_list.emplace_back(build_list[3]->Build()); - - // set device info for a & b - AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get()); - AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get()); - - test_select(c_node, select_info_list); - EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT); - EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_FRAC_NZ); -} -} // namespace ascend -} // namespace device -} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index 43ddc046b7..1a3440e780 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -39,7 +39,7 @@ class MockSupportedChecker : public SupportedChecker { public: MockSupportedChecker() = default; ~MockSupportedChecker() override = default; - bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { return true; } }; // namespace opt diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index af59ef7e9a..25cd12edfe 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -37,6 +37,15 @@ class TestHWTransposeTransdataFusion : public BackendCommon { UT::PyFuncGraphFetcher get_py_fun_; }; +class MockSupportedChecker : public SupportedChecker { + public: + MockSupportedChecker() = default; + ~MockSupportedChecker() override = default; + bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + return true; + } +}; + class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { public: MockInsertTransOpKernelSelectTrans4Dto5D() = default; @@ -60,37 +69,6 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { } }; -class MockTransposeTransdataFusionKernelSelect : public KernelSelect { - public: - MockTransposeTransdataFusionKernelSelect() = default; - ~MockTransposeTransdataFusionKernelSelect() override = default; - bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, - const kernel::KernelBuildInfoPtr &new_kernel_build_info) override { - std::vector> kernel_info_list; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetInputsFormat({kOpFormat_NCHW}); - builder.SetOutputsFormat({kOpFormat_DEFAULT}); - builder.SetInputsDeviceType({kNumberTypeFloat16}); - builder.SetOutputsDeviceType({kNumberTypeFloat16}); - builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); - builder.SetFusionType(kernel::FusionType::OPAQUE); - builder.SetProcessor(kernel::Processor::AICORE); - kernel_info_list.push_back(builder.Build()); - MS_LOG(INFO) << "transpose transdata fusion success"; - MS_LOG(INFO) << "new transdata build info input format:" << new_kernel_build_info->GetInputFormat(0) - << ",outputformat:" << new_kernel_build_info->GetOutputFormat(0) - << ",kerneltype:" << new_kernel_build_info->kernel_type() - << ",fusiontype:" << new_kernel_build_info->fusion_type() - << ",process:" << new_kernel_build_info->processor(); - auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(), - [&new_kernel_build_info](kernel::KernelBuildInfoPtr item) { - MS_EXCEPTION_IF_NULL(item); - return *item == *new_kernel_build_info; - }); - return result != kernel_info_list.end(); - } -}; - TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { /* * def before(input0, input1): @@ -128,7 +106,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { insert_trans_op_pass->kernel_select_ = std::make_shared(); pm->AddPass(insert_trans_op_pass); auto transpose_transdata_pass = std::make_shared(); - transpose_transdata_pass->kernel_select_ = std::make_shared(); + transpose_transdata_pass->supported_checker_ = std::make_shared(); pm->AddPass(transpose_transdata_pass); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(kg);