/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "mindspore/ccsrc/device/ascend/kernel_select_ascend.h" #include "common/common_test.h" #include "session/kernel_graph.h" #include "kernel/kernel.h" #include "session/anf_runtime_algorithm.h" #include "utils/utils.h" #include "operator/ops.h" #include "mindspore/ccsrc/device/kernel_info.h" #include "mindspore/ccsrc/kernel/kernel_build_info.h" #include namespace mindspore { namespace device { namespace ascend { namespace { using KernelInfo = device::KernelInfo; using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; using KernelBuildInfo = kernel::KernelBuildInfo; using KernelGraph = session::KernelGraph; using KernelBuildInfoPtr = std::shared_ptr; using KernelBuilderPtr = std::shared_ptr; using Shape = std::vector; using ShapeList = std::vector; enum MatchCountPriority { MATCH_COUNT_PRIORITY_BEGIN = 0, MATCH_FORMAT_COUNT = MATCH_COUNT_PRIORITY_BEGIN, MATCH_DTYPE_COUNT, MATCH_NZ_FORMAT_COUNT, MATCH_5D_FORMAT_COUNT, MATCH_OUTPUT_DTYPE_COUNT, MATCH_COUNT_PRIORITY_END }; const std::set kOpFormatList = { kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ}; bool IsShapeMatchFormat(const std::vector &shape, const std::string &format) { // if format is default,it remarkes support all format if (kOpFormatList.find(format) == kOpFormatList.end()) { MS_EXCEPTION(ArgumentError) << "got the unknow format " << format; } if (format == kOpFormat_DEFAULT) { return true; } // if shape size is 0,the shape will be a scalar if (shape.empty()) { return true; } if (shape.size() > kShapeSupportFormatMap.size()) { return false; } if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { return shape[shape.size() - 1] % 16 != 0 && shape[shape.size() - 2] % 16 != 0; } return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); } bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { MS_EXCEPTION_IF_NULL(kernel_node); auto check_function = [](const std::vector &shape, const std::string &format) -> bool { if (!IsShapeMatchFormat(shape, format)) { return false; } for (auto shape_value : shape) { if (shape_value == 0) { MS_EXCEPTION(ArgumentError) << "dimension size of the tensor shape should be a positive integer, but got [" << shape_value << "]"; } } return true; }; for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { return false; } } for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { return false; } } return true; } bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { MS_EXCEPTION_IF_NULL(cnode); // Check input data type for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { AnfNodePtr cur_input = cnode->input(input_index + 1); MS_EXCEPTION_IF_NULL(cur_input); TypeId input_origin_type; if (cur_input->isa() && AnfAlgo::IsParameterWeight(cur_input->cast())) { // weight input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); } else { // feature map input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); } if (input_origin_type == kTypeUnknown) { continue; } if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { return false; } } // Check output data type for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) { return false; } } return true; } /** * compare too vector by priority,select a better vector,like compare too num,first compare highest num location,if * equal then next num location * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] */ bool PriorityChooseItem(const std::vector &cur_item, std::vector *best_item) { MS_EXCEPTION_IF_NULL(best_item); if (cur_item.size() != best_item->size()) { MS_LOG(ERROR) << "item size should be same!"; return false; } // Update the best_item by comparing the cur_item and best_item for (size_t i = 0; i < cur_item.size(); i++) { if (cur_item[i] > best_item->at(i)) { *best_item = cur_item; return true; } else if (cur_item[i] == best_item->at(i)) { continue; } else { return false; } } return false; } void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node, std::vector *const cur_kernelinfo_match_counts) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts); if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { MS_EXCEPTION(ArgumentError) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; } for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { AnfNodePtr input_anf_node = kernel_node->input(input_index + 1); MS_EXCEPTION_IF_NULL(input_anf_node); // if a input parameter is a weight with default format, the input shouldn't participate the judge if (input_anf_node->isa()) { auto para = input_anf_node->cast(); if (AnfAlgo::IsParameterWeight(para) && AnfAlgo::GetOutputDeviceDataType(para, 0) == kTypeUnknown) { continue; } } if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; } if (kernel_build_info.GetInputDeviceType(input_index) == AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT]++; } if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_FRAC_NZ) { (*cur_kernelinfo_match_counts)[MATCH_NZ_FORMAT_COUNT]++; } if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_NC1HWC0) { (*cur_kernelinfo_match_counts)[MATCH_5D_FORMAT_COUNT]++; } } for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { // cal count of same output dtype between abstract and kernel info if (kernel_build_info.GetOutputDeviceType(output_index) == AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; } } } void SetKernelBuildInfo(KernelBuilderPtr builder) { builder->SetFusionType(kernel::OPAQUE); builder->SetKernelType(AUTO_DIFF_KERNEL); builder->SetProcessor(kernel::AICORE); } void test_select(const CNodePtr &kernel_node, std::vector> kernel_info_list) { std::vector most_match_counts = {-1, -1, -1, -1, -1}; int selected_index = -1; for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { std::vector cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; if (!IsValidKernelInfo(kernel_node, *(kernel_info_list[info_index]))) { continue; } if (!MatchInferOutputDataType(kernel_node, *(kernel_info_list[info_index]))) { continue; } std::shared_ptr kernel_info_ptr = kernel_info_list[info_index]; UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); // Currently the selection policy is the match format count first, and then is datatype counts. if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { selected_index = SizeToInt(info_index); } } if (selected_index == -1) { MS_EXCEPTION(NotExistsError) << "" << kernel_node->DebugString() << " Cannot find valid kernel Info !"; } auto index = IntToSize(selected_index); if (index >= kernel_info_list.size()) { MS_EXCEPTION(ArgumentError) << "index outof range"; } std::shared_ptr selected_kernel_info_ptr = kernel_info_list[index]; MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr); AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get()); } void SetParentAbstract(std::vector parent_list, std::vector> shapes, std::vector types) { for (const auto &node : parent_list) { AnfAlgo::SetOutputInferTypeAndShape(types, shapes, node.get()); } } } // namespace class AscendKernelSelctTest : public UT::Common { public: AscendKernelSelctTest() = default; void SetUp() override {} void TearDown() override {} }; TEST_F(AscendKernelSelctTest, TestSelect) { std::vector build_list; std::vector type_list = {kNumberTypeFloat32}; for (size_t i = 0; i <= 4; ++i) { build_list.push_back(std::make_shared()); SetKernelBuildInfo(build_list[i]); build_list[i]->SetInputsDeviceType(type_list); build_list[i]->SetOutputsDeviceType(type_list); } std::vector nd_fmt = {kOpFormat_DEFAULT}; std::vector nz_fmt = {kOpFormat_FRAC_NZ}; auto anf_graph = std::make_shared(); // 16's multiple should not chose format NZ Shape nd_shapes = {2, 32, 224, 224}; Shape nz_shapes = {3, 3, 5, 5}; auto add_value = NewValueNode(prim::kPrimTensorAdd); auto a_node = anf_graph->NewCNode(std::vector{add_value}); auto b_node = anf_graph->NewCNode(std::vector{add_value}); std::vector parent_list = {add_value, a_node, b_node}; auto c_node = anf_graph->NewCNode(parent_list); // a b // \ / // c // a & b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}} // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3,224, 224}} // set a & b's info SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list); // set abstract c AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nd_shapes}, c_node.get()); // set format of kernel info build_list[0]->SetOutputsFormat(nz_fmt); build_list[1]->SetOutputsFormat(nz_fmt); build_list[2]->SetInputsFormat(std::vector{nd_fmt[0], nd_fmt[0]}); build_list[3]->SetInputsFormat(std::vector{nz_fmt[0], nz_fmt[0]}); build_list[2]->SetInputsDeviceType(std::vector{kNumberTypeFloat32, kNumberTypeFloat32}); build_list[3]->SetInputsDeviceType(std::vector{kNumberTypeFloat32, kNumberTypeFloat32}); build_list[2]->SetOutputsFormat(nd_fmt); build_list[3]->SetOutputsFormat(nz_fmt); std::vector select_info_list; // set select info list select_info_list.emplace_back(build_list[2]->Build()); select_info_list.emplace_back(build_list[3]->Build()); // set device info for a & b AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get()); AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get()); test_select(c_node, select_info_list); EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT); EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_DEFAULT); // set a & b's info // a b // \ / // c // a: kernel_info:{output_format:{5d},dtype:{kNumberTypeFloat32}} // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} // b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}} // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}} // set a & b's info SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list); // set abstract c AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nz_shapes}, c_node.get()); // set format of kernel info build_list[0]->SetOutputsFormat(std::vector{kOpFormat_NC1HWC0}); build_list[1]->SetOutputsFormat(nz_fmt); build_list[2]->SetInputsFormat(std::vector{kOpFormat_NC1HWC0, nd_fmt[0]}); build_list[3]->SetInputsFormat(std::vector{nd_fmt[0], nz_fmt[0]}); build_list[2]->SetInputsDeviceType(std::vector{kNumberTypeFloat32, kNumberTypeFloat32}); build_list[3]->SetInputsDeviceType(std::vector{kNumberTypeFloat32, kNumberTypeFloat32}); build_list[2]->SetOutputsFormat(nd_fmt); build_list[3]->SetOutputsFormat(nz_fmt); // set select info list select_info_list.emplace_back(build_list[2]->Build()); select_info_list.emplace_back(build_list[3]->Build()); // set device info for a & b AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get()); AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get()); test_select(c_node, select_info_list); EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT); EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_FRAC_NZ); } } // namespace ascend } // namespace device } // namespace mindspore