Merge pull request !1079 from lianliguang/convert-to-AICPU-when-AiCore-not-supported-kerneltags/v0.3.0-alpha
| @@ -85,7 +85,7 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType | |||
| } while (0) | |||
| template <typename T> | |||
| T Ceil(T n1, T n2) { | |||
| T DivCeil(T n1, T n2) { | |||
| return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | |||
| } | |||
| @@ -371,15 +371,48 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| std::vector<size_t> device_shape; | |||
| size_t c0 = 4; | |||
| size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize); | |||
| size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize); | |||
| device_shape.push_back(first_dim); | |||
| device_shape.push_back(no); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| std::vector<size_t> device_shape; | |||
| size_t C1 = 1; | |||
| size_t C0 = 4; | |||
| device_shape.push_back(shape[0]); | |||
| device_shape.push_back(C1); | |||
| device_shape.push_back(shape[2]); | |||
| device_shape.push_back(shape[3]); | |||
| device_shape.push_back(C0); | |||
| return device_shape; | |||
| } | |||
| } // namespace | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | |||
| using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; | |||
| const std::map<std::string, DeviceShapeTransfer> device_shape_map{ | |||
| {kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape}, | |||
| {kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape}, | |||
| {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | |||
| }; | |||
| const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, | |||
| {kOpFormat_NHWC, NhwcDeviceShape}, | |||
| {kOpFormat_HWCN, HwchDeviceShape}, | |||
| {kOpFormat_FRAC_Z, FracZDeviceShape}, | |||
| {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, | |||
| {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | |||
| {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, | |||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}}; | |||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | |||
| return shape; | |||
| @@ -506,13 +539,13 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| } | |||
| size_t c1 = Ceil(c, c0); | |||
| size_t c1 = DivCeil(c, c0); | |||
| size_t hw = h * w; | |||
| size_t chw = c * hw; | |||
| size_t hwc0 = hw * c0; | |||
| size_t nchw = n * chw; | |||
| size_t hf_cnt = Ceil(n, kCubeSize); | |||
| size_t hf_cnt = DivCeil(n, kCubeSize); | |||
| size_t vf_cnt = c1 * hw; | |||
| size_t fractal_ele_cnt = c0 * kCubeSize; | |||
| size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | |||
| @@ -775,7 +808,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| } | |||
| size_t c1 = Ceil(c, c0); | |||
| size_t c1 = DivCeil(c, c0); | |||
| size_t hw = h * w; | |||
| size_t chw = c * hw; | |||
| size_t c1hwc0 = c1 * hw * c0; | |||
| @@ -34,6 +34,7 @@ namespace ascend { | |||
| namespace { | |||
| const float kWegihtBaseScore = 1; | |||
| const float kFeatureMapBaseScore = 10; | |||
| constexpr auto kPriChoosenFormat = "pri_format"; | |||
| enum MatchCountPriority : int { | |||
| MATCH_COUNT_PRIORITY_BEGIN = 0, | |||
| MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, | |||
| @@ -85,6 +86,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { | |||
| if (need_change_nd) { | |||
| priority_matched_format = kOpFormat_DEFAULT; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); | |||
| return priority_matched_format; | |||
| } | |||
| /** | |||
| @@ -394,9 +396,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, | |||
| std::ostringstream buffer; | |||
| buffer << cnode->DebugString(); | |||
| if (precision_reduce) { | |||
| buffer << " reduce precision, node datatype: "; | |||
| buffer << " reduce precision, node datatype: \n"; | |||
| } else { | |||
| buffer << " raise precision, node datatype: "; | |||
| buffer << " raise precision, node datatype: \n"; | |||
| } | |||
| PrintInputAndOutputInferType(buffer, cnode); | |||
| buffer << ", select kernel:" << selected_kernel_build_info->ToString(); | |||
| @@ -464,66 +466,57 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis | |||
| } | |||
| } // namespace | |||
| std::shared_ptr<kernel::KernelBuildInfo> CanHitKernelInfo( | |||
| int *status, const CNodePtr &kernel_node, | |||
| const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | |||
| KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | |||
| const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| KernelSelectStatus select_status = kNoMatched; | |||
| bool precision_reduce = false; | |||
| std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; | |||
| // Matched kernel info | |||
| // Filter kernel info matched with me infered type | |||
| auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); | |||
| if (!filtered_kernel_info_list.empty()) { | |||
| selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | |||
| select_status = kStatusAllMatched; | |||
| } else { | |||
| // selected kernel info using raised precision or reduce precision | |||
| filtered_kernel_info_list = | |||
| FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); | |||
| selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | |||
| if (selected_kernel_info == nullptr) { | |||
| return nullptr; | |||
| return select_status; | |||
| } else { | |||
| PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); | |||
| *status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; | |||
| select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; | |||
| } | |||
| } | |||
| return selected_kernel_info; | |||
| // Set kernel info to the anfnode | |||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); | |||
| // Set format and data type for input tensor. | |||
| SetTensorDeviceInfo(*selected_kernel_info, kernel_node); | |||
| return select_status; | |||
| } | |||
| int SelectKernelInfo(const CNodePtr &kernel_node) { | |||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| int status = kStatusAllMatched; | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | |||
| // filter kernel info matched with me infered type | |||
| auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list); | |||
| if (selected_kernel_info == nullptr) { | |||
| auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | |||
| // If aicore not find valid kernel info reloading aicpu kernel info list to find it | |||
| if (select_status == kNoMatched) { | |||
| MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() | |||
| << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | |||
| kernel::AicpuQuery(kernel_node, &kernel_info_list); | |||
| selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list); | |||
| kernel::AICpuQuery(kernel_node, &kernel_info_list); | |||
| select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | |||
| } | |||
| if (selected_kernel_info == nullptr) { | |||
| // The kernel info not finded both in the aicpu kernel list & aicore kernel list | |||
| if (select_status == kNoMatched) { | |||
| std::ostringstream buffer; | |||
| PrintInputAndOutputInferType(buffer, kernel_node); | |||
| MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() | |||
| << "] cannot find valid kernel info, not supported the type " << buffer.str(); | |||
| } | |||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); | |||
| // Set format and data type for input tensor. | |||
| SetTensorDeviceInfo(*selected_kernel_info, kernel_node); | |||
| return status; | |||
| } | |||
| bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, | |||
| const kernel::KernelBuildInfoPtr &new_kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | |||
| auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(), | |||
| [&new_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| return *item == *new_kernel_build_info; | |||
| }); | |||
| return result != kernel_info_list.end(); | |||
| return select_status; | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| @@ -21,8 +21,13 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| int SelectKernelInfo(const CNodePtr &kernel_node); | |||
| bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info); | |||
| enum KernelSelectStatus { | |||
| kNoMatched = -1, | |||
| kStatusAllMatched = 0, | |||
| kStatusReducePrecision = 1, | |||
| kStatusRaisePrecision = 2, | |||
| }; | |||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node); | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -35,7 +35,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K | |||
| std::vector<std::string> input_format, output_format; | |||
| std::vector<TypeId> input_type, output_type; | |||
| for (const auto &data_type : data_type_list) { | |||
| for (const auto &format : k4DSupportFormat) { | |||
| for (const auto &format : kOpFormatList) { | |||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| input_format.clear(); | |||
| input_format.push_back(format); | |||
| @@ -35,14 +35,18 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, | |||
| return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && | |||
| AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); | |||
| }); | |||
| kernel_info_list->clear(); | |||
| if (!filtered_list.empty()) { | |||
| kernel_info_list->clear(); | |||
| (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "node" << kernel_node->DebugString() << "'s output size : [" | |||
| << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" | |||
| << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) | |||
| << "] cannot match any kernelInfo !"; | |||
| MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; | |||
| for (size_t index; index < kernel_info_list->size(); ++index) { | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); | |||
| MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); | |||
| } | |||
| MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" | |||
| << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" | |||
| << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; | |||
| } | |||
| } | |||
| } // namespace | |||
| @@ -50,7 +54,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| TbeMetadataInfo(kernel_node, kernel_info_list); | |||
| if (kernel_info_list->empty()) { | |||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | |||
| } | |||
| @@ -68,12 +71,41 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | |||
| } | |||
| void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||
| void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| kernel_info_list->clear(); | |||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | |||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | |||
| } | |||
| bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(select_kernel_build_info); | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| auto cnode = kernel_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| AicpuMetadataInfo(cnode, &kernel_info_list); | |||
| FilterInvalidKernelInfo(cnode, &kernel_info_list); | |||
| return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | |||
| [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| return *item == *select_kernel_build_info; | |||
| }); | |||
| } | |||
| bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(select_kernel_build_info); | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| auto cnode = kernel_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| TbeMetadataInfo(cnode, &kernel_info_list); | |||
| FilterInvalidKernelInfo(cnode, &kernel_info_list); | |||
| return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | |||
| [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| return *item == *select_kernel_build_info; | |||
| }); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -26,7 +26,9 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||
| void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||
| void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||
| bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||
| bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ | |||
| @@ -551,11 +551,6 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn | |||
| } | |||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||
| // if format is default, it remarkes support all format | |||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | |||
| @@ -54,6 +54,7 @@ | |||
| #include "pre_activate/pass/optimize_dependence.h" | |||
| #include "pre_activate/pass/erase_visit_attr.h" | |||
| #include "pre_activate/ascend/format_type/insert_cast.h" | |||
| #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | |||
| #include "pre_activate/pass/eliminate_redundant_op.h" | |||
| #include "pre_activate/pass/common_subexpression_elimination.h" | |||
| #include "pre_activate/ascend/format_type/merge_cast_to_op.h" | |||
| @@ -172,6 +173,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); | |||
| optimizer->AddPassManager(mixed_precision_pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| @@ -268,6 +268,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| } | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | |||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); | |||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); | |||
| return cast; | |||
| } | |||
| @@ -30,10 +30,6 @@ class KernelSelect { | |||
| KernelSelect() = default; | |||
| virtual ~KernelSelect() = default; | |||
| virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } | |||
| virtual bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, | |||
| const kernel::KernelBuildInfoPtr &new_kernel_build_info) { | |||
| return device::ascend::CheckKernelAccuracySupported(kernel_node, new_kernel_build_info); | |||
| } | |||
| }; | |||
| using KernelSelectPtr = std::shared_ptr<KernelSelect>; | |||
| @@ -41,8 +37,13 @@ class SupportedChecker { | |||
| public: | |||
| SupportedChecker() = default; | |||
| virtual ~SupportedChecker() = default; | |||
| virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||
| return kernel::CheckSupported(anf_node, select_kernel_build_info); | |||
| virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, | |||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||
| return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info); | |||
| } | |||
| virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, | |||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||
| return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info); | |||
| } | |||
| }; | |||
| using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | |||
| #include <memory> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| #include "kernel/kernel_query.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({X, Xs}); | |||
| } | |||
| const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &, | |||
| const mindspore::AnfNodePtr &node, | |||
| const mindspore::EquivPtr &) const { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto node_name = AnfAlgo::GetCNodeName(node); | |||
| if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) { | |||
| return nullptr; | |||
| } | |||
| auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); | |||
| if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) { | |||
| return node; | |||
| } else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) { | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info); | |||
| builder->SetKernelType(AICPU_KERNEL); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" | |||
| << node->DebugString() << "]"; | |||
| } | |||
| return node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| #ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H | |||
| #define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ConvertUnSupportNodeToAICPU : public PatternProcessPass { | |||
| public: | |||
| explicit ConvertUnSupportNodeToAICPU(bool multigraph = true) | |||
| : PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph), | |||
| supported_checker_(std::make_shared<SupportedChecker>()) {} | |||
| ~ConvertUnSupportNodeToAICPU() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| SupportedCheckerPtr supported_checker_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ | |||
| #define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||
| #include <string> | |||
| #include "pre_activate/common/optimizer.h" | |||
| @@ -32,4 +32,4 @@ class RunOpInsertCast : public PatternProcessPass { | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||
| #define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||
| #include <string> | |||
| #include <utility> | |||
| @@ -41,4 +41,4 @@ class RunOpInsertTransData : public PatternProcessPass { | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||
| @@ -128,7 +128,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod | |||
| auto indices_const = CreateValueNode(new_cnode); | |||
| new_cnode->add_input(indices_const); | |||
| MS_EXCEPTION_IF_NULL(supported_checker_); | |||
| if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) { | |||
| if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { | |||
| return nullptr; | |||
| } | |||
| @@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap | |||
| new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); | |||
| auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName); | |||
| if (kernel_select_->CheckKernelAccuracySupported(transdata_cnode, new_transdata_builder->Build())) { | |||
| if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) { | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata), | |||
| utils::cast<AnfNodePtr>((*equiv)[input_varptr_])}; | |||
| auto new_node = func_graph->NewCNode(inputs); | |||
| @@ -34,7 +34,7 @@ class TransposeTransDataFusion : public PatternProcessPass { | |||
| explicit TransposeTransDataFusion(bool multigraph = true) | |||
| : PatternProcessPass("transpose_transdata_fusion", multigraph) { | |||
| input_varptr_ = std::make_shared<Var>(); | |||
| kernel_select_ = std::make_shared<KernelSelect>(); | |||
| supported_checker_ = std::make_shared<SupportedChecker>(); | |||
| } | |||
| ~TransposeTransDataFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| @@ -42,7 +42,9 @@ class TransposeTransDataFusion : public PatternProcessPass { | |||
| private: | |||
| VarPtr input_varptr_; | |||
| KernelSelectPtr kernel_select_; | |||
| private: | |||
| SupportedCheckerPtr supported_checker_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -329,9 +329,9 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { | |||
| size_t reduce_precision_count = 0; | |||
| for (const auto &cnode : kernel_graph.execution_order()) { | |||
| auto status = device::ascend::SelectKernelInfo(cnode); | |||
| if (status == kStatusRaisePrecision) { | |||
| if (status == device::ascend::kStatusRaisePrecision) { | |||
| raise_precision_count++; | |||
| } else if (status == kStatusReducePrecision) { | |||
| } else if (status == device::ascend::kStatusReducePrecision) { | |||
| reduce_precision_count++; | |||
| } | |||
| MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); | |||
| @@ -27,6 +27,8 @@ | |||
| namespace mindspore { | |||
| namespace session { | |||
| namespace { | |||
| constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | |||
| constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | |||
| void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | |||
| std::unordered_set<AnfNodePtr> *visited_nodes) { | |||
| MS_EXCEPTION_IF_NULL(que); | |||
| @@ -180,11 +182,24 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| cnode->set_abstract(std::make_shared<abstract::AbstractNone>()); | |||
| // create kernel_info from new parameter | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| std::vector<size_t> feature_map_input_indexs; | |||
| // if the node only has the primitive(such as getNext) or the node's input has a feature map input | |||
| // then the node's output is a feature map output | |||
| if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(), | |||
| [&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) { | |||
| for (size_t index = 1; index < inputs.size(); ++index) { | |||
| auto node = inputs[index]; | |||
| if (AnfAlgo::IsFeatureMapOutput(node)) { | |||
| feature_map_input_indexs.push_back(index); | |||
| } | |||
| } | |||
| if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { | |||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); | |||
| } | |||
| if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { | |||
| kernel_info->SetFeatureMapFlag(true); | |||
| AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(true), cnode); | |||
| AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); | |||
| } else { | |||
| AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(false), cnode); | |||
| } | |||
| cnode->set_kernel_info(kernel_info); | |||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | |||
| @@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| constexpr auto kIsBackendCast = "is_backed_cast"; | |||
| constexpr auto kAttrOutputNames = "output_names"; | |||
| constexpr auto kAttrVisited = "visited"; | |||
| constexpr auto kAttrShape = "shape"; | |||
| @@ -201,10 +202,6 @@ constexpr auto kControlDependBehindIndex = 2; | |||
| // index define of depend | |||
| constexpr auto kRealInputIndexInDepend = 1; | |||
| constexpr auto kDependAttachNodeIndex = 2; | |||
| // status of kernel select result | |||
| const int kStatusReducePrecision = -1; | |||
| const int kStatusRaisePrecision = 1; | |||
| const int kStatusAllMatched = 0; | |||
| // format | |||
| constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; | |||
| constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; | |||
| @@ -218,18 +215,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; | |||
| constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | |||
| constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | |||
| constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | |||
| const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, | |||
| kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||
| kOpFormat_C1HWNCoC0}; | |||
| const std::set<std::string> k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z, | |||
| kOpFormat_NC1KHKWHWC0}; | |||
| const std::set<std::string> k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0}; | |||
| const std::set<std::string> k4DSupportFormat = k1DSupportFormat; | |||
| const std::vector<std::set<std::string>> kShapeSupportFormatMap = {k1DSupportFormat, k2DSupportFormat, k3DSupportFormat, | |||
| k4DSupportFormat}; | |||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | |||
| const std::set<std::string> kOptOperatorSet = { | |||
| kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, | |||
| kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName, | |||
| @@ -1,345 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindspore/ccsrc/device/ascend/kernel_select_ascend.h" | |||
| #include "common/common_test.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "kernel/kernel.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "operator/ops.h" | |||
| #include "mindspore/ccsrc/device/kernel_info.h" | |||
| #include "mindspore/ccsrc/kernel/kernel_build_info.h" | |||
| #include <vector> | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| namespace { | |||
| using KernelInfo = device::KernelInfo; | |||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | |||
| using KernelBuildInfo = kernel::KernelBuildInfo; | |||
| using KernelGraph = session::KernelGraph; | |||
| using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>; | |||
| using KernelBuilderPtr = std::shared_ptr<KernelBuildInfoBuilder>; | |||
| using Shape = std::vector<size_t>; | |||
| using ShapeList = std::vector<Shape>; | |||
| enum MatchCountPriority { | |||
| MATCH_COUNT_PRIORITY_BEGIN = 0, | |||
| MATCH_FORMAT_COUNT = MATCH_COUNT_PRIORITY_BEGIN, | |||
| MATCH_DTYPE_COUNT, | |||
| MATCH_NZ_FORMAT_COUNT, | |||
| MATCH_5D_FORMAT_COUNT, | |||
| MATCH_OUTPUT_DTYPE_COUNT, | |||
| MATCH_COUNT_PRIORITY_END | |||
| }; | |||
| const std::set<std::string> kOpFormatList = { | |||
| kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, | |||
| kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ}; | |||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||
| // if format is default,it remarkes support all format | |||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||
| MS_EXCEPTION(ArgumentError) << "got the unknow format " << format; | |||
| } | |||
| if (format == kOpFormat_DEFAULT) { | |||
| return true; | |||
| } | |||
| // if shape size is 0,the shape will be a scalar | |||
| if (shape.empty()) { | |||
| return true; | |||
| } | |||
| if (shape.size() > kShapeSupportFormatMap.size()) { | |||
| return false; | |||
| } | |||
| if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { | |||
| return shape[shape.size() - 1] % 16 != 0 && shape[shape.size() - 2] % 16 != 0; | |||
| } | |||
| return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); | |||
| } | |||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool { | |||
| if (!IsShapeMatchFormat(shape, format)) { | |||
| return false; | |||
| } | |||
| for (auto shape_value : shape) { | |||
| if (shape_value == 0) { | |||
| MS_EXCEPTION(ArgumentError) << "dimension size of the tensor shape should be a positive integer, but got [" | |||
| << shape_value << "]"; | |||
| } | |||
| } | |||
| return true; | |||
| }; | |||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | |||
| if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||
| return false; | |||
| } | |||
| } | |||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | |||
| if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // Check input data type | |||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | |||
| AnfNodePtr cur_input = cnode->input(input_index + 1); | |||
| MS_EXCEPTION_IF_NULL(cur_input); | |||
| TypeId input_origin_type; | |||
| if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) { | |||
| // weight | |||
| input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); | |||
| } else { | |||
| // feature map | |||
| input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||
| } | |||
| if (input_origin_type == kTypeUnknown) { | |||
| continue; | |||
| } | |||
| if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { | |||
| return false; | |||
| } | |||
| } | |||
| // Check output data type | |||
| for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { | |||
| if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| /** | |||
| * compare too vector by priority,select a better vector,like compare too num,first compare highest num location,if | |||
| * equal then next num location | |||
| * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] | |||
| */ | |||
| bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) { | |||
| MS_EXCEPTION_IF_NULL(best_item); | |||
| if (cur_item.size() != best_item->size()) { | |||
| MS_LOG(ERROR) << "item size should be same!"; | |||
| return false; | |||
| } | |||
| // Update the best_item by comparing the cur_item and best_item | |||
| for (size_t i = 0; i < cur_item.size(); i++) { | |||
| if (cur_item[i] > best_item->at(i)) { | |||
| *best_item = cur_item; | |||
| return true; | |||
| } else if (cur_item[i] == best_item->at(i)) { | |||
| continue; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, | |||
| std::vector<int> *const cur_kernelinfo_match_counts) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts); | |||
| if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { | |||
| MS_EXCEPTION(ArgumentError) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; | |||
| } | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| AnfNodePtr input_anf_node = kernel_node->input(input_index + 1); | |||
| MS_EXCEPTION_IF_NULL(input_anf_node); | |||
| // if a input parameter is a weight with default format, the input shouldn't participate the judge | |||
| if (input_anf_node->isa<Parameter>()) { | |||
| auto para = input_anf_node->cast<ParameterPtr>(); | |||
| if (AnfAlgo::IsParameterWeight(para) && AnfAlgo::GetOutputDeviceDataType(para, 0) == kTypeUnknown) { | |||
| continue; | |||
| } | |||
| } | |||
| if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; | |||
| } | |||
| if (kernel_build_info.GetInputDeviceType(input_index) == | |||
| AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT]++; | |||
| } | |||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_FRAC_NZ) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_NZ_FORMAT_COUNT]++; | |||
| } | |||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_NC1HWC0) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_5D_FORMAT_COUNT]++; | |||
| } | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| // cal count of same output dtype between abstract and kernel info | |||
| if (kernel_build_info.GetOutputDeviceType(output_index) == | |||
| AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; | |||
| } | |||
| } | |||
| } | |||
| void SetKernelBuildInfo(KernelBuilderPtr builder) { | |||
| builder->SetFusionType(kernel::OPAQUE); | |||
| builder->SetKernelType(AUTO_DIFF_KERNEL); | |||
| builder->SetProcessor(kernel::AICORE); | |||
| } | |||
| void test_select(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list) { | |||
| std::vector<int> most_match_counts = {-1, -1, -1, -1, -1}; | |||
| int selected_index = -1; | |||
| for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | |||
| std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; | |||
| if (!IsValidKernelInfo(kernel_node, *(kernel_info_list[info_index]))) { | |||
| continue; | |||
| } | |||
| if (!MatchInferOutputDataType(kernel_node, *(kernel_info_list[info_index]))) { | |||
| continue; | |||
| } | |||
| std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; | |||
| UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); | |||
| // Currently the selection policy is the match format count first, and then is datatype counts. | |||
| if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { | |||
| selected_index = SizeToInt(info_index); | |||
| } | |||
| } | |||
| if (selected_index == -1) { | |||
| MS_EXCEPTION(NotExistsError) << "" << kernel_node->DebugString() << " Cannot find valid kernel Info !"; | |||
| } | |||
| auto index = IntToSize(selected_index); | |||
| if (index >= kernel_info_list.size()) { | |||
| MS_EXCEPTION(ArgumentError) << "index outof range"; | |||
| } | |||
| std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info_ptr = kernel_info_list[index]; | |||
| MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr); | |||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get()); | |||
| } | |||
| void SetParentAbstract(std::vector<AnfNodePtr> parent_list, std::vector<std::vector<size_t>> shapes, | |||
| std::vector<TypeId> types) { | |||
| for (const auto &node : parent_list) { | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, node.get()); | |||
| } | |||
| } | |||
| } // namespace | |||
| class AscendKernelSelctTest : public UT::Common { | |||
| public: | |||
| AscendKernelSelctTest() = default; | |||
| void SetUp() override {} | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(AscendKernelSelctTest, TestSelect) { | |||
| std::vector<KernelBuilderPtr> build_list; | |||
| std::vector<TypeId> type_list = {kNumberTypeFloat32}; | |||
| for (size_t i = 0; i <= 4; ++i) { | |||
| build_list.push_back(std::make_shared<KernelBuildInfoBuilder>()); | |||
| SetKernelBuildInfo(build_list[i]); | |||
| build_list[i]->SetInputsDeviceType(type_list); | |||
| build_list[i]->SetOutputsDeviceType(type_list); | |||
| } | |||
| std::vector<std::string> nd_fmt = {kOpFormat_DEFAULT}; | |||
| std::vector<std::string> nz_fmt = {kOpFormat_FRAC_NZ}; | |||
| auto anf_graph = std::make_shared<KernelGraph>(); | |||
| // 16's multiple should not chose format NZ | |||
| Shape nd_shapes = {2, 32, 224, 224}; | |||
| Shape nz_shapes = {3, 3, 5, 5}; | |||
| auto add_value = NewValueNode(prim::kPrimTensorAdd); | |||
| auto a_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value}); | |||
| auto b_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value}); | |||
| std::vector<AnfNodePtr> parent_list = {add_value, a_node, b_node}; | |||
| auto c_node = anf_graph->NewCNode(parent_list); | |||
| // a b | |||
| // \ / | |||
| // c | |||
| // a & b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}} | |||
| // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} | |||
| // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3,224, 224}} | |||
| // set a & b's info | |||
| SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list); | |||
| // set abstract c | |||
| AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nd_shapes}, c_node.get()); | |||
| // set format of kernel info | |||
| build_list[0]->SetOutputsFormat(nz_fmt); | |||
| build_list[1]->SetOutputsFormat(nz_fmt); | |||
| build_list[2]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nd_fmt[0]}); | |||
| build_list[3]->SetInputsFormat(std::vector<std::string>{nz_fmt[0], nz_fmt[0]}); | |||
| build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| build_list[2]->SetOutputsFormat(nd_fmt); | |||
| build_list[3]->SetOutputsFormat(nz_fmt); | |||
| std::vector<KernelBuildInfoPtr> select_info_list; | |||
| // set select info list | |||
| select_info_list.emplace_back(build_list[2]->Build()); | |||
| select_info_list.emplace_back(build_list[3]->Build()); | |||
| // set device info for a & b | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get()); | |||
| test_select(c_node, select_info_list); | |||
| EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT); | |||
| EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_DEFAULT); | |||
| // set a & b's info | |||
| // a b | |||
| // \ / | |||
| // c | |||
| // a: kernel_info:{output_format:{5d},dtype:{kNumberTypeFloat32}} | |||
| // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} | |||
| // b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}} | |||
| // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} | |||
| // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} | |||
| // set a & b's info | |||
| SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list); | |||
| // set abstract c | |||
| AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nz_shapes}, c_node.get()); | |||
| // set format of kernel info | |||
| build_list[0]->SetOutputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0}); | |||
| build_list[1]->SetOutputsFormat(nz_fmt); | |||
| build_list[2]->SetInputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0, nd_fmt[0]}); | |||
| build_list[3]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nz_fmt[0]}); | |||
| build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| build_list[2]->SetOutputsFormat(nd_fmt); | |||
| build_list[3]->SetOutputsFormat(nz_fmt); | |||
| // set select info list | |||
| select_info_list.emplace_back(build_list[2]->Build()); | |||
| select_info_list.emplace_back(build_list[3]->Build()); | |||
| // set device info for a & b | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get()); | |||
| test_select(c_node, select_info_list); | |||
| EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT); | |||
| EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_FRAC_NZ); | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -39,7 +39,7 @@ class MockSupportedChecker : public SupportedChecker { | |||
| public: | |||
| MockSupportedChecker() = default; | |||
| ~MockSupportedChecker() override = default; | |||
| bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||
| bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||
| return true; | |||
| } | |||
| }; // namespace opt | |||
| @@ -37,6 +37,15 @@ class TestHWTransposeTransdataFusion : public BackendCommon { | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| class MockSupportedChecker : public SupportedChecker { | |||
| public: | |||
| MockSupportedChecker() = default; | |||
| ~MockSupportedChecker() override = default; | |||
| bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||
| return true; | |||
| } | |||
| }; | |||
| class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| public: | |||
| MockInsertTransOpKernelSelectTrans4Dto5D() = default; | |||
| @@ -60,37 +69,6 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| } | |||
| }; | |||
| class MockTransposeTransdataFusionKernelSelect : public KernelSelect { | |||
| public: | |||
| MockTransposeTransdataFusionKernelSelect() = default; | |||
| ~MockTransposeTransdataFusionKernelSelect() override = default; | |||
| bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, | |||
| const kernel::KernelBuildInfoPtr &new_kernel_build_info) override { | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({kOpFormat_NCHW}); | |||
| builder.SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| builder.SetInputsDeviceType({kNumberTypeFloat16}); | |||
| builder.SetOutputsDeviceType({kNumberTypeFloat16}); | |||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||
| builder.SetFusionType(kernel::FusionType::OPAQUE); | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| kernel_info_list.push_back(builder.Build()); | |||
| MS_LOG(INFO) << "transpose transdata fusion success"; | |||
| MS_LOG(INFO) << "new transdata build info input format:" << new_kernel_build_info->GetInputFormat(0) | |||
| << ",outputformat:" << new_kernel_build_info->GetOutputFormat(0) | |||
| << ",kerneltype:" << new_kernel_build_info->kernel_type() | |||
| << ",fusiontype:" << new_kernel_build_info->fusion_type() | |||
| << ",process:" << new_kernel_build_info->processor(); | |||
| auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(), | |||
| [&new_kernel_build_info](kernel::KernelBuildInfoPtr item) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| return *item == *new_kernel_build_info; | |||
| }); | |||
| return result != kernel_info_list.end(); | |||
| } | |||
| }; | |||
| TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { | |||
| /* | |||
| * def before(input0, input1): | |||
| @@ -128,7 +106,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { | |||
| insert_trans_op_pass->kernel_select_ = std::make_shared<MockInsertTransOpKernelSelectTrans4Dto5D>(); | |||
| pm->AddPass(insert_trans_op_pass); | |||
| auto transpose_transdata_pass = std::make_shared<opt::TransposeTransDataFusion>(); | |||
| transpose_transdata_pass->kernel_select_ = std::make_shared<MockTransposeTransdataFusionKernelSelect>(); | |||
| transpose_transdata_pass->supported_checker_ = std::make_shared<MockSupportedChecker>(); | |||
| pm->AddPass(transpose_transdata_pass); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||