|
|
|
@@ -26,6 +26,7 @@ |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "kernel/common_utils.h" |
|
|
|
#include "utils/context/ms_context.h" |
|
|
|
#include "pre_activate/common/helper.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
@@ -50,16 +51,11 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::vector<CNodePtr> do_mask_node_list; |
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto node_map = manager->node_users(); |
|
|
|
auto iter = node_map.find(node); |
|
|
|
if (iter == node_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Cannot find the node " << node->DebugString() << " in the graph manager!"; |
|
|
|
} |
|
|
|
auto gen_mask_output_nodes = iter->second; |
|
|
|
for (const auto &output_node : gen_mask_output_nodes) { |
|
|
|
auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(gen_mask_output_nodes); |
|
|
|
for (const auto &output_node : *gen_mask_output_nodes) { |
|
|
|
if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { |
|
|
|
MS_EXCEPTION_IF_NULL(output_node.first); |
|
|
|
auto output_cnode = output_node.first->cast<CNodePtr>(); |
|
|
|
do_mask_node_list.push_back(output_cnode); |
|
|
|
} |
|
|
|
@@ -76,11 +72,12 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con |
|
|
|
<< " GenMask " << node->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
RectifyKernelInfo(do_mask_node_list); |
|
|
|
RectifyKernelInfo(do_mask_node_list, graph); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const { |
|
|
|
void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, |
|
|
|
const FuncGraphPtr &graph) const { |
|
|
|
std::map<std::string, size_t> format_counter; |
|
|
|
std::string special_format; |
|
|
|
std::string convert_format; |
|
|
|
@@ -94,17 +91,6 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_ |
|
|
|
} else { |
|
|
|
format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; |
|
|
|
} |
|
|
|
// if has two or more special format we need change all domask's format to default that can avoid insert more |
|
|
|
// transdata |
|
|
|
if (format_counter.size() > 2) { |
|
|
|
convert_format = kOpFormat_DEFAULT; |
|
|
|
break; |
|
|
|
} |
|
|
|
if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() && |
|
|
|
special_format != do_mask_data_format) { |
|
|
|
convert_format = kOpFormat_DEFAULT; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (format_counter.size() == 1) { |
|
|
|
return; |
|
|
|
@@ -112,17 +98,23 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_ |
|
|
|
if (convert_format.empty()) { |
|
|
|
convert_format = GetConvertFormat(format_counter); |
|
|
|
} |
|
|
|
RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format); |
|
|
|
RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph); |
|
|
|
} |
|
|
|
|
|
|
|
std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const { |
|
|
|
std::string convert_format; |
|
|
|
const size_t counter = 0; |
|
|
|
std::string convert_format = kOpFormat_DEFAULT; |
|
|
|
size_t counter = 0; |
|
|
|
if (format_counter.size() > 2) { |
|
|
|
return kOpFormat_DEFAULT; |
|
|
|
} |
|
|
|
if (format_counter.size() == 2 && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) { |
|
|
|
return kOpFormat_DEFAULT; |
|
|
|
} |
|
|
|
for (const auto &iter : format_counter) { |
|
|
|
if (counter < iter.second) { |
|
|
|
convert_format = iter.first; |
|
|
|
} |
|
|
|
if (counter == iter.second && kHWSpecialFormatSet.find(convert_format) == kHWSpecialFormatSet.end()) { |
|
|
|
counter = iter.second; |
|
|
|
} else if (counter == iter.second && kHWSpecialFormatSet.find(iter.first) != kHWSpecialFormatSet.end()) { |
|
|
|
convert_format = iter.first; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -130,13 +122,17 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string |
|
|
|
} |
|
|
|
|
|
|
|
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, |
|
|
|
const std::string &format) const { |
|
|
|
const std::string &format, |
|
|
|
const FuncGraphPtr &graph) const { |
|
|
|
for (const auto &do_mask : do_mask_node_list) { |
|
|
|
auto builder = |
|
|
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); |
|
|
|
builder->SetInputFormat(format, 0); |
|
|
|
builder->SetOutputFormat(format, 0); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); |
|
|
|
if (AnfAlgo::GetInputFormat(do_mask, 0) != format) { |
|
|
|
auto builder = |
|
|
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); |
|
|
|
builder->SetInputFormat(format, 0); |
|
|
|
builder->SetOutputFormat(format, 0); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); |
|
|
|
ReSelecChildNodeKernelInfo(do_mask, graph); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -159,5 +155,30 @@ AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const Anf |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto output_node_list = GetRealNodeUsedList(graph, cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(output_node_list); |
|
|
|
for (const auto &out_node_info : *output_node_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(out_node_info.first); |
|
|
|
auto out_node = out_node_info.first->cast<CNodePtr>(); |
|
|
|
if (AnfAlgo::IsRealKernel(out_node_info.first)) { |
|
|
|
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); |
|
|
|
kernel_selecter->SelectKernel(out_node); |
|
|
|
auto new_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); |
|
|
|
MS_EXCEPTION_IF_NULL(new_build_info); |
|
|
|
MS_EXCEPTION_IF_NULL(ori_build_info); |
|
|
|
if ((*new_build_info) != (*ori_build_info)) { |
|
|
|
ReSelecChildNodeKernelInfo(out_node, graph); |
|
|
|
} |
|
|
|
} else if (AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() || |
|
|
|
AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) { |
|
|
|
ReSelecChildNodeKernelInfo(out_node, graph); |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |