| @@ -25,6 +25,18 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| namespace { | |||
| // sort format according the number of occurrences. | |||
| bool cmp_format_num(const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) { | |||
| if (a.second != b.second) { | |||
| return a.second > b.second; | |||
| } else if (a.first == kOpFormat_DEFAULT) { | |||
| return a.second + 1 > b.second; | |||
| } else if (b.first == kOpFormat_DEFAULT) { | |||
| return a.second > b.second + 1; | |||
| } | |||
| return a.second > b.second; | |||
| } | |||
| TypeId GetPrimitivePrecision(const CNodePtr &cnode) { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||
| @@ -44,6 +56,7 @@ TypeId GetPrimitivePrecision(const CNodePtr &cnode) { | |||
| return except_type; | |||
| } | |||
| } // namespace | |||
| void ResetKernelBuildInfo(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| @@ -185,15 +198,12 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, | |||
| auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| if (!input_kernel_node->isa<Parameter>()) { | |||
| auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i); | |||
| ++all_input_formats[pre_format]; | |||
| ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)]; | |||
| continue; | |||
| } | |||
| auto para = input_kernel_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(para); | |||
| if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { | |||
| auto pre_format = AnfAlgo::GetOutputFormat(para, 0); | |||
| ++all_input_formats[pre_format]; | |||
| ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)]; | |||
| continue; | |||
| } | |||
| *use_same_format = false; | |||
| @@ -207,17 +217,8 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, | |||
| for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { | |||
| pairs.push_back(std::make_pair(iter->first, iter->second)); | |||
| } | |||
| auto cmp_func = [](const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) { | |||
| if (a.second != b.second) { | |||
| return a.second > b.second; | |||
| } else if (a.first == kOpFormat_DEFAULT) { | |||
| return a.second + 1 > b.second; | |||
| } else if (b.first == kOpFormat_DEFAULT) { | |||
| return a.second > b.second + 1; | |||
| } | |||
| return a.second > b.second; | |||
| }; | |||
| std::sort(pairs.begin(), pairs.end(), cmp_func); | |||
| std::sort(pairs.begin(), pairs.end(), cmp_format_num); | |||
| *default_format = pairs.begin()->first; | |||
| } | |||
| @@ -237,10 +238,9 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, | |||
| } | |||
| } | |||
| void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, | |||
| const std::string &default_format, bool use_same_format, | |||
| std::vector<std::string> *graph_input_format, | |||
| std::vector<TypeId> *graph_input_type) { | |||
| void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, | |||
| const std::string &default_format, bool use_same_format, | |||
| std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type) { | |||
| MS_EXCEPTION_IF_NULL(graph_input_format); | |||
| MS_EXCEPTION_IF_NULL(graph_input_type); | |||
| // We set same format to all inputs of graph kernel subgraph, and process this latter. | |||
| @@ -338,21 +338,22 @@ void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_ | |||
| } | |||
| } | |||
| void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list, | |||
| const std::vector<AnfNodePtr> &input_list, const FuncGraphManagerPtr &mng, | |||
| const std::string &default_format, std::vector<std::string> *graph_input_format, | |||
| std::vector<TypeId> *graph_input_type) { | |||
| void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, | |||
| const FuncGraphManagerPtr &mng, const std::string &default_format, | |||
| std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type, | |||
| std::vector<bool> *need_update) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| MS_EXCEPTION_IF_NULL(graph_input_format); | |||
| MS_EXCEPTION_IF_NULL(graph_input_type); | |||
| // update graph input format and dtype use inner ops. | |||
| MS_EXCEPTION_IF_NULL(need_update); | |||
| // check graph input format and dtype use inner ops. | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (graph_input_format->size() != input_num) { | |||
| if (graph_input_format->size() != input_num || graph_input_type->size() != input_num || | |||
| need_update->size() != input_num) { | |||
| MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() | |||
| << "], [%" << graph_input_format->size() << "] != [%" << input_num << "]"; | |||
| << "], [" << graph_input_format->size() << "] != [" << input_num << "]"; | |||
| } | |||
| std::vector<bool> need_update(input_num, false); | |||
| auto &node_users = mng->node_users(); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto &input = input_list[i]; | |||
| @@ -372,36 +373,48 @@ void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNo | |||
| << kernel_node->DebugString() | |||
| << "] selected different format. we use defult: " << default_format; | |||
| (*graph_input_format)[i] = default_format; | |||
| need_update[i] = true; | |||
| (*need_update)[i] = true; | |||
| } | |||
| if (kernel_node->input(i + 1)->isa<Parameter>()) { | |||
| auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)); | |||
| if (user_dtype != (*graph_input_type)[i]) { | |||
| TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); | |||
| MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" | |||
| << kernel_node->DebugString() | |||
| << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); | |||
| (*graph_input_type)[i] = default_dtype; | |||
| need_update[i] = true; | |||
| } | |||
| if (kernel_node->input(i + 1)->isa<Parameter>() || | |||
| AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) { | |||
| continue; | |||
| } | |||
| TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); | |||
| MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" | |||
| << kernel_node->DebugString() | |||
| << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); | |||
| (*graph_input_type)[i] = default_dtype; | |||
| (*need_update)[i] = true; | |||
| } | |||
| } | |||
| } | |||
| void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list, | |||
| const std::vector<AnfNodePtr> &input_list, const std::vector<bool> &need_update, | |||
| const std::vector<std::string> &graph_input_format, | |||
| const std::vector<TypeId> &graph_input_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| // update graph input format and dtype use inner ops. | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (graph_input_format.size() != input_num || graph_input_type.size() != input_num || | |||
| need_update.size() != input_num) { | |||
| MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() | |||
| << "], [" << graph_input_format.size() << "] != [" << input_num << "]"; | |||
| } | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| if (!need_update[i]) { | |||
| continue; | |||
| } | |||
| need_update[i] = false; | |||
| MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() | |||
| << "] to: " << (*graph_input_format)[i]; | |||
| << "] to: " << graph_input_format[i]; | |||
| MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() | |||
| << "] to: " << TypeIdLabel((*graph_input_type)[i]); | |||
| << "] to: " << TypeIdLabel(graph_input_type[i]); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| std::vector<std::string> outputs_format = {(*graph_input_format)[i]}; | |||
| std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]}; | |||
| std::vector<std::string> outputs_format = {graph_input_format[i]}; | |||
| std::vector<TypeId> outputs_device_type = {graph_input_type[i]}; | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputsDeviceType(outputs_device_type); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); | |||
| @@ -487,8 +500,8 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func | |||
| std::vector<std::string> graph_input_format; | |||
| std::vector<TypeId> graph_input_type; | |||
| UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, | |||
| &graph_input_type); | |||
| UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, | |||
| &graph_input_type); | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| @@ -502,8 +515,10 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func | |||
| kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | |||
| // update graph input format and dtype use inner ops. | |||
| UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format, | |||
| &graph_input_type); | |||
| std::vector<bool> need_update(AnfAlgo::GetInputTensorNum(kernel_node), false); | |||
| CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type, | |||
| &need_update); | |||
| UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type); | |||
| // set fix_precision for kernel when the me prim has fix_precision attr | |||
| UpdateKernelInfo(node_list); | |||