|
|
|
@@ -451,6 +451,126 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K |
|
|
|
return new_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> cnode_inputs; |
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(attr_input); |
|
|
|
if (IsValueNode<FuncGraph>(attr_input)) { |
|
|
|
// create primitive of cnode:call |
|
|
|
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))}; |
|
|
|
// create a ValueNode<KernelGraph> as input of cnode:call |
|
|
|
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { |
|
|
|
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)); |
|
|
|
} else { |
|
|
|
auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph); |
|
|
|
if (new_value_node != nullptr) { |
|
|
|
cnode_inputs.emplace_back(new_value_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (attr_input->isa<CNode>()) { |
|
|
|
// create primitive of cnode:call(switch) |
|
|
|
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))}; |
|
|
|
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { |
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); |
|
|
|
auto prim = GetCNodePrimitive(cnode_input); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if (prim->name() != kSwitchOpName) { |
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be switch."; |
|
|
|
} |
|
|
|
cnode_inputs.emplace_back(cnode_input); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() |
|
|
|
<< ", but input[0] has not been created."; |
|
|
|
} |
|
|
|
} else { |
|
|
|
// get primitive of old node |
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
// push attr to inputs[0] of new cnode |
|
|
|
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))}; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { |
|
|
|
auto anf = cnode->inputs()[input_idx]; |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
// anf has been created before |
|
|
|
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { |
|
|
|
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); |
|
|
|
continue; |
|
|
|
} else if (anf->isa<ValueNode>()) { |
|
|
|
if (!IsValueNode<FuncGraph>(anf)) { |
|
|
|
// if input is a common value node, |
|
|
|
auto new_value_node = CreateNewValueNode(anf, graph); |
|
|
|
if (new_value_node != nullptr) { |
|
|
|
cnode_inputs.emplace_back(new_value_node); |
|
|
|
} |
|
|
|
} else { |
|
|
|
// if input is a ValueNode<FuncGraph> |
|
|
|
auto new_value_node = CreateValueNodeKernelGraph(anf, graph); |
|
|
|
if (new_value_node != nullptr) { |
|
|
|
cnode_inputs.emplace_back(new_value_node); |
|
|
|
} |
|
|
|
} |
|
|
|
continue; |
|
|
|
} else if (anf->isa<Parameter>()) { |
|
|
|
auto new_parameter = CreateNewParameter(anf, graph); |
|
|
|
cnode_inputs.push_back(new_parameter); |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; |
|
|
|
} |
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); |
|
|
|
auto new_cnode = graph->NewCNode(cnode_inputs); |
|
|
|
TraceManager::EndTrace(); |
|
|
|
return new_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
auto value_node = anf->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); |
|
|
|
MS_EXCEPTION_IF_NULL(sub_func_graph); |
|
|
|
if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph."; |
|
|
|
} |
|
|
|
auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph]; |
|
|
|
|
|
|
|
ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph); |
|
|
|
new_value_node->set_abstract(value_node->abstract()); |
|
|
|
// create new kernel_info of new value_node |
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>(); |
|
|
|
kernel_info->SetFeatureMapFlag(false); |
|
|
|
new_value_node->set_kernel_info(kernel_info); |
|
|
|
// create kernel_build_info for new value node |
|
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); |
|
|
|
AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get()); |
|
|
|
|
|
|
|
graph->FrontBackendlMapAdd(anf, new_value_node); |
|
|
|
graph->AddValueNodeToGraph(new_value_node); |
|
|
|
|
|
|
|
return new_value_node; |
|
|
|
} |
|
|
|
|
|
|
|
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
if (!anf->isa<Parameter>()) { |
|
|
|
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; |
|
|
|
} |
|
|
|
|
|
|
|
auto graph_inputs = graph->MutableInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs); |
|
|
|
|
|
|
|
auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
graph->FrontBackendlMapAdd(anf, new_parameter); |
|
|
|
|
|
|
|
return new_parameter; |
|
|
|
} |
|
|
|
|
|
|
|
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode; |
|
|
|
auto graph = NewKernelGraph(); |
|
|
|
@@ -494,7 +614,69 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con |
|
|
|
return graph; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &) { return nullptr; } |
|
|
|
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) { |
|
|
|
MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph."; |
|
|
|
return front_backend_graph_map_[func_graph]; |
|
|
|
} |
|
|
|
auto node_list = TopoSort(func_graph->get_return()); |
|
|
|
auto graph = NewKernelGraph(); |
|
|
|
front_backend_graph_map_[func_graph] = graph; |
|
|
|
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); |
|
|
|
|
|
|
|
for (const auto &node : node_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode"; |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
|
|
// recurse control ops: call, partial |
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(attr_input); |
|
|
|
if (IsValueNode<FuncGraph>(attr_input)) { |
|
|
|
// recurse call subgraph |
|
|
|
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input); |
|
|
|
ConstructKernelGraph(sub_func_graph); |
|
|
|
} else if (IsValueNode<Primitive>(attr_input)) { |
|
|
|
auto prim = GetCNodePrimitive(node); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if (prim->name() == kPartialOpName) { |
|
|
|
// recurse partial subgraph |
|
|
|
auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph_node); |
|
|
|
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node); |
|
|
|
ConstructKernelGraph(sub_func_graph); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// create a new cnode object |
|
|
|
auto new_cnode = CreateNewCNode(cnode, graph.get()); |
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode); |
|
|
|
new_cnode->set_abstract(cnode->abstract()); |
|
|
|
new_cnode->set_scope(cnode->scope()); |
|
|
|
graph->FrontBackendlMapAdd(node, new_cnode); |
|
|
|
|
|
|
|
// set original return to kernel_graph |
|
|
|
if (IsPrimitive(new_cnode->input(kAnfPrimitiveIndex), prim::kPrimReturn)) { |
|
|
|
graph->set_return(new_cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_); |
|
|
|
FuncGraphManagerPtr manager = context_->manager(); |
|
|
|
if (manager) { |
|
|
|
manager->AddFuncGraph(graph); |
|
|
|
graph->set_manager(manager); |
|
|
|
} |
|
|
|
graph->SetExecOrderByDefault(); |
|
|
|
return graph; |
|
|
|
} |
|
|
|
|
|
|
|
// run graph steps |
|
|
|
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, |
|
|
|
|