diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index d36369be7f..412c3d8dc8 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -156,6 +156,7 @@ TensorPtr GetCNodeOutputStubTensor(const KernelWithIndex &kernel_with_index, } void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr &kernel, + const std::map &cnode_refcount, std::map *op_output_info) { MS_EXCEPTION_IF_NULL(single_op_graph); MS_EXCEPTION_IF_NULL(kernel); @@ -163,6 +164,10 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr OutputTensorInfo output_tensor_info; size_t out_idx = 0; for (const auto &output : single_op_graph->outputs()) { + KernelWithIndex kernel_with_index = std::make_pair(kernel, out_idx++); + if (cnode_refcount.find(kernel_with_index) == cnode_refcount.end()) { + continue; + } const auto &output_kernel_with_index = AnfAlgo::VisitKernel(output, 0); const auto &output_node = output_kernel_with_index.first; const auto &output_index = output_kernel_with_index.second; @@ -187,7 +192,6 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr device::DeviceAddressPtr device_address = std::make_shared(nullptr, 0, output_format, output_type); stub_output_tensor->set_device_address(device_address); - KernelWithIndex kernel_with_index = std::make_pair(kernel, out_idx++); output_tensor_info.output_stub_tensor = stub_output_tensor; output_tensor_info.is_weight = !dynamic_cast(output_node->kernel_info())->is_feature_map(); (*op_output_info)[kernel_with_index] = output_tensor_info; @@ -700,7 +704,8 @@ void AscendSession::GetOpInputStubTensors(const CNodePtr &cnode, const std::map< } void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map ¶meter_index, - const std::vector &graph_inputs) { + const std::vector &graph_inputs, + const std::map &cnode_refcount) { if (built_graph_id_.find(graph_id) != built_graph_id_.end()) { return; } @@ -722,13 +727,13 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::mapsecond; - GenOpOutputStubTensor(single_op_graph, kernel, &op_output_info); + GenOpOutputStubTensor(single_op_graph, kernel, cnode_refcount, &op_output_info); continue; } const auto &single_op_graph = PreBuildOp(op_run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask); MS_EXCEPTION_IF_NULL(single_op_graph); - GenOpOutputStubTensor(single_op_graph, kernel, &op_output_info); + GenOpOutputStubTensor(single_op_graph, kernel, cnode_refcount, &op_output_info); opt::HideNopNode(single_op_graph.get()); // The graph info could have been changed in PreBuildOp const GraphInfo &new_graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors); diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 8fb895a500..8a36cd43bb 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -59,7 +59,8 @@ class AscendSession : public SessionBasic { void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector *input_tensors, VectorRef *outputs, const std::vector &tensors_mask) override; void BuildOpsInGraph(const GraphId &graph_id, const std::map ¶meter_index, - const std::vector &graph_inputs) override; + const std::vector &graph_inputs, + const std::map &cnode_refcount) override; private: // compile child graph when session have multiple child graphs diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index a5b7d35581..c457de1dcd 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -2156,9 +2156,9 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector< GetParameterIndex(kernel_graph.get(), inputs, ¶meter_index); std::map>> output_indexes; CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes); - std::map cnode_ref; - GetRefCount(kernel_graph.get(), &cnode_ref); - BuildOpsInGraph(graph_id, parameter_index, inputs); + std::map cnode_refcount; + GetRefCount(kernel_graph.get(), &cnode_refcount); + BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount); std::map op_output_map; for (const auto &kernel : kernel_graph->execution_order()) { @@ -2177,8 +2177,8 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector< input_tensor_info.input_tensors_mask); // Handle inputs and outputs of current op - HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map); - HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs); + HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map); + HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs); } MS_LOG(INFO) << "Finish!"; } diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 06585e2851..4f16acd2ae 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -178,7 +178,8 @@ class SessionBasic : public std::enable_shared_from_this { const std::vector &tensors_mask) {} void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map ¶meter_index, - const std::vector &graph_inputs) {} + const std::vector &graph_inputs, + const std::map &cnode_refcount) {} void RunInfer(NotNull func_graph, const std::vector &inputs); virtual void SetSummaryNodes(KernelGraph *graph);