|
|
|
@@ -20,25 +20,22 @@ |
|
|
|
|
|
|
|
#include "c_ops/primitive_c.h" |
|
|
|
#include "ir/manager.h" |
|
|
|
#include "ir/param_info.h" |
|
|
|
#include "backend/kernel_compiler/common_utils.h" |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "common/trans.h" |
|
|
|
#include "utils/config_manager.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "backend/session/executor.h" |
|
|
|
#include "backend/session/executor_manager.h" |
|
|
|
#include "backend/optimizer/common/common_backend_optimization.h" |
|
|
|
#include "backend/optimizer/common/helper.h" |
|
|
|
#include "runtime/device/kernel_runtime_manager.h" |
|
|
|
#include "utils/ms_utils.h" |
|
|
|
#include "ir/dtype.h" |
|
|
|
#include "ir/anf.h" |
|
|
|
#include "ir/func_graph_cloner.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "debug/anf_ir_dump.h" |
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) |
|
|
|
#include "ps/worker.h" |
|
|
|
#include "ps/common.h" |
|
|
|
#include "ps/util.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
@@ -665,8 +662,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & |
|
|
|
MS_EXCEPTION_IF_NULL(attr_input); |
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); |
|
|
|
if (cnode_input == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() |
|
|
|
<< ", but input[0] has not been created."; |
|
|
|
MS_LOG(ERROR) << "CNode input[0] is CNode:" << attr_input->DebugString() << ", but input[0] has not been created."; |
|
|
|
return {}; |
|
|
|
} |
|
|
|
// if the node is partial, insert the inputs of partial to the call |
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) { |
|
|
|
@@ -682,7 +679,9 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { |
|
|
|
return CreateCallSwitchInputs(cnode, graph); |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; |
|
|
|
MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString() |
|
|
|
<< "must be partial or switch."; |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
|
@@ -752,6 +751,10 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { |
|
|
|
// 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph |
|
|
|
// 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created |
|
|
|
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); |
|
|
|
if (cnode_inputs.empty()) { |
|
|
|
MS_LOG_ERROR << "Create switch or partial failed, cnode:" << cnode->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} else { |
|
|
|
// get primitive of old node |
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
|
@@ -877,14 +880,16 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con |
|
|
|
return graph; |
|
|
|
} |
|
|
|
|
|
|
|
void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) { |
|
|
|
bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
// create a new cnode object |
|
|
|
auto new_cnode = CreateNewCNode(cnode, graph.get()); |
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode); |
|
|
|
auto new_cnode = CreateNewCNode(cnode, graph); |
|
|
|
if (new_cnode == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
new_cnode->set_abstract(cnode->abstract()); |
|
|
|
std::string fullname; |
|
|
|
if (cnode->input(kAnfPrimitiveIndex)->isa<CNode>()) { |
|
|
|
@@ -898,6 +903,7 @@ void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr |
|
|
|
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { |
|
|
|
graph->set_return(new_cnode); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph, |
|
|
|
@@ -909,11 +915,10 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
front_backend_graph_map_[func_graph] = graph; |
|
|
|
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); |
|
|
|
|
|
|
|
bool is_trace_back = false; |
|
|
|
for (const auto &node : node_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); |
|
|
|
// Create parameter |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
auto graph_inputs = graph->MutableInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs); |
|
|
|
@@ -921,25 +926,28 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
graph->FrontBackendlMapAdd(node, new_parameter); |
|
|
|
continue; |
|
|
|
} else if (node->isa<ValueNode>()) { |
|
|
|
} |
|
|
|
// Create value node |
|
|
|
if (node->isa<ValueNode>()) { |
|
|
|
// Create common value node |
|
|
|
if (!IsValueNode<FuncGraph>(node)) { |
|
|
|
// if input is a common value node, |
|
|
|
(void)CreateNewValueNode(node, graph.get()); |
|
|
|
} else { |
|
|
|
// if input is a ValueNode<FuncGraph> |
|
|
|
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); |
|
|
|
if (front_backend_graph_map_.find(child_graph) == front_backend_graph_map_.end()) { |
|
|
|
(void)ConstructKernelGraph(child_graph, all_out_graph); |
|
|
|
} |
|
|
|
(void)CreateValueNodeKernelGraph(node, graph.get()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Create child kernel graph according ValueNode<FuncGraph> |
|
|
|
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); |
|
|
|
if (front_backend_graph_map_.find(child_graph) == front_backend_graph_map_.end()) { |
|
|
|
(void)ConstructKernelGraph(child_graph, all_out_graph); |
|
|
|
} |
|
|
|
(void)CreateValueNodeKernelGraph(node, graph.get()); |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
CreateCNodeKernelGraph(node, graph); |
|
|
|
} |
|
|
|
// Create cnode |
|
|
|
if (!CreateCNodeOfKernelGraph(node, graph.get())) { |
|
|
|
DumpIR("contruct_kernel_graph_fail.ir", func_graph); |
|
|
|
MS_LOG_EXCEPTION << "construct func graph " << func_graph->ToString() << "fail!"; |
|
|
|
} |
|
|
|
} |
|
|
|
// if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. |
|
|
|
graph->set_output_null(is_trace_back); |
|
|
|
AddParameterToGraphInputs(func_graph->parameters(), graph.get()); |
|
|
|
graph->SetExecOrderByDefault(); |
|
|
|
if (ExistSummaryNode(graph.get())) { |
|
|
|
|