From cbb243a20795616b19f29c7f96c4e328578c49e4 Mon Sep 17 00:00:00 2001 From: chenfei Date: Thu, 18 Mar 2021 17:31:01 +0800 Subject: [PATCH] print graph output --- .../ccsrc/backend/session/session_basic.cc | 64 +++++++------------ .../ccsrc/backend/session/session_basic.h | 2 +- 2 files changed, 23 insertions(+), 43 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 7cb64d7570..f896501cf6 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -691,47 +691,39 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const } } -std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { +AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); - std::vector parameters; - std::vector pre_graph_out = {node}; if (IgnoreCreateParameterForMakeTuple(node)) { - pre_graph_out.clear(); + return nullptr; } + auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); + auto parameters = AnfAlgo::GetAllOutput(new_parameter); + std::vector pre_graph_out = {node}; // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) { pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState}); } - auto valid_inputs = graph->MutableValidInputs(); - MS_EXCEPTION_IF_NULL(valid_inputs); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { - auto new_parameter = graph->NewParameter(abstract); - parameters.push_back(new_parameter); + for (const auto ¶meter : parameters) { + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); valid_inputs->push_back(true); - graph_inputs->push_back(new_parameter); - }; + graph_inputs->push_back(parameter); + } + size_t param_index = 0; for (const auto &out_node : pre_graph_out) { - MS_EXCEPTION_IF_NULL(out_node); - auto abstract = out_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; - for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { - create_parameter((*tuple_abstract)[output_idx]); + size_t output_size = AnfAlgo::GetOutputTensorNum(out_node); + for (size_t i = 0; i < output_size; i++) { + if (param_index >= parameters.size()) { + MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString() + << ",out_node:" << out_node->DebugString(); } - continue; + InitInternalOutputParameter(out_node, parameters[param_index++]); } - // create single parameter if is a abstract real kernel - create_parameter(out_node->abstract()); - InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]); } - return parameters; + return new_parameter; } ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) { @@ -770,20 +762,7 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; - auto parameters = CreateParameterFromTuple(anf, graph); - if (parameters.empty()) { - MS_LOG(INFO) << "Empty parameter from cnode"; - return nullptr; - } - if (parameters.size() == 1) { - return parameters[0]; - } - std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; - (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); - auto make_tuple = graph->NewCNode(make_tuple_input); - MS_EXCEPTION_IF_NULL(make_tuple); - MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; - return make_tuple; + return CreateParameterFromTuple(anf, graph); } void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector *cnode_inputs) { @@ -884,6 +863,7 @@ CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr KernelGraphPtr kernel_graph = NewKernelGraph(); MS_EXCEPTION_IF_NULL(kernel_graph); auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get()); + MS_EXCEPTION_IF_NULL(parameter); parameter->set_abstract(cnode->abstract()); auto primitive = NewValueNode(std::make_shared(prim::kPrimReturn->name())); auto return_node = kernel_graph->NewCNode({primitive, parameter}); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index eb4fd4921a..6e78b65b09 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -214,7 +214,7 @@ class SessionBasic : public std::enable_shared_from_this { const std::vector &graph_inputs, InputTensorInfo *input_tensor_info); // create a new kernel graph and update the graph sum KernelGraphPtr NewKernelGraph(); - std::vector CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph); + AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph); virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph); ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);