|
|
|
@@ -25,6 +25,18 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace device { |
|
|
|
namespace ascend { |
|
|
|
namespace { |
|
|
|
// sort format according the number of occurrences. |
|
|
|
bool cmp_format_num(const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) { |
|
|
|
if (a.second != b.second) { |
|
|
|
return a.second > b.second; |
|
|
|
} else if (a.first == kOpFormat_DEFAULT) { |
|
|
|
return a.second + 1 > b.second; |
|
|
|
} else if (b.first == kOpFormat_DEFAULT) { |
|
|
|
return a.second > b.second + 1; |
|
|
|
} |
|
|
|
return a.second > b.second; |
|
|
|
} |
|
|
|
|
|
|
|
TypeId GetPrimitivePrecision(const CNodePtr &cnode) { |
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
|
@@ -44,6 +56,7 @@ TypeId GetPrimitivePrecision(const CNodePtr &cnode) { |
|
|
|
|
|
|
|
return except_type; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void ResetKernelBuildInfo(const CNodePtr &kernel_node) { |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
@@ -185,15 +198,12 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, |
|
|
|
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; |
|
|
|
MS_EXCEPTION_IF_NULL(input_kernel_node); |
|
|
|
if (!input_kernel_node->isa<Parameter>()) { |
|
|
|
auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i); |
|
|
|
++all_input_formats[pre_format]; |
|
|
|
++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)]; |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto para = input_kernel_node->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(para); |
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { |
|
|
|
auto pre_format = AnfAlgo::GetOutputFormat(para, 0); |
|
|
|
++all_input_formats[pre_format]; |
|
|
|
++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)]; |
|
|
|
continue; |
|
|
|
} |
|
|
|
*use_same_format = false; |
|
|
|
@@ -207,17 +217,8 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, |
|
|
|
for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { |
|
|
|
pairs.push_back(std::make_pair(iter->first, iter->second)); |
|
|
|
} |
|
|
|
auto cmp_func = [](const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) { |
|
|
|
if (a.second != b.second) { |
|
|
|
return a.second > b.second; |
|
|
|
} else if (a.first == kOpFormat_DEFAULT) { |
|
|
|
return a.second + 1 > b.second; |
|
|
|
} else if (b.first == kOpFormat_DEFAULT) { |
|
|
|
return a.second > b.second + 1; |
|
|
|
} |
|
|
|
return a.second > b.second; |
|
|
|
}; |
|
|
|
std::sort(pairs.begin(), pairs.end(), cmp_func); |
|
|
|
|
|
|
|
std::sort(pairs.begin(), pairs.end(), cmp_format_num); |
|
|
|
*default_format = pairs.begin()->first; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -237,10 +238,9 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, |
|
|
|
const std::string &default_format, bool use_same_format, |
|
|
|
std::vector<std::string> *graph_input_format, |
|
|
|
std::vector<TypeId> *graph_input_type) { |
|
|
|
void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, |
|
|
|
const std::string &default_format, bool use_same_format, |
|
|
|
std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_format); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_type); |
|
|
|
// We set same format to all inputs of graph kernel subgraph, and process this latter. |
|
|
|
@@ -338,21 +338,22 @@ void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_ |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list, |
|
|
|
const std::vector<AnfNodePtr> &input_list, const FuncGraphManagerPtr &mng, |
|
|
|
const std::string &default_format, std::vector<std::string> *graph_input_format, |
|
|
|
std::vector<TypeId> *graph_input_type) { |
|
|
|
void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, |
|
|
|
const FuncGraphManagerPtr &mng, const std::string &default_format, |
|
|
|
std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type, |
|
|
|
std::vector<bool> *need_update) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_format); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_type); |
|
|
|
// update graph input format and dtype use inner ops. |
|
|
|
MS_EXCEPTION_IF_NULL(need_update); |
|
|
|
// check graph input format and dtype use inner ops. |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
if (graph_input_format->size() != input_num) { |
|
|
|
if (graph_input_format->size() != input_num || graph_input_type->size() != input_num || |
|
|
|
need_update->size() != input_num) { |
|
|
|
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() |
|
|
|
<< "], [%" << graph_input_format->size() << "] != [%" << input_num << "]"; |
|
|
|
<< "], [" << graph_input_format->size() << "] != [" << input_num << "]"; |
|
|
|
} |
|
|
|
std::vector<bool> need_update(input_num, false); |
|
|
|
auto &node_users = mng->node_users(); |
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
auto &input = input_list[i]; |
|
|
|
@@ -372,36 +373,48 @@ void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNo |
|
|
|
<< kernel_node->DebugString() |
|
|
|
<< "] selected different format. we use defult: " << default_format; |
|
|
|
(*graph_input_format)[i] = default_format; |
|
|
|
need_update[i] = true; |
|
|
|
(*need_update)[i] = true; |
|
|
|
} |
|
|
|
|
|
|
|
if (kernel_node->input(i + 1)->isa<Parameter>()) { |
|
|
|
auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)); |
|
|
|
if (user_dtype != (*graph_input_type)[i]) { |
|
|
|
TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); |
|
|
|
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" |
|
|
|
<< kernel_node->DebugString() |
|
|
|
<< "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); |
|
|
|
(*graph_input_type)[i] = default_dtype; |
|
|
|
need_update[i] = true; |
|
|
|
} |
|
|
|
if (kernel_node->input(i + 1)->isa<Parameter>() || |
|
|
|
AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); |
|
|
|
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" |
|
|
|
<< kernel_node->DebugString() |
|
|
|
<< "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); |
|
|
|
(*graph_input_type)[i] = default_dtype; |
|
|
|
(*need_update)[i] = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list, |
|
|
|
const std::vector<AnfNodePtr> &input_list, const std::vector<bool> &need_update, |
|
|
|
const std::vector<std::string> &graph_input_format, |
|
|
|
const std::vector<TypeId> &graph_input_type) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
// update graph input format and dtype use inner ops. |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
if (graph_input_format.size() != input_num || graph_input_type.size() != input_num || |
|
|
|
need_update.size() != input_num) { |
|
|
|
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() |
|
|
|
<< "], [" << graph_input_format.size() << "] != [" << input_num << "]"; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
if (!need_update[i]) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
need_update[i] = false; |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() |
|
|
|
<< "] to: " << (*graph_input_format)[i]; |
|
|
|
<< "] to: " << graph_input_format[i]; |
|
|
|
MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() |
|
|
|
<< "] to: " << TypeIdLabel((*graph_input_type)[i]); |
|
|
|
<< "] to: " << TypeIdLabel(graph_input_type[i]); |
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; |
|
|
|
std::vector<std::string> outputs_format = {(*graph_input_format)[i]}; |
|
|
|
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]}; |
|
|
|
std::vector<std::string> outputs_format = {graph_input_format[i]}; |
|
|
|
std::vector<TypeId> outputs_device_type = {graph_input_type[i]}; |
|
|
|
builder.SetOutputsFormat(outputs_format); |
|
|
|
builder.SetOutputsDeviceType(outputs_device_type); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); |
|
|
|
@@ -487,8 +500,8 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func |
|
|
|
|
|
|
|
std::vector<std::string> graph_input_format; |
|
|
|
std::vector<TypeId> graph_input_type; |
|
|
|
UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, |
|
|
|
&graph_input_type); |
|
|
|
UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, |
|
|
|
&graph_input_type); |
|
|
|
|
|
|
|
auto mng = func_graph->manager(); |
|
|
|
if (mng == nullptr) { |
|
|
|
@@ -502,8 +515,10 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func |
|
|
|
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); |
|
|
|
|
|
|
|
// update graph input format and dtype use inner ops. |
|
|
|
UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format, |
|
|
|
&graph_input_type); |
|
|
|
std::vector<bool> need_update(AnfAlgo::GetInputTensorNum(kernel_node), false); |
|
|
|
CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type, |
|
|
|
&need_update); |
|
|
|
UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type); |
|
|
|
|
|
|
|
// set fix_precision for kernel when the me prim has fix_precision attr |
|
|
|
UpdateKernelInfo(node_list); |
|
|
|
|