|
|
|
@@ -17,7 +17,7 @@ |
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h" |
|
|
|
|
|
|
|
#include <string> |
|
|
|
#include <unordered_set> |
|
|
|
#include <set> |
|
|
|
#include <utility> |
|
|
|
#include <vector> |
|
|
|
|
|
|
|
@@ -37,8 +37,8 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
std::unordered_set<PrimitivePtr> GetExpandOps() { |
|
|
|
std::unordered_set<PrimitivePtr> expand_ops = { |
|
|
|
std::vector<PrimitivePtr> GetExpandOps() { |
|
|
|
std::vector<PrimitivePtr> expand_ops = { |
|
|
|
prim::kPrimSquare, |
|
|
|
prim::kPrimGeLUGrad, |
|
|
|
#if ENABLE_D |
|
|
|
@@ -55,7 +55,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { |
|
|
|
prim::kPrimReduceMean, |
|
|
|
prim::kPrimMaximumGrad, |
|
|
|
prim::kPrimMinimumGrad, |
|
|
|
prim::kPrimGkDropout, |
|
|
|
prim::kPrimDropout, |
|
|
|
prim::kPrimDropoutGrad, |
|
|
|
prim::kPrimSoftmax, |
|
|
|
prim::kPrimLayerNorm, |
|
|
|
@@ -68,41 +68,20 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { |
|
|
|
prim::kPrimAssignAdd, |
|
|
|
#endif |
|
|
|
}; |
|
|
|
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); }; |
|
|
|
auto &flags = context::GraphKernelFlags::GetInstance(); |
|
|
|
auto &enable_ops_only = flags.enable_expand_ops_only; |
|
|
|
if (!enable_ops_only.empty()) { |
|
|
|
expand_ops.clear(); |
|
|
|
std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()), |
|
|
|
new_prim); |
|
|
|
} else { |
|
|
|
auto &enable_ops = flags.enable_expand_ops; |
|
|
|
auto &disable_ops = flags.disable_expand_ops; |
|
|
|
if (!enable_ops.empty()) { |
|
|
|
std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim); |
|
|
|
} |
|
|
|
if (!disable_ops.empty()) { |
|
|
|
for (auto iter = expand_ops.begin(); iter != expand_ops.end();) { |
|
|
|
if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) { |
|
|
|
expand_ops.erase(iter++); |
|
|
|
} else { |
|
|
|
++iter; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
const auto &flags = context::GraphKernelFlags::GetInstance(); |
|
|
|
OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops); |
|
|
|
return expand_ops; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool GraphKernelExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { |
|
|
|
bool DefaultExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { |
|
|
|
DumpOption dump_option; |
|
|
|
dump_option.extract_opinfo_from_anfnode = true; |
|
|
|
kernel::AkgKernelJsonGenerator json_generator(dump_option); |
|
|
|
return json_generator.CollectJson(node, kernel_json); |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { |
|
|
|
FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) { |
|
|
|
nlohmann::json kernel_json; |
|
|
|
if (!ExpandJsonInfo(node, &kernel_json)) { |
|
|
|
MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump(); |
|
|
|
@@ -121,7 +100,6 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { |
|
|
|
} |
|
|
|
std::string kernel_desc_str = py::cast<std::string>(ret); |
|
|
|
if (kernel_desc_str.empty()) { |
|
|
|
MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// decode json to func_graph. |
|
|
|
@@ -129,10 +107,10 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { |
|
|
|
return JsonDescToAnf(kernel_desc_str, ori_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { |
|
|
|
void DefaultExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { |
|
|
|
const auto &ori_parameter = func_graph->parameters(); |
|
|
|
auto todos = TopoSort(func_graph->get_return()); |
|
|
|
std::unordered_set<AnfNodePtr> used_param; |
|
|
|
std::set<AnfNodePtr> used_param; |
|
|
|
for (auto node : todos) { |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
used_param.insert(node); |
|
|
|
@@ -152,21 +130,39 @@ void GraphKernelExpander::EliminateRedundantParameters(const FuncGraphPtr &func_ |
|
|
|
*inputs = std::move(new_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph, |
|
|
|
const FuncGraphPtr &new_func_graph, const CNodePtr &node) { |
|
|
|
std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end()); |
|
|
|
AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) { |
|
|
|
auto func_graph = old_node->func_graph(); |
|
|
|
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end()); |
|
|
|
AnfNodePtrList kernel_nodes; |
|
|
|
AnfNodePtrList outputs; |
|
|
|
EliminateRedundantParameters(new_func_graph, &inputs); |
|
|
|
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); |
|
|
|
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); |
|
|
|
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); |
|
|
|
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); |
|
|
|
MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() |
|
|
|
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_node)); |
|
|
|
MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope() |
|
|
|
<< " with: " << graph_kernel_node->fullname_with_scope(); |
|
|
|
return graph_kernel_node; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto new_func_graph = CreateExpandFuncGraph(cnode); |
|
|
|
if (new_func_graph == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(cnode))); |
|
|
|
auto graph_kernel_node = CreateExpandGraphKernel(new_func_graph, cnode); |
|
|
|
if (AnfAlgo::GetOutputTensorNum(node) != AnfAlgo::GetOutputTensorNum(graph_kernel_node)) { |
|
|
|
MS_LOG(ERROR) << "The output num of composite node (" << AnfAlgo::GetOutputTensorNum(graph_kernel_node) |
|
|
|
<< ") does not match the original basic node (" << AnfAlgo::GetOutputTensorNum(node) << ")." |
|
|
|
<< node->fullname_with_scope(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return graph_kernel_node; |
|
|
|
} |
|
|
|
|
|
|
|
bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { |
|
|
|
bool changed = false; |
|
|
|
auto todos = TopoSort(func_graph->get_return()); |
|
|
|
@@ -181,35 +177,31 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope(); |
|
|
|
auto new_func_graph = CreateExpandFuncGraph(node); |
|
|
|
if (new_func_graph == nullptr) { |
|
|
|
auto new_node = GetExpander(node)->Run(node); |
|
|
|
if (new_node == nullptr) { |
|
|
|
MS_LOG(INFO) << "Skipped node: " << node->fullname_with_scope(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
mng->AddFuncGraph(new_func_graph); |
|
|
|
|
|
|
|
auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); |
|
|
|
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); |
|
|
|
|
|
|
|
// replace origin node. |
|
|
|
(void)mng->Replace(node, graph_kernel_node); |
|
|
|
(void)mng->Replace(node, new_node); |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
return changed; |
|
|
|
} |
|
|
|
|
|
|
|
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { |
|
|
|
expand_ops_ = GetExpandOps(); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
if (expand_ops_.count(prim::kPrimGkDropout) > 0) { |
|
|
|
std::shared_ptr<Pass> pass = std::make_shared<opt::SubstituteDropout>(); |
|
|
|
pass->Run(func_graph); |
|
|
|
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) { |
|
|
|
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = { |
|
|
|
{prim::kPrimDropout, std::make_shared<DropoutExpander>()}, |
|
|
|
}; |
|
|
|
for (auto &e : expanders) { |
|
|
|
if (IsPrimitiveCNode(node, e.first)) { |
|
|
|
return e.second; |
|
|
|
} |
|
|
|
} |
|
|
|
return std::make_shared<DefaultExpander>(); |
|
|
|
} |
|
|
|
|
|
|
|
auto mng = func_graph->manager(); |
|
|
|
if (mng == nullptr) { |
|
|
|
mng = Manage(func_graph, true); |
|
|
|
func_graph->set_manager(mng); |
|
|
|
} |
|
|
|
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { |
|
|
|
expand_ops_ = GetExpandOps(); |
|
|
|
return DoExpand(func_graph); |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
|