|
|
@@ -494,54 +494,52 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern |
|
|
return make_tuple; |
|
|
return make_tuple; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, |
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) { |
|
|
|
|
|
|
|
|
void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) { |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(other_graph_cnode); |
|
|
|
|
|
// get primitive of old node |
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs; |
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_inputs); |
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
if (prim != nullptr) { |
|
|
if (prim != nullptr) { |
|
|
// push attr to inputs[0] of new cnode |
|
|
// push attr to inputs[0] of new cnode |
|
|
cnode_inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))); |
|
|
|
|
|
|
|
|
cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))); |
|
|
} else { |
|
|
} else { |
|
|
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); |
|
|
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); |
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
auto new_fg = BasicClone(fg); |
|
|
auto new_fg = BasicClone(fg); |
|
|
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg)); |
|
|
|
|
|
|
|
|
cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg)); |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs, |
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(other_graph_cnode); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_inputs); |
|
|
auto origin_inputs = cnode->inputs(); |
|
|
auto origin_inputs = cnode->inputs(); |
|
|
bool optimize_depend = false; |
|
|
|
|
|
bool optimize_control_depend = false; |
|
|
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && |
|
|
|
|
|
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) { |
|
|
|
|
|
optimize_depend = true; |
|
|
|
|
|
} |
|
|
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) { |
|
|
|
|
|
optimize_control_depend = true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && |
|
|
|
|
|
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>(); |
|
|
|
|
|
bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3; |
|
|
// if has multiple depends,only select first depend as parameter |
|
|
// if has multiple depends,only select first depend as parameter |
|
|
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { |
|
|
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { |
|
|
auto anf = origin_inputs[input_idx]; |
|
|
auto anf = origin_inputs[input_idx]; |
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
// anf has been created before |
|
|
// anf has been created before |
|
|
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { |
|
|
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { |
|
|
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); |
|
|
|
|
|
|
|
|
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); |
|
|
continue; |
|
|
continue; |
|
|
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { |
|
|
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { |
|
|
cnode_inputs.push_back((*other_graph_cnode)[anf]); |
|
|
|
|
|
|
|
|
cnode_inputs->push_back((*other_graph_cnode)[anf]); |
|
|
continue; |
|
|
continue; |
|
|
} else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) { |
|
|
} else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) { |
|
|
// if input is a value node, |
|
|
// if input is a value node, |
|
|
auto new_value_node = CreateNewValueNode(anf, graph); |
|
|
auto new_value_node = CreateNewValueNode(anf, graph); |
|
|
if (new_value_node != nullptr) { |
|
|
if (new_value_node != nullptr) { |
|
|
cnode_inputs.emplace_back(new_value_node); |
|
|
|
|
|
|
|
|
cnode_inputs->emplace_back(new_value_node); |
|
|
} |
|
|
} |
|
|
continue; |
|
|
continue; |
|
|
} else if (anf->isa<Parameter>()) { |
|
|
} else if (anf->isa<Parameter>()) { |
|
|
auto new_parameter = CreateNewParameterFromParameter(anf, graph); |
|
|
auto new_parameter = CreateNewParameterFromParameter(anf, graph); |
|
|
cnode_inputs.push_back(new_parameter); |
|
|
|
|
|
|
|
|
cnode_inputs->push_back(new_parameter); |
|
|
if (GetGraphIdByNode(anf) == kInvalidGraphId) { |
|
|
if (GetGraphIdByNode(anf) == kInvalidGraphId) { |
|
|
graph->FrontBackendlMapAdd(anf, new_parameter); |
|
|
graph->FrontBackendlMapAdd(anf, new_parameter); |
|
|
} else { |
|
|
} else { |
|
|
@@ -549,20 +547,31 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, |
|
|
} |
|
|
} |
|
|
continue; |
|
|
continue; |
|
|
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) { |
|
|
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) { |
|
|
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); |
|
|
|
|
|
|
|
|
cnode_inputs->push_back(origin_inputs[kRealInputIndexInDepend]); |
|
|
continue; |
|
|
continue; |
|
|
} else if (optimize_control_depend) { |
|
|
} else if (optimize_control_depend) { |
|
|
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); |
|
|
|
|
|
|
|
|
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); |
|
|
} else { |
|
|
} else { |
|
|
// the input node is a cnode from other graph |
|
|
// the input node is a cnode from other graph |
|
|
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph); |
|
|
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph); |
|
|
if (parameter_from_cnode == nullptr) { |
|
|
if (parameter_from_cnode == nullptr) { |
|
|
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); |
|
|
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); |
|
|
} |
|
|
} |
|
|
cnode_inputs.push_back(parameter_from_cnode); |
|
|
|
|
|
|
|
|
cnode_inputs->push_back(parameter_from_cnode); |
|
|
(*other_graph_cnode)[anf] = parameter_from_cnode; |
|
|
(*other_graph_cnode)[anf] = parameter_from_cnode; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, |
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(other_graph_cnode); |
|
|
|
|
|
// get primitive of old node |
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs; |
|
|
|
|
|
GetCNodeInfo(cnode, &cnode_inputs); |
|
|
|
|
|
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode); |
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); |
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); |
|
|
auto new_cnode = graph->NewCNode(cnode_inputs); |
|
|
auto new_cnode = graph->NewCNode(cnode_inputs); |
|
|
TraceManager::EndTrace(); |
|
|
TraceManager::EndTrace(); |
|
|
@@ -593,6 +602,42 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra |
|
|
return partial_node; |
|
|
return partial_node; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs = { |
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; |
|
|
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(attr_input); |
|
|
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); |
|
|
|
|
|
auto switch_cnode = cnode_input->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(switch_cnode); |
|
|
|
|
|
if (cnode->inputs().size() < 2) { |
|
|
|
|
|
cnode_inputs = switch_cnode->inputs(); |
|
|
|
|
|
return cnode_inputs; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), |
|
|
|
|
|
switch_cnode->input(kFirstDataInputIndex)}; |
|
|
|
|
|
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { |
|
|
|
|
|
auto node = switch_cnode->input(index); |
|
|
|
|
|
// there is real input in call, should put it to true and false branch in switch |
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { |
|
|
|
|
|
auto partial_node = node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_node); |
|
|
|
|
|
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs(); |
|
|
|
|
|
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); |
|
|
|
|
|
auto new_partial = graph->NewCNode(partial_inputs); |
|
|
|
|
|
switch_inputs.emplace_back(new_partial); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (switch_inputs.size() < kSwitchInputSize) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize; |
|
|
|
|
|
} |
|
|
|
|
|
auto switch_node = graph->NewCNode(switch_inputs); |
|
|
|
|
|
cnode_inputs.emplace_back(switch_node); |
|
|
|
|
|
return cnode_inputs; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
@@ -618,32 +663,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & |
|
|
}); |
|
|
}); |
|
|
return cnode_inputs; |
|
|
return cnode_inputs; |
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { |
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { |
|
|
auto switch_cnode = cnode_input->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(switch_cnode); |
|
|
|
|
|
if (cnode->inputs().size() < 2) { |
|
|
|
|
|
cnode_inputs = switch_cnode->inputs(); |
|
|
|
|
|
return cnode_inputs; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), |
|
|
|
|
|
switch_cnode->input(kFirstDataInputIndex)}; |
|
|
|
|
|
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { |
|
|
|
|
|
auto node = switch_cnode->input(index); |
|
|
|
|
|
// there is real input in call, should put it to true and false branch in switch |
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { |
|
|
|
|
|
auto partial_node = node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_node); |
|
|
|
|
|
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs(); |
|
|
|
|
|
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); |
|
|
|
|
|
auto new_partial = graph->NewCNode(partial_inputs); |
|
|
|
|
|
switch_inputs.emplace_back(new_partial); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (switch_inputs.size() < kSwitchInputSize) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize; |
|
|
|
|
|
} |
|
|
|
|
|
auto switch_node = graph->NewCNode(switch_inputs); |
|
|
|
|
|
cnode_inputs.emplace_back(switch_node); |
|
|
|
|
|
return cnode_inputs; |
|
|
|
|
|
|
|
|
return CreateCallSwitchInputs(cnode, graph); |
|
|
} |
|
|
} |
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; |
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; |
|
|
} |
|
|
} |
|
|
|