|
|
|
@@ -395,8 +395,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, |
|
|
|
KernelGraph *graph) { |
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> parameters; |
|
|
|
@@ -418,7 +417,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr |
|
|
|
parameter->set_abstract(abstract); |
|
|
|
auto new_parameter = graph->NewParameter(parameter); |
|
|
|
parameters.push_back(new_parameter); |
|
|
|
valid_inputs->push_back(valid_input); |
|
|
|
valid_inputs->push_back(true); |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
}; |
|
|
|
for (const auto &out_node : pre_graph_out) { |
|
|
|
@@ -442,8 +441,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr |
|
|
|
return parameters; |
|
|
|
} |
|
|
|
|
|
|
|
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, |
|
|
|
KernelGraph *graph) { |
|
|
|
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
if (!anf->isa<Parameter>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; |
|
|
|
@@ -471,15 +469,15 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf |
|
|
|
TraceManager::EndTrace(); |
|
|
|
} |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
valid_inputs->push_back(valid_input); |
|
|
|
valid_inputs->push_back(true); |
|
|
|
return new_parameter; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { |
|
|
|
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; |
|
|
|
auto parameters = CreateParameterFromTuple(anf, valid_input, graph); |
|
|
|
auto parameters = CreateParameterFromTuple(anf, graph); |
|
|
|
if (parameters.empty()) { |
|
|
|
MS_LOG(INFO) << "Empty parameter from cnode"; |
|
|
|
return nullptr; |
|
|
|
@@ -495,14 +493,11 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool |
|
|
|
return make_tuple; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, |
|
|
|
bool *from_other_graph, |
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(from_other_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(other_graph_cnode); |
|
|
|
*from_other_graph = false; |
|
|
|
// get primitive of old node |
|
|
|
std::vector<AnfNodePtr> cnode_inputs; |
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
|
@@ -544,7 +539,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K |
|
|
|
} |
|
|
|
continue; |
|
|
|
} else if (anf->isa<Parameter>()) { |
|
|
|
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); |
|
|
|
auto new_parameter = CreateNewParameterFromParameter(anf, graph); |
|
|
|
cnode_inputs.push_back(new_parameter); |
|
|
|
if (GetGraphIdByNode(anf) == kInvalidGraphId) { |
|
|
|
graph->FrontBackendlMapAdd(anf, new_parameter); |
|
|
|
@@ -558,9 +553,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K |
|
|
|
} else if (optimize_control_depend) { |
|
|
|
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); |
|
|
|
} else { |
|
|
|
*from_other_graph = true; |
|
|
|
// the input node is a cnode from other graph |
|
|
|
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); |
|
|
|
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph); |
|
|
|
if (parameter_from_cnode == nullptr) { |
|
|
|
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); |
|
|
|
} |
|
|
|
@@ -587,7 +581,7 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra |
|
|
|
} else { |
|
|
|
KernelGraphPtr kernel_graph = NewKernelGraph(); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get()); |
|
|
|
auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), kernel_graph.get()); |
|
|
|
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())); |
|
|
|
auto return_node = kernel_graph->NewCNode({primitive, parameter}); |
|
|
|
kernel_graph->set_return(return_node); |
|
|
|
@@ -806,7 +800,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con |
|
|
|
auto graph = NewKernelGraph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); |
|
|
|
size_t from_other_graph_depend_num = 0; |
|
|
|
for (const auto &node : lst) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); |
|
|
|
@@ -816,16 +809,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
// create a new cnode object |
|
|
|
bool from_other_graph = false; |
|
|
|
// only first depend from other graph can create |
|
|
|
bool valid_input = true; |
|
|
|
if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { |
|
|
|
valid_input = false; |
|
|
|
} |
|
|
|
auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { |
|
|
|
from_other_graph_depend_num++; |
|
|
|
} |
|
|
|
auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode); |
|
|
|
new_cnode->set_abstract(cnode->abstract()); |
|
|
|
new_cnode->set_scope(cnode->scope()); |
|
|
|
|