Merge pull request !28313 from zjun/fix_forward_releasetags/v1.6.0
| @@ -1364,37 +1364,41 @@ void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithInde | |||
| } | |||
| void SessionBasic::GetForwardOpOutputRefCount(const KernelGraph *graph, | |||
| std::multiset<std::string> *forward_op_output_tensor_id) { | |||
| std::map<std::string, size_t> *forward_op_output_tensor_id) { | |||
| if (!pynative::PynativeExecutor::GetInstance()->grad_executor()->grad_is_running()) { | |||
| return; | |||
| } | |||
| const auto &graph_value_nodes = graph->graph_value_nodes(); | |||
| std::vector<tensor::TensorPtr> tensor_value_list; | |||
| for (const auto &v : graph_value_nodes) { | |||
| const auto &value = GetValueNode(v); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| TensorValueToTensor(value, &tensor_value_list); | |||
| } | |||
| const auto &forward_op_output_id = pynative::PynativeExecutor::GetInstance()->grad_executor()->forward_op_output_id(); | |||
| MS_LOG(DEBUG) << "Total forward op out put size " << forward_op_output_id.size(); | |||
| for (const auto &t : tensor_value_list) { | |||
| if (forward_op_output_id.find(t->id()) != forward_op_output_id.end()) { | |||
| (*forward_op_output_tensor_id).emplace(t->id()); | |||
| for (const auto &kernel : graph->execution_order()) { | |||
| const auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel); | |||
| for (size_t i = 1; i <= input_tensor_num; ++i) { | |||
| const auto &input = kernel->input(i); | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); | |||
| auto real_input = kernel_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(real_input); | |||
| if (real_input->isa<ValueNode>()) { | |||
| const auto &tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (forward_op_output_id.find(tensor->id()) != forward_op_output_id.end()) { | |||
| (*forward_op_output_tensor_id)[tensor->id()] += 1; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Total value nodes in graph size " << graph_value_nodes.size() << ", total tensor value node size " | |||
| << tensor_value_list.size() << ", forward op output tensor size " | |||
| << forward_op_output_tensor_id->size(); | |||
| MS_LOG(DEBUG) << "Forward op output tensor in bprop graph size " << forward_op_output_tensor_id->size(); | |||
| } | |||
| void SessionBasic::ReleaseForwardOpOutput(const std::vector<tensor::TensorPtr> &input_tensors, | |||
| std::multiset<std::string> *forward_op_output_tensor_id) { | |||
| std::map<std::string, size_t> *forward_op_output_tensor_id) { | |||
| MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id); | |||
| for (const auto &tensor : input_tensors) { | |||
| auto it = forward_op_output_tensor_id->find(tensor->id()); | |||
| if (it != forward_op_output_tensor_id->end()) { | |||
| tensor->set_device_address(nullptr); | |||
| forward_op_output_tensor_id->erase(it); | |||
| if (--(it->second) == 0) { | |||
| tensor->set_device_address(nullptr); | |||
| forward_op_output_tensor_id->erase(it); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -2425,7 +2429,7 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector< | |||
| graph_output_info.graph_outputs = outputs; | |||
| CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes); | |||
| std::map<KernelWithIndex, size_t> cnode_refcount; | |||
| std::multiset<std::string> forward_op_output_tensor_id; | |||
| std::map<std::string, size_t> forward_op_output_tensor_id; | |||
| GetRefCount(kernel_graph.get(), &cnode_refcount); | |||
| GetForwardOpOutputRefCount(kernel_graph.get(), &forward_op_output_tensor_id); | |||
| BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount); | |||
| @@ -196,9 +196,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| VectorRef *const outputs, | |||
| std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes); | |||
| void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count); | |||
| void GetForwardOpOutputRefCount(const KernelGraph *graph, std::multiset<std::string> *forward_op_output_tensor_id); | |||
| void GetForwardOpOutputRefCount(const KernelGraph *graph, std::map<std::string, size_t> *forward_op_output_tensor_id); | |||
| void ReleaseForwardOpOutput(const std::vector<tensor::TensorPtr> &input_tensors, | |||
| std::multiset<std::string> *forward_op_output_tensor_id); | |||
| std::map<std::string, size_t> *forward_op_output_tensor_id); | |||
| void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count, | |||
| std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map); | |||
| @@ -638,7 +638,7 @@ void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<Kern | |||
| } | |||
| void GraphCompiler::CalculateForwardOpOutputCount(const KernelGraphPtr &graph, | |||
| std::multiset<std::string> *forward_op_output_tensor_id) const { | |||
| std::map<std::string, size_t> *forward_op_output_tensor_id) const { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| forward_op_output_tensor_id->clear(); | |||
| session_->GetForwardOpOutputRefCount(graph.get(), forward_op_output_tensor_id); | |||
| @@ -652,7 +652,7 @@ void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernel | |||
| } | |||
| void GraphCompiler::UpdateForwardOpOutputRefCount(const std::vector<tensor::TensorPtr> &input_tensor, | |||
| std::multiset<std::string> *forward_op_output_tensor_id) const { | |||
| std::map<std::string, size_t> *forward_op_output_tensor_id) const { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id); | |||
| session_->ReleaseForwardOpOutput(input_tensor, forward_op_output_tensor_id); | |||
| @@ -146,7 +146,7 @@ class GraphCompiler { | |||
| // Calculate forward op output ref count of PyNative back graph. | |||
| void CalculateForwardOpOutputCount(const KernelGraphPtr &graph, | |||
| std::multiset<std::string> *forward_op_output_tensor_id) const; | |||
| std::map<std::string, size_t> *forward_op_output_tensor_id) const; | |||
| // Update ref count of PyNative back propagation operators. | |||
| void UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index, | |||
| @@ -155,7 +155,7 @@ class GraphCompiler { | |||
| // Update forward op output ref count of PyNative back graph. | |||
| void UpdateForwardOpOutputRefCount(const std::vector<tensor::TensorPtr> &input_tensor, | |||
| std::multiset<std::string> *forward_op_output_tensor_id) const; | |||
| std::map<std::string, size_t> *forward_op_output_tensor_id) const; | |||
| // Handle single op output tensor and recover output of original complete kernel graph. | |||
| void RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs, | |||
| @@ -195,7 +195,7 @@ class MindRTBackend : public Backend { | |||
| std::map<GraphId, std::map<KernelWithIndex, size_t>> cnode_ref_counts_; | |||
| // Cache forward op output value node tensor ref count of kernels for back propagation graph in PyNative mode. | |||
| std::multiset<std::string> forward_op_output_tensor_id_; | |||
| std::map<std::string, size_t> forward_op_output_tensor_id_; | |||
| FuncGraph *root_graph_; | |||
| GraphPartitionPtr graph_partition_; | |||