From: @HulkTang Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -319,6 +319,7 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern | |||
| MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = " | |||
| << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second; | |||
| } | |||
| // Reduce reference count number, when it was reduced to zero, release the useless output of pre node. | |||
| ref_iter->second -= 1; | |||
| if (ref_iter->second != 0) { | |||
| continue; | |||
| @@ -138,6 +138,11 @@ void RunOpsInGraphTask::Run() { | |||
| session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_); | |||
| } | |||
| void CleanUselessTensorsTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| session_->CleanUselessTensorsImpl(useless_tensors_); | |||
| } | |||
| void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } | |||
| void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } | |||
| @@ -371,6 +376,15 @@ void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, | |||
| *outputs = task->outputs_; | |||
| } | |||
| void Executor::CleanUselessTensors(const SessionPtr &session, | |||
| const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) { | |||
| MS_EXCEPTION_IF_NULL(useless_tensors); | |||
| auto task = std::make_shared<CleanUselessTensorsTask>(); | |||
| task->session_ = session; | |||
| task->useless_tensors_ = useless_tensors; | |||
| SyncRunTask(task); | |||
| } | |||
| bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) { | |||
| auto task = std::make_shared<CreateCommGroupTask>(); | |||
| task->group_name_ = group_name; | |||
| @@ -47,7 +47,8 @@ enum TaskType { | |||
| kRunOp, | |||
| kCreateCommGroup, | |||
| kDestroyCommGroup, | |||
| kRunOpsInGraph | |||
| kRunOpsInGraph, | |||
| kCleanUselessTensors | |||
| }; | |||
| class Task { | |||
| @@ -110,6 +111,14 @@ class RunOpsInGraphTask : public Task { | |||
| GraphId graph_id_{0}; | |||
| }; | |||
| class CleanUselessTensorsTask : public Task { | |||
| public: | |||
| CleanUselessTensorsTask() { type_ = kCleanUselessTensors; } | |||
| ~CleanUselessTensorsTask() override = default; | |||
| void Run() override; | |||
| std::shared_ptr<std::vector<tensor::TensorPtr>> useless_tensors_{nullptr}; | |||
| }; | |||
| class RunOpTask : public Task { | |||
| public: | |||
| RunOpTask() { type_ = kRunOp; } | |||
| @@ -165,6 +174,8 @@ class Executor { | |||
| const std::vector<int64_t> &tensors_mask); | |||
| void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs); | |||
| void CleanUselessTensors(const SessionPtr &session, | |||
| const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | |||
| void OnRunGraphFinished(); | |||
| bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | |||
| bool DestroyCommGroup(const std::string &group_name); | |||
| @@ -1656,6 +1656,11 @@ void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tens | |||
| executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs); | |||
| } | |||
| void SessionBasic::CleanUselessTensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->CleanUselessTensors(shared_from_this(), useless_tensors); | |||
| } | |||
| void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs); | |||
| @@ -1704,6 +1709,12 @@ void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &ro | |||
| root_graph->UpdateGraphDynamicAttr(); | |||
| } | |||
| void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) { | |||
| for (const auto &tensor : *useless_tensors) { | |||
| tensor->set_device_address(nullptr); | |||
| } | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | |||
| auto input_nodes = kernel_graph->inputs(); | |||
| @@ -82,6 +82,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask); | |||
| void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | |||
| void CleanUselessTensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | |||
| virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | |||
| @@ -141,6 +142,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| friend class RunGraphTask; | |||
| friend class RunOpTask; | |||
| friend class RunOpsInGraphTask; | |||
| friend class CleanUselessTensorsTask; | |||
| virtual bool IsSupportSummary() { return true; } | |||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| @@ -161,6 +163,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| const std::vector<int64_t> &tensors_mask) {} | |||
| virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) {} | |||
| void CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | |||
| void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | |||
| virtual void SetSummaryNodes(KernelGraph *graph); | |||
| @@ -355,6 +355,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||
| MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first; | |||
| auto input_value_node = NewValueNode(input_value.first); | |||
| input_value_node->set_has_new_value(true); | |||
| input_value_node->set_used_graph_count(para_ref_size); | |||
| manager->Replace(paras[i], input_value_node); | |||
| } | |||
| } | |||
| @@ -305,18 +305,6 @@ class PynativeEliminater : public OptimizerCaller { | |||
| return out; | |||
| } | |||
| void OnlySaveAbstractInfo(const ValueNodePtr &value_node) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto &value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value->isa<tensor::Tensor>()) { | |||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||
| value_node->set_value(MakeValue(new_tensor)); | |||
| } | |||
| } | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | |||
| @@ -363,11 +351,31 @@ class PynativeEliminater : public OptimizerCaller { | |||
| // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} | |||
| PatternNode<AnfNodePtr> binop_grad_common; | |||
| PatternNode<AnfNodePtr> getitem_vnode; | |||
| PatternNode<AnfNodePtr> arg1; | |||
| PatternNode<AnfNodePtr> arg2; | |||
| PatternNode<AnfNodePtr> arg3; | |||
| PatternNode<AnfNodePtr> arg4; | |||
| std::vector<PatternNode<AnfNodePtr>> args(4); | |||
| auto resolve_binop = PPrimitive(prim::kPrimResolve, symbol_str_vnode, binop_grad_common); | |||
| auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); | |||
| if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | |||
| CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { | |||
| for (size_t i = 0; i < 2; i++) { | |||
| auto rep = (args[i]).GetNode(node); | |||
| if (rep != nullptr && rep->isa<ValueNode>()) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto &value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| // when the use count of value node equals to one, it only used in binop_grad_common function | |||
| if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) { | |||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||
| value_node->set_value(new_tensor); | |||
| } | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| // resolve(CommonOPS, getitem)((tensors), 3) | |||
| PatternNode<AnfNodePtr> arg1; | |||
| auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); | |||
| auto pattern2 = PCNode(resolve2, arg, arg1); | |||
| if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && | |||
| @@ -1039,6 +1039,18 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { | |||
| } | |||
| } | |||
| void PynativeExecutor::CleanTensorsInValueNode() { | |||
| // Only need clean in ms backend policy and session should not be nullptr in ms backend. | |||
| if (session == nullptr) { | |||
| return; | |||
| } | |||
| auto useless_tensors = std::make_shared<std::vector<tensor::TensorPtr>>(); | |||
| for (const auto &id_tensor_pair : tensor_id_with_tensor_) { | |||
| std::copy(id_tensor_pair.second.begin(), id_tensor_pair.second.end(), std::back_inserter(*useless_tensors)); | |||
| } | |||
| session->CleanUselessTensors(useless_tensors); | |||
| } | |||
| AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { | |||
| auto &out = graph_info_map_[curr_g_].node_map[obj_id]; | |||
| if (out.second.size() == 1 && out.second[0] == -1) { | |||
| @@ -2039,6 +2051,7 @@ py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) | |||
| MS_LOG(DEBUG) << "Eval run " << backend; | |||
| grad_is_running = true; | |||
| BaseRef value = (*run)(arg_list); | |||
| CleanTensorsInValueNode(); | |||
| grad_is_running = false; | |||
| MS_LOG(DEBUG) << "Run end " << value.ToString(); | |||
| return BaseRefToPyData(value); | |||
| @@ -141,6 +141,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| // Update the abstract and device address info of value node and tensors in bprop graph | |||
| void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); | |||
| void SaveTensorsInValueNode(const ResourcePtr &resource); | |||
| void CleanTensorsInValueNode(); | |||
| // construct grad graph | |||
| void PushCurrentGraphToStack(); | |||
| @@ -381,6 +381,9 @@ class ValueNode : public ANode { | |||
| void set_has_new_value(bool flag) { has_new_value_ = flag; } | |||
| bool has_new_value() const { return has_new_value_; } | |||
| size_t used_graph_count() const { return used_graph_count_; } | |||
| void set_used_graph_count(size_t used_graph_count) { used_graph_count_ = used_graph_count; } | |||
| std::string ToString() const override; | |||
| std::string DebugString(int recursive_level = 1) const override; | |||
| std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } | |||
| @@ -401,6 +404,7 @@ class ValueNode : public ANode { | |||
| private: | |||
| ValuePtr value_; | |||
| bool has_new_value_ = false; | |||
| size_t used_graph_count_{0}; | |||
| }; | |||
| template <typename T> | |||