Merge pull request !1079 from lianliguang/convert-to-AICPU-when-AiCore-not-supported-kerneltags/v0.3.0-alpha
| @@ -85,7 +85,7 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType | |||||
| } while (0) | } while (0) | ||||
| template <typename T> | template <typename T> | ||||
| T Ceil(T n1, T n2) { | |||||
| T DivCeil(T n1, T n2) { | |||||
| return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | ||||
| } | } | ||||
| @@ -371,15 +371,48 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { | |||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) { | |||||
| if (!CheckDims(shape)) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||||
| } | |||||
| std::vector<size_t> device_shape; | |||||
| size_t c0 = 4; | |||||
| size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize); | |||||
| size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize); | |||||
| device_shape.push_back(first_dim); | |||||
| device_shape.push_back(no); | |||||
| device_shape.push_back(kCubeSize); | |||||
| device_shape.push_back(kCubeSize); | |||||
| return device_shape; | |||||
| } | |||||
| std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { | |||||
| if (!CheckDims(shape)) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||||
| } | |||||
| std::vector<size_t> device_shape; | |||||
| size_t C1 = 1; | |||||
| size_t C0 = 4; | |||||
| device_shape.push_back(shape[0]); | |||||
| device_shape.push_back(C1); | |||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(C0); | |||||
| return device_shape; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | ||||
| using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; | using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; | ||||
| const std::map<std::string, DeviceShapeTransfer> device_shape_map{ | |||||
| {kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape}, | |||||
| {kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape}, | |||||
| {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | |||||
| }; | |||||
| const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, | |||||
| {kOpFormat_NHWC, NhwcDeviceShape}, | |||||
| {kOpFormat_HWCN, HwchDeviceShape}, | |||||
| {kOpFormat_FRAC_Z, FracZDeviceShape}, | |||||
| {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, | |||||
| {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | |||||
| {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, | |||||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}}; | |||||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | ||||
| return shape; | return shape; | ||||
| @@ -506,13 +539,13 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t c1 = Ceil(c, c0); | |||||
| size_t c1 = DivCeil(c, c0); | |||||
| size_t hw = h * w; | size_t hw = h * w; | ||||
| size_t chw = c * hw; | size_t chw = c * hw; | ||||
| size_t hwc0 = hw * c0; | size_t hwc0 = hw * c0; | ||||
| size_t nchw = n * chw; | size_t nchw = n * chw; | ||||
| size_t hf_cnt = Ceil(n, kCubeSize); | |||||
| size_t hf_cnt = DivCeil(n, kCubeSize); | |||||
| size_t vf_cnt = c1 * hw; | size_t vf_cnt = c1 * hw; | ||||
| size_t fractal_ele_cnt = c0 * kCubeSize; | size_t fractal_ele_cnt = c0 * kCubeSize; | ||||
| size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | ||||
| @@ -775,7 +808,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t c1 = Ceil(c, c0); | |||||
| size_t c1 = DivCeil(c, c0); | |||||
| size_t hw = h * w; | size_t hw = h * w; | ||||
| size_t chw = c * hw; | size_t chw = c * hw; | ||||
| size_t c1hwc0 = c1 * hw * c0; | size_t c1hwc0 = c1 * hw * c0; | ||||
| @@ -34,6 +34,7 @@ namespace ascend { | |||||
| namespace { | namespace { | ||||
| const float kWegihtBaseScore = 1; | const float kWegihtBaseScore = 1; | ||||
| const float kFeatureMapBaseScore = 10; | const float kFeatureMapBaseScore = 10; | ||||
| constexpr auto kPriChoosenFormat = "pri_format"; | |||||
| enum MatchCountPriority : int { | enum MatchCountPriority : int { | ||||
| MATCH_COUNT_PRIORITY_BEGIN = 0, | MATCH_COUNT_PRIORITY_BEGIN = 0, | ||||
| MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, | MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, | ||||
| @@ -85,6 +86,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { | |||||
| if (need_change_nd) { | if (need_change_nd) { | ||||
| priority_matched_format = kOpFormat_DEFAULT; | priority_matched_format = kOpFormat_DEFAULT; | ||||
| } | } | ||||
| AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); | |||||
| return priority_matched_format; | return priority_matched_format; | ||||
| } | } | ||||
| /** | /** | ||||
| @@ -394,9 +396,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, | |||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| buffer << cnode->DebugString(); | buffer << cnode->DebugString(); | ||||
| if (precision_reduce) { | if (precision_reduce) { | ||||
| buffer << " reduce precision, node datatype: "; | |||||
| buffer << " reduce precision, node datatype: \n"; | |||||
| } else { | } else { | ||||
| buffer << " raise precision, node datatype: "; | |||||
| buffer << " raise precision, node datatype: \n"; | |||||
| } | } | ||||
| PrintInputAndOutputInferType(buffer, cnode); | PrintInputAndOutputInferType(buffer, cnode); | ||||
| buffer << ", select kernel:" << selected_kernel_build_info->ToString(); | buffer << ", select kernel:" << selected_kernel_build_info->ToString(); | ||||
| @@ -464,66 +466,57 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| std::shared_ptr<kernel::KernelBuildInfo> CanHitKernelInfo( | |||||
| int *status, const CNodePtr &kernel_node, | |||||
| const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | |||||
| KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | |||||
| const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| KernelSelectStatus select_status = kNoMatched; | |||||
| bool precision_reduce = false; | bool precision_reduce = false; | ||||
| std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; | std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; | ||||
| // Matched kernel info | |||||
| // Filter kernel info matched with me infered type | |||||
| auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); | auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); | ||||
| if (!filtered_kernel_info_list.empty()) { | if (!filtered_kernel_info_list.empty()) { | ||||
| selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | ||||
| select_status = kStatusAllMatched; | |||||
| } else { | } else { | ||||
| // selected kernel info using raised precision or reduce precision | // selected kernel info using raised precision or reduce precision | ||||
| filtered_kernel_info_list = | filtered_kernel_info_list = | ||||
| FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); | FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); | ||||
| selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); | ||||
| if (selected_kernel_info == nullptr) { | if (selected_kernel_info == nullptr) { | ||||
| return nullptr; | |||||
| return select_status; | |||||
| } else { | } else { | ||||
| PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); | PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); | ||||
| *status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; | |||||
| select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; | |||||
| } | } | ||||
| } | } | ||||
| return selected_kernel_info; | |||||
| // Set kernel info to the anfnode | |||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); | |||||
| // Set format and data type for input tensor. | |||||
| SetTensorDeviceInfo(*selected_kernel_info, kernel_node); | |||||
| return select_status; | |||||
| } | } | ||||
| int SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | ||||
| int status = kStatusAllMatched; | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | kernel::KernelQuery(kernel_node, &kernel_info_list); | ||||
| // filter kernel info matched with me infered type | |||||
| auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list); | |||||
| if (selected_kernel_info == nullptr) { | |||||
| auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | |||||
| // If aicore not find valid kernel info reloading aicpu kernel info list to find it | |||||
| if (select_status == kNoMatched) { | |||||
| MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() | MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() | ||||
| << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | ||||
| kernel::AicpuQuery(kernel_node, &kernel_info_list); | |||||
| selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list); | |||||
| kernel::AICpuQuery(kernel_node, &kernel_info_list); | |||||
| select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | |||||
| } | } | ||||
| if (selected_kernel_info == nullptr) { | |||||
| // The kernel info not finded both in the aicpu kernel list & aicore kernel list | |||||
| if (select_status == kNoMatched) { | |||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| PrintInputAndOutputInferType(buffer, kernel_node); | PrintInputAndOutputInferType(buffer, kernel_node); | ||||
| MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() | MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() | ||||
| << "] cannot find valid kernel info, not supported the type " << buffer.str(); | << "] cannot find valid kernel info, not supported the type " << buffer.str(); | ||||
| } | } | ||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); | |||||
| // Set format and data type for input tensor. | |||||
| SetTensorDeviceInfo(*selected_kernel_info, kernel_node); | |||||
| return status; | |||||
| } | |||||
| bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, | |||||
| const kernel::KernelBuildInfoPtr &new_kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | |||||
| auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(), | |||||
| [&new_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| return *item == *new_kernel_build_info; | |||||
| }); | |||||
| return result != kernel_info_list.end(); | |||||
| return select_status; | |||||
| } | } | ||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace device | } // namespace device | ||||
| @@ -21,8 +21,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| int SelectKernelInfo(const CNodePtr &kernel_node); | |||||
| bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info); | |||||
| enum KernelSelectStatus { | |||||
| kNoMatched = -1, | |||||
| kStatusAllMatched = 0, | |||||
| kStatusReducePrecision = 1, | |||||
| kStatusRaisePrecision = 2, | |||||
| }; | |||||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node); | |||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,7 +35,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K | |||||
| std::vector<std::string> input_format, output_format; | std::vector<std::string> input_format, output_format; | ||||
| std::vector<TypeId> input_type, output_type; | std::vector<TypeId> input_type, output_type; | ||||
| for (const auto &data_type : data_type_list) { | for (const auto &data_type : data_type_list) { | ||||
| for (const auto &format : k4DSupportFormat) { | |||||
| for (const auto &format : kOpFormatList) { | |||||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| input_format.clear(); | input_format.clear(); | ||||
| input_format.push_back(format); | input_format.push_back(format); | ||||
| @@ -35,14 +35,18 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, | |||||
| return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && | return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && | ||||
| AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); | AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); | ||||
| }); | }); | ||||
| kernel_info_list->clear(); | |||||
| if (!filtered_list.empty()) { | if (!filtered_list.empty()) { | ||||
| kernel_info_list->clear(); | |||||
| (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); | (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "node" << kernel_node->DebugString() << "'s output size : [" | |||||
| << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" | |||||
| << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) | |||||
| << "] cannot match any kernelInfo !"; | |||||
| MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; | |||||
| for (size_t index; index < kernel_info_list->size(); ++index) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); | |||||
| MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); | |||||
| } | |||||
| MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" | |||||
| << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" | |||||
| << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; | |||||
| } | } | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -50,7 +54,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| TbeMetadataInfo(kernel_node, kernel_info_list); | TbeMetadataInfo(kernel_node, kernel_info_list); | ||||
| if (kernel_info_list->empty()) { | if (kernel_info_list->empty()) { | ||||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | AicpuMetadataInfo(kernel_node, kernel_info_list); | ||||
| } | } | ||||
| @@ -68,12 +71,41 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | FilterInvalidKernelInfo(kernel_node, kernel_info_list); | ||||
| } | } | ||||
| void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||||
| void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| kernel_info_list->clear(); | kernel_info_list->clear(); | ||||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | AicpuMetadataInfo(kernel_node, kernel_info_list); | ||||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | FilterInvalidKernelInfo(kernel_node, kernel_info_list); | ||||
| } | } | ||||
| bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(select_kernel_build_info); | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||||
| auto cnode = kernel_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| AicpuMetadataInfo(cnode, &kernel_info_list); | |||||
| FilterInvalidKernelInfo(cnode, &kernel_info_list); | |||||
| return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | |||||
| [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| return *item == *select_kernel_build_info; | |||||
| }); | |||||
| } | |||||
| bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(select_kernel_build_info); | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||||
| auto cnode = kernel_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| TbeMetadataInfo(cnode, &kernel_info_list); | |||||
| FilterInvalidKernelInfo(cnode, &kernel_info_list); | |||||
| return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | |||||
| [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| return *item == *select_kernel_build_info; | |||||
| }); | |||||
| } | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,7 +26,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | ||||
| void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||||
| void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||||
| bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||||
| bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ | #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ | ||||
| @@ -551,11 +551,6 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn | |||||
| } | } | ||||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | ||||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||||
| // if format is default, it remarkes support all format | // if format is default, it remarkes support all format | ||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | if (kOpFormatList.find(format) == kOpFormatList.end()) { | ||||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | MS_LOG(EXCEPTION) << "Got the unknown format " << format; | ||||
| @@ -54,6 +54,7 @@ | |||||
| #include "pre_activate/pass/optimize_dependence.h" | #include "pre_activate/pass/optimize_dependence.h" | ||||
| #include "pre_activate/pass/erase_visit_attr.h" | #include "pre_activate/pass/erase_visit_attr.h" | ||||
| #include "pre_activate/ascend/format_type/insert_cast.h" | #include "pre_activate/ascend/format_type/insert_cast.h" | ||||
| #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | |||||
| #include "pre_activate/pass/eliminate_redundant_op.h" | #include "pre_activate/pass/eliminate_redundant_op.h" | ||||
| #include "pre_activate/pass/common_subexpression_elimination.h" | #include "pre_activate/pass/common_subexpression_elimination.h" | ||||
| #include "pre_activate/ascend/format_type/merge_cast_to_op.h" | #include "pre_activate/ascend/format_type/merge_cast_to_op.h" | ||||
| @@ -172,6 +173,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); | |||||
| optimizer->AddPassManager(mixed_precision_pm); | optimizer->AddPassManager(mixed_precision_pm); | ||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||
| @@ -268,6 +268,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr | |||||
| } | } | ||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | ||||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); | AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); | ||||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); | |||||
| return cast; | return cast; | ||||
| } | } | ||||
| @@ -30,10 +30,6 @@ class KernelSelect { | |||||
| KernelSelect() = default; | KernelSelect() = default; | ||||
| virtual ~KernelSelect() = default; | virtual ~KernelSelect() = default; | ||||
| virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } | virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } | ||||
| virtual bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, | |||||
| const kernel::KernelBuildInfoPtr &new_kernel_build_info) { | |||||
| return device::ascend::CheckKernelAccuracySupported(kernel_node, new_kernel_build_info); | |||||
| } | |||||
| }; | }; | ||||
| using KernelSelectPtr = std::shared_ptr<KernelSelect>; | using KernelSelectPtr = std::shared_ptr<KernelSelect>; | ||||
| @@ -41,8 +37,13 @@ class SupportedChecker { | |||||
| public: | public: | ||||
| SupportedChecker() = default; | SupportedChecker() = default; | ||||
| virtual ~SupportedChecker() = default; | virtual ~SupportedChecker() = default; | ||||
| virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| return kernel::CheckSupported(anf_node, select_kernel_build_info); | |||||
| virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, | |||||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info); | |||||
| } | |||||
| virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, | |||||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info); | |||||
| } | } | ||||
| }; | }; | ||||
| using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; | using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; | ||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | |||||
| #include <memory> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| #include "kernel/kernel_query.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({X, Xs}); | |||||
| } | |||||
| const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &, | |||||
| const mindspore::AnfNodePtr &node, | |||||
| const mindspore::EquivPtr &) const { | |||||
| if (node == nullptr || !node->isa<CNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto node_name = AnfAlgo::GetCNodeName(node); | |||||
| if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); | |||||
| if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) { | |||||
| return node; | |||||
| } else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) { | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info); | |||||
| builder->SetKernelType(AICPU_KERNEL); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" | |||||
| << node->DebugString() << "]"; | |||||
| } | |||||
| return node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| #ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H | |||||
| #define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ConvertUnSupportNodeToAICPU : public PatternProcessPass { | |||||
| public: | |||||
| explicit ConvertUnSupportNodeToAICPU(bool multigraph = true) | |||||
| : PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph), | |||||
| supported_checker_(std::make_shared<SupportedChecker>()) {} | |||||
| ~ConvertUnSupportNodeToAICPU() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| SupportedCheckerPtr supported_checker_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H | |||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ | |||||
| #define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||||
| #include <string> | #include <string> | ||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| @@ -32,4 +32,4 @@ class RunOpInsertCast : public PatternProcessPass { | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_ | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||||
| #define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -41,4 +41,4 @@ class RunOpInsertTransData : public PatternProcessPass { | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ | |||||
| @@ -128,7 +128,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod | |||||
| auto indices_const = CreateValueNode(new_cnode); | auto indices_const = CreateValueNode(new_cnode); | ||||
| new_cnode->add_input(indices_const); | new_cnode->add_input(indices_const); | ||||
| MS_EXCEPTION_IF_NULL(supported_checker_); | MS_EXCEPTION_IF_NULL(supported_checker_); | ||||
| if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) { | |||||
| if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap | |||||
| new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); | new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); | ||||
| auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName); | auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName); | ||||
| if (kernel_select_->CheckKernelAccuracySupported(transdata_cnode, new_transdata_builder->Build())) { | |||||
| if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) { | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata), | std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata), | ||||
| utils::cast<AnfNodePtr>((*equiv)[input_varptr_])}; | utils::cast<AnfNodePtr>((*equiv)[input_varptr_])}; | ||||
| auto new_node = func_graph->NewCNode(inputs); | auto new_node = func_graph->NewCNode(inputs); | ||||
| @@ -34,7 +34,7 @@ class TransposeTransDataFusion : public PatternProcessPass { | |||||
| explicit TransposeTransDataFusion(bool multigraph = true) | explicit TransposeTransDataFusion(bool multigraph = true) | ||||
| : PatternProcessPass("transpose_transdata_fusion", multigraph) { | : PatternProcessPass("transpose_transdata_fusion", multigraph) { | ||||
| input_varptr_ = std::make_shared<Var>(); | input_varptr_ = std::make_shared<Var>(); | ||||
| kernel_select_ = std::make_shared<KernelSelect>(); | |||||
| supported_checker_ = std::make_shared<SupportedChecker>(); | |||||
| } | } | ||||
| ~TransposeTransDataFusion() override = default; | ~TransposeTransDataFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| @@ -42,7 +42,9 @@ class TransposeTransDataFusion : public PatternProcessPass { | |||||
| private: | private: | ||||
| VarPtr input_varptr_; | VarPtr input_varptr_; | ||||
| KernelSelectPtr kernel_select_; | |||||
| private: | |||||
| SupportedCheckerPtr supported_checker_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -329,9 +329,9 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { | |||||
| size_t reduce_precision_count = 0; | size_t reduce_precision_count = 0; | ||||
| for (const auto &cnode : kernel_graph.execution_order()) { | for (const auto &cnode : kernel_graph.execution_order()) { | ||||
| auto status = device::ascend::SelectKernelInfo(cnode); | auto status = device::ascend::SelectKernelInfo(cnode); | ||||
| if (status == kStatusRaisePrecision) { | |||||
| if (status == device::ascend::kStatusRaisePrecision) { | |||||
| raise_precision_count++; | raise_precision_count++; | ||||
| } else if (status == kStatusReducePrecision) { | |||||
| } else if (status == device::ascend::kStatusReducePrecision) { | |||||
| reduce_precision_count++; | reduce_precision_count++; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); | MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); | ||||
| @@ -27,6 +27,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| namespace { | namespace { | ||||
| constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | |||||
| constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | |||||
| void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | ||||
| std::unordered_set<AnfNodePtr> *visited_nodes) { | std::unordered_set<AnfNodePtr> *visited_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(que); | MS_EXCEPTION_IF_NULL(que); | ||||
| @@ -180,11 +182,24 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||||
| cnode->set_abstract(std::make_shared<abstract::AbstractNone>()); | cnode->set_abstract(std::make_shared<abstract::AbstractNone>()); | ||||
| // create kernel_info from new parameter | // create kernel_info from new parameter | ||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | auto kernel_info = std::make_shared<device::KernelInfo>(); | ||||
| std::vector<size_t> feature_map_input_indexs; | |||||
| // if the node only has the primitive(such as getNext) or the node's input has a feature map input | // if the node only has the primitive(such as getNext) or the node's input has a feature map input | ||||
| // then the node's output is a feature map output | // then the node's output is a feature map output | ||||
| if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(), | |||||
| [&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) { | |||||
| for (size_t index = 1; index < inputs.size(); ++index) { | |||||
| auto node = inputs[index]; | |||||
| if (AnfAlgo::IsFeatureMapOutput(node)) { | |||||
| feature_map_input_indexs.push_back(index); | |||||
| } | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { | |||||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); | |||||
| } | |||||
| if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { | |||||
| kernel_info->SetFeatureMapFlag(true); | kernel_info->SetFeatureMapFlag(true); | ||||
| AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(true), cnode); | |||||
| AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); | |||||
| } else { | |||||
| AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(false), cnode); | |||||
| } | } | ||||
| cnode->set_kernel_info(kernel_info); | cnode->set_kernel_info(kernel_info); | ||||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | AnfAlgo::SetGraphId(graph_id_, cnode.get()); | ||||
| @@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto"; | |||||
| // attr key name | // attr key name | ||||
| constexpr auto kAttrInputNames = "input_names"; | constexpr auto kAttrInputNames = "input_names"; | ||||
| constexpr auto kIsBackendCast = "is_backed_cast"; | |||||
| constexpr auto kAttrOutputNames = "output_names"; | constexpr auto kAttrOutputNames = "output_names"; | ||||
| constexpr auto kAttrVisited = "visited"; | constexpr auto kAttrVisited = "visited"; | ||||
| constexpr auto kAttrShape = "shape"; | constexpr auto kAttrShape = "shape"; | ||||
| @@ -201,10 +202,6 @@ constexpr auto kControlDependBehindIndex = 2; | |||||
| // index define of depend | // index define of depend | ||||
| constexpr auto kRealInputIndexInDepend = 1; | constexpr auto kRealInputIndexInDepend = 1; | ||||
| constexpr auto kDependAttachNodeIndex = 2; | constexpr auto kDependAttachNodeIndex = 2; | ||||
| // status of kernel select result | |||||
| const int kStatusReducePrecision = -1; | |||||
| const int kStatusRaisePrecision = 1; | |||||
| const int kStatusAllMatched = 0; | |||||
| // format | // format | ||||
| constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; | constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; | ||||
| constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; | constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; | ||||
| @@ -218,18 +215,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; | |||||
| constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | ||||
| constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | ||||
| constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | ||||
| const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, | |||||
| kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||||
| kOpFormat_C1HWNCoC0}; | |||||
| const std::set<std::string> k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z, | |||||
| kOpFormat_NC1KHKWHWC0}; | |||||
| const std::set<std::string> k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0}; | |||||
| const std::set<std::string> k4DSupportFormat = k1DSupportFormat; | |||||
| const std::vector<std::set<std::string>> kShapeSupportFormatMap = {k1DSupportFormat, k2DSupportFormat, k3DSupportFormat, | |||||
| k4DSupportFormat}; | |||||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | ||||
| const std::set<std::string> kOptOperatorSet = { | const std::set<std::string> kOptOperatorSet = { | ||||
| kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, | kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, | ||||
| kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName, | kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName, | ||||
| @@ -1,345 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "mindspore/ccsrc/device/ascend/kernel_select_ascend.h" | |||||
| #include "common/common_test.h" | |||||
| #include "session/kernel_graph.h" | |||||
| #include "kernel/kernel.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "utils/utils.h" | |||||
| #include "operator/ops.h" | |||||
| #include "mindspore/ccsrc/device/kernel_info.h" | |||||
| #include "mindspore/ccsrc/kernel/kernel_build_info.h" | |||||
| #include <vector> | |||||
| namespace mindspore { | |||||
| namespace device { | |||||
| namespace ascend { | |||||
| namespace { | |||||
| using KernelInfo = device::KernelInfo; | |||||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | |||||
| using KernelBuildInfo = kernel::KernelBuildInfo; | |||||
| using KernelGraph = session::KernelGraph; | |||||
| using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>; | |||||
| using KernelBuilderPtr = std::shared_ptr<KernelBuildInfoBuilder>; | |||||
| using Shape = std::vector<size_t>; | |||||
| using ShapeList = std::vector<Shape>; | |||||
| enum MatchCountPriority { | |||||
| MATCH_COUNT_PRIORITY_BEGIN = 0, | |||||
| MATCH_FORMAT_COUNT = MATCH_COUNT_PRIORITY_BEGIN, | |||||
| MATCH_DTYPE_COUNT, | |||||
| MATCH_NZ_FORMAT_COUNT, | |||||
| MATCH_5D_FORMAT_COUNT, | |||||
| MATCH_OUTPUT_DTYPE_COUNT, | |||||
| MATCH_COUNT_PRIORITY_END | |||||
| }; | |||||
| const std::set<std::string> kOpFormatList = { | |||||
| kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, | |||||
| kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ}; | |||||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||||
| // if format is default,it remarkes support all format | |||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||||
| MS_EXCEPTION(ArgumentError) << "got the unknow format " << format; | |||||
| } | |||||
| if (format == kOpFormat_DEFAULT) { | |||||
| return true; | |||||
| } | |||||
| // if shape size is 0,the shape will be a scalar | |||||
| if (shape.empty()) { | |||||
| return true; | |||||
| } | |||||
| if (shape.size() > kShapeSupportFormatMap.size()) { | |||||
| return false; | |||||
| } | |||||
| if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { | |||||
| return shape[shape.size() - 1] % 16 != 0 && shape[shape.size() - 2] % 16 != 0; | |||||
| } | |||||
| return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); | |||||
| } | |||||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool { | |||||
| if (!IsShapeMatchFormat(shape, format)) { | |||||
| return false; | |||||
| } | |||||
| for (auto shape_value : shape) { | |||||
| if (shape_value == 0) { | |||||
| MS_EXCEPTION(ArgumentError) << "dimension size of the tensor shape should be a positive integer, but got [" | |||||
| << shape_value << "]"; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | |||||
| if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | |||||
| if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| // Check input data type | |||||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | |||||
| AnfNodePtr cur_input = cnode->input(input_index + 1); | |||||
| MS_EXCEPTION_IF_NULL(cur_input); | |||||
| TypeId input_origin_type; | |||||
| if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) { | |||||
| // weight | |||||
| input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); | |||||
| } else { | |||||
| // feature map | |||||
| input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||||
| } | |||||
| if (input_origin_type == kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| // Check output data type | |||||
| for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { | |||||
| if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| /** | |||||
| * compare too vector by priority,select a better vector,like compare too num,first compare highest num location,if | |||||
| * equal then next num location | |||||
| * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] | |||||
| */ | |||||
| bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) { | |||||
| MS_EXCEPTION_IF_NULL(best_item); | |||||
| if (cur_item.size() != best_item->size()) { | |||||
| MS_LOG(ERROR) << "item size should be same!"; | |||||
| return false; | |||||
| } | |||||
| // Update the best_item by comparing the cur_item and best_item | |||||
| for (size_t i = 0; i < cur_item.size(); i++) { | |||||
| if (cur_item[i] > best_item->at(i)) { | |||||
| *best_item = cur_item; | |||||
| return true; | |||||
| } else if (cur_item[i] == best_item->at(i)) { | |||||
| continue; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, | |||||
| std::vector<int> *const cur_kernelinfo_match_counts) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts); | |||||
| if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { | |||||
| MS_EXCEPTION(ArgumentError) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; | |||||
| } | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||||
| AnfNodePtr input_anf_node = kernel_node->input(input_index + 1); | |||||
| MS_EXCEPTION_IF_NULL(input_anf_node); | |||||
| // if a input parameter is a weight with default format, the input shouldn't participate the judge | |||||
| if (input_anf_node->isa<Parameter>()) { | |||||
| auto para = input_anf_node->cast<ParameterPtr>(); | |||||
| if (AnfAlgo::IsParameterWeight(para) && AnfAlgo::GetOutputDeviceDataType(para, 0) == kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; | |||||
| } | |||||
| if (kernel_build_info.GetInputDeviceType(input_index) == | |||||
| AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT]++; | |||||
| } | |||||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_FRAC_NZ) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_NZ_FORMAT_COUNT]++; | |||||
| } | |||||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_NC1HWC0) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_5D_FORMAT_COUNT]++; | |||||
| } | |||||
| } | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||||
| // cal count of same output dtype between abstract and kernel info | |||||
| if (kernel_build_info.GetOutputDeviceType(output_index) == | |||||
| AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; | |||||
| } | |||||
| } | |||||
| } | |||||
| void SetKernelBuildInfo(KernelBuilderPtr builder) { | |||||
| builder->SetFusionType(kernel::OPAQUE); | |||||
| builder->SetKernelType(AUTO_DIFF_KERNEL); | |||||
| builder->SetProcessor(kernel::AICORE); | |||||
| } | |||||
| void test_select(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list) { | |||||
| std::vector<int> most_match_counts = {-1, -1, -1, -1, -1}; | |||||
| int selected_index = -1; | |||||
| for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | |||||
| std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; | |||||
| if (!IsValidKernelInfo(kernel_node, *(kernel_info_list[info_index]))) { | |||||
| continue; | |||||
| } | |||||
| if (!MatchInferOutputDataType(kernel_node, *(kernel_info_list[info_index]))) { | |||||
| continue; | |||||
| } | |||||
| std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; | |||||
| UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); | |||||
| // Currently the selection policy is the match format count first, and then is datatype counts. | |||||
| if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { | |||||
| selected_index = SizeToInt(info_index); | |||||
| } | |||||
| } | |||||
| if (selected_index == -1) { | |||||
| MS_EXCEPTION(NotExistsError) << "" << kernel_node->DebugString() << " Cannot find valid kernel Info !"; | |||||
| } | |||||
| auto index = IntToSize(selected_index); | |||||
| if (index >= kernel_info_list.size()) { | |||||
| MS_EXCEPTION(ArgumentError) << "index outof range"; | |||||
| } | |||||
| std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info_ptr = kernel_info_list[index]; | |||||
| MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get()); | |||||
| } | |||||
| void SetParentAbstract(std::vector<AnfNodePtr> parent_list, std::vector<std::vector<size_t>> shapes, | |||||
| std::vector<TypeId> types) { | |||||
| for (const auto &node : parent_list) { | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, node.get()); | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| class AscendKernelSelctTest : public UT::Common { | |||||
| public: | |||||
| AscendKernelSelctTest() = default; | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| TEST_F(AscendKernelSelctTest, TestSelect) { | |||||
| std::vector<KernelBuilderPtr> build_list; | |||||
| std::vector<TypeId> type_list = {kNumberTypeFloat32}; | |||||
| for (size_t i = 0; i <= 4; ++i) { | |||||
| build_list.push_back(std::make_shared<KernelBuildInfoBuilder>()); | |||||
| SetKernelBuildInfo(build_list[i]); | |||||
| build_list[i]->SetInputsDeviceType(type_list); | |||||
| build_list[i]->SetOutputsDeviceType(type_list); | |||||
| } | |||||
| std::vector<std::string> nd_fmt = {kOpFormat_DEFAULT}; | |||||
| std::vector<std::string> nz_fmt = {kOpFormat_FRAC_NZ}; | |||||
| auto anf_graph = std::make_shared<KernelGraph>(); | |||||
| // 16's multiple should not chose format NZ | |||||
| Shape nd_shapes = {2, 32, 224, 224}; | |||||
| Shape nz_shapes = {3, 3, 5, 5}; | |||||
| auto add_value = NewValueNode(prim::kPrimTensorAdd); | |||||
| auto a_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value}); | |||||
| auto b_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value}); | |||||
| std::vector<AnfNodePtr> parent_list = {add_value, a_node, b_node}; | |||||
| auto c_node = anf_graph->NewCNode(parent_list); | |||||
| // a b | |||||
| // \ / | |||||
| // c | |||||
| // a & b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}} | |||||
| // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} | |||||
| // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3,224, 224}} | |||||
| // set a & b's info | |||||
| SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list); | |||||
| // set abstract c | |||||
| AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nd_shapes}, c_node.get()); | |||||
| // set format of kernel info | |||||
| build_list[0]->SetOutputsFormat(nz_fmt); | |||||
| build_list[1]->SetOutputsFormat(nz_fmt); | |||||
| build_list[2]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nd_fmt[0]}); | |||||
| build_list[3]->SetInputsFormat(std::vector<std::string>{nz_fmt[0], nz_fmt[0]}); | |||||
| build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32}); | |||||
| build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32}); | |||||
| build_list[2]->SetOutputsFormat(nd_fmt); | |||||
| build_list[3]->SetOutputsFormat(nz_fmt); | |||||
| std::vector<KernelBuildInfoPtr> select_info_list; | |||||
| // set select info list | |||||
| select_info_list.emplace_back(build_list[2]->Build()); | |||||
| select_info_list.emplace_back(build_list[3]->Build()); | |||||
| // set device info for a & b | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get()); | |||||
| test_select(c_node, select_info_list); | |||||
| EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT); | |||||
| EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_DEFAULT); | |||||
| // set a & b's info | |||||
| // a b | |||||
| // \ / | |||||
| // c | |||||
| // a: kernel_info:{output_format:{5d},dtype:{kNumberTypeFloat32}} | |||||
| // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} | |||||
| // b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}} | |||||
| // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} | |||||
| // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} | |||||
| // set a & b's info | |||||
| SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list); | |||||
| // set abstract c | |||||
| AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nz_shapes}, c_node.get()); | |||||
| // set format of kernel info | |||||
| build_list[0]->SetOutputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0}); | |||||
| build_list[1]->SetOutputsFormat(nz_fmt); | |||||
| build_list[2]->SetInputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0, nd_fmt[0]}); | |||||
| build_list[3]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nz_fmt[0]}); | |||||
| build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32}); | |||||
| build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32}); | |||||
| build_list[2]->SetOutputsFormat(nd_fmt); | |||||
| build_list[3]->SetOutputsFormat(nz_fmt); | |||||
| // set select info list | |||||
| select_info_list.emplace_back(build_list[2]->Build()); | |||||
| select_info_list.emplace_back(build_list[3]->Build()); | |||||
| // set device info for a & b | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get()); | |||||
| test_select(c_node, select_info_list); | |||||
| EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT); | |||||
| EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_FRAC_NZ); | |||||
| } | |||||
| } // namespace ascend | |||||
| } // namespace device | |||||
| } // namespace mindspore | |||||
| @@ -39,7 +39,7 @@ class MockSupportedChecker : public SupportedChecker { | |||||
| public: | public: | ||||
| MockSupportedChecker() = default; | MockSupportedChecker() = default; | ||||
| ~MockSupportedChecker() override = default; | ~MockSupportedChecker() override = default; | ||||
| bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||||
| bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||||
| return true; | return true; | ||||
| } | } | ||||
| }; // namespace opt | }; // namespace opt | ||||
| @@ -37,6 +37,15 @@ class TestHWTransposeTransdataFusion : public BackendCommon { | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | UT::PyFuncGraphFetcher get_py_fun_; | ||||
| }; | }; | ||||
| class MockSupportedChecker : public SupportedChecker { | |||||
| public: | |||||
| MockSupportedChecker() = default; | |||||
| ~MockSupportedChecker() override = default; | |||||
| bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||||
| return true; | |||||
| } | |||||
| }; | |||||
| class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | ||||
| public: | public: | ||||
| MockInsertTransOpKernelSelectTrans4Dto5D() = default; | MockInsertTransOpKernelSelectTrans4Dto5D() = default; | ||||
| @@ -60,37 +69,6 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||||
| } | } | ||||
| }; | }; | ||||
| class MockTransposeTransdataFusionKernelSelect : public KernelSelect { | |||||
| public: | |||||
| MockTransposeTransdataFusionKernelSelect() = default; | |||||
| ~MockTransposeTransdataFusionKernelSelect() override = default; | |||||
| bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, | |||||
| const kernel::KernelBuildInfoPtr &new_kernel_build_info) override { | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| builder.SetInputsFormat({kOpFormat_NCHW}); | |||||
| builder.SetOutputsFormat({kOpFormat_DEFAULT}); | |||||
| builder.SetInputsDeviceType({kNumberTypeFloat16}); | |||||
| builder.SetOutputsDeviceType({kNumberTypeFloat16}); | |||||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||||
| builder.SetFusionType(kernel::FusionType::OPAQUE); | |||||
| builder.SetProcessor(kernel::Processor::AICORE); | |||||
| kernel_info_list.push_back(builder.Build()); | |||||
| MS_LOG(INFO) << "transpose transdata fusion success"; | |||||
| MS_LOG(INFO) << "new transdata build info input format:" << new_kernel_build_info->GetInputFormat(0) | |||||
| << ",outputformat:" << new_kernel_build_info->GetOutputFormat(0) | |||||
| << ",kerneltype:" << new_kernel_build_info->kernel_type() | |||||
| << ",fusiontype:" << new_kernel_build_info->fusion_type() | |||||
| << ",process:" << new_kernel_build_info->processor(); | |||||
| auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(), | |||||
| [&new_kernel_build_info](kernel::KernelBuildInfoPtr item) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| return *item == *new_kernel_build_info; | |||||
| }); | |||||
| return result != kernel_info_list.end(); | |||||
| } | |||||
| }; | |||||
| TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { | TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { | ||||
| /* | /* | ||||
| * def before(input0, input1): | * def before(input0, input1): | ||||
| @@ -128,7 +106,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { | |||||
| insert_trans_op_pass->kernel_select_ = std::make_shared<MockInsertTransOpKernelSelectTrans4Dto5D>(); | insert_trans_op_pass->kernel_select_ = std::make_shared<MockInsertTransOpKernelSelectTrans4Dto5D>(); | ||||
| pm->AddPass(insert_trans_op_pass); | pm->AddPass(insert_trans_op_pass); | ||||
| auto transpose_transdata_pass = std::make_shared<opt::TransposeTransDataFusion>(); | auto transpose_transdata_pass = std::make_shared<opt::TransposeTransDataFusion>(); | ||||
| transpose_transdata_pass->kernel_select_ = std::make_shared<MockTransposeTransdataFusionKernelSelect>(); | |||||
| transpose_transdata_pass->supported_checker_ = std::make_shared<MockSupportedChecker>(); | |||||
| pm->AddPass(transpose_transdata_pass); | pm->AddPass(transpose_transdata_pass); | ||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | FuncGraphPtr new_graph = optimizer->Optimize(kg); | ||||