|
|
|
@@ -92,6 +92,51 @@ GraphId GetDistinctionLabel(const KernelGraphPtr &graph) { |
|
|
|
// else use first node of execution order as label |
|
|
|
return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get()); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> graph_inputs = graph->inputs(); |
|
|
|
auto valid_inputs = graph->ValidInputs(); |
|
|
|
size_t real_args_size = 0; |
|
|
|
std::vector<BaseRef> real_args = {}; |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
if (utils::isa<AnfNodePtr>(args[i])) { |
|
|
|
auto tmp_args = AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}); |
|
|
|
for (auto &real_arg : tmp_args) { |
|
|
|
auto anf_node = utils::cast<AnfNodePtr>(real_arg); |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
auto abstract = anf_node->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
// create multiple parameters if is a tuple output real kernel |
|
|
|
if (abstract->isa<abstract::AbstractTuple>() && |
|
|
|
!AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { |
|
|
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); |
|
|
|
real_args_size += tuple_abstract->size(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
real_args_size += 1; |
|
|
|
real_args.push_back(real_arg); |
|
|
|
} |
|
|
|
} else { |
|
|
|
real_args_size += 1; |
|
|
|
real_args.push_back(args[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
if (graph_inputs.size() != valid_inputs.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size() |
|
|
|
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; |
|
|
|
} |
|
|
|
if (real_args_size != graph_inputs.size()) { |
|
|
|
for (size_t j = 0; j < valid_inputs.size(); j++) { |
|
|
|
if (valid_inputs[j]) { |
|
|
|
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() |
|
|
|
<< " not equal"; |
|
|
|
} |
|
|
|
return real_args; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { |
|
|
|
@@ -763,38 +808,26 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { |
|
|
|
UpdateGraphOrder(g); |
|
|
|
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); |
|
|
|
auto valid_inputs = to_graph->ValidInputs(); |
|
|
|
size_t real_args_size = 0; |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
real_args_size += AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}).size(); |
|
|
|
} |
|
|
|
if (real_args_size != graph_inputs.size()) { |
|
|
|
for (size_t j = 0; j < valid_inputs.size(); j++) { |
|
|
|
if (valid_inputs[j]) { |
|
|
|
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() |
|
|
|
<< " not equal"; |
|
|
|
} |
|
|
|
auto real_args = GetRealArgs(to_graph, args); |
|
|
|
size_t input_index = 0; |
|
|
|
if (graph_inputs.size() != valid_inputs.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size() |
|
|
|
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
for (size_t i = 0; i < real_args.size(); i++) { |
|
|
|
if (input_index >= graph_inputs.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); |
|
|
|
} |
|
|
|
if (utils::isa<AnfNodePtr>(args[i])) { |
|
|
|
if (utils::isa<AnfNodePtr>(real_args[i])) { |
|
|
|
// arg is a anf node |
|
|
|
for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) { |
|
|
|
if (!valid_inputs[input_index]) { |
|
|
|
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto real_arg = utils::cast<AnfNodePtr>(real_args[i]); |
|
|
|
auto real_arg_output_num = AnfAlgo::GetOutputTensorNum(real_arg); |
|
|
|
if (!AnfAlgo::CheckPrimitiveType(real_arg, prim::kPrimTupleGetItem) && real_arg_output_num > 1) { |
|
|
|
input_index += real_arg_output_num; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (valid_inputs[input_index]) { |
|
|
|
SetChildGraphParameter(real_arg, graph_inputs[input_index]); |
|
|
|
input_index++; |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); |
|
|
|
} |
|
|
|
input_index++; |
|
|
|
} else if (utils::isa<ValuePtr>(args[i])) { |
|
|
|
auto value = utils::cast<ValuePtr>(args[i]); |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
|