|
|
|
@@ -288,6 +288,22 @@ bool ExistSummaryNode(const KernelGraph *graph) { |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
const auto &node_inputs = cnode->inputs(); |
|
|
|
for (size_t i = 1; i < node_inputs.size(); ++i) { |
|
|
|
if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
GraphId SessionBasic::graph_sum_ = 0; |
|
|
|
@@ -354,8 +370,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> parameters; |
|
|
|
std::vector<AnfNodePtr> pre_graph_out = {node}; |
|
|
|
if (IgnoreCreateParameterForMakeTuple(node)) { |
|
|
|
pre_graph_out.clear(); |
|
|
|
} |
|
|
|
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive |
|
|
|
if (!AnfAlgo::IsRealKernel(node)) { |
|
|
|
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) { |
|
|
|
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); |
|
|
|
} |
|
|
|
auto valid_inputs = graph->MutableValidInputs(); |
|
|
|
@@ -431,7 +450,8 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool |
|
|
|
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; |
|
|
|
auto parameters = CreateParameterFromTuple(anf, valid_input, graph); |
|
|
|
if (parameters.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "No parameter exist!!"; |
|
|
|
MS_LOG(INFO) << "Empty parameter from cnode"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (parameters.size() == 1) { |
|
|
|
return parameters[0]; |
|
|
|
@@ -505,11 +525,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K |
|
|
|
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); |
|
|
|
continue; |
|
|
|
} else if (optimize_control_depend) { |
|
|
|
cnode_inputs.push_back(NewValueNode(MakeValue(input_idx))); |
|
|
|
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); |
|
|
|
} else { |
|
|
|
*from_other_graph = true; |
|
|
|
// the input node is a cnode from other graph |
|
|
|
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); |
|
|
|
if (parameter_from_cnode == nullptr) { |
|
|
|
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); |
|
|
|
} |
|
|
|
cnode_inputs.push_back(parameter_from_cnode); |
|
|
|
(*other_graph_cnode)[anf] = parameter_from_cnode; |
|
|
|
} |
|
|
|
@@ -878,7 +901,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
auto tensor = inputs[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
auto input_node = input_nodes[i]; |
|
|
|
if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); |
|
|
|
if (ms_context->execution_mode() == kPynativeMode || |
|
|
|
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) { |
|
|
|
|