From: @HulkTang Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -319,6 +319,7 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern | |||||
| MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = " | MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = " | ||||
| << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second; | << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second; | ||||
| } | } | ||||
| // Reduce reference count number, when it was reduced to zero, release the useless output of pre node. | |||||
| ref_iter->second -= 1; | ref_iter->second -= 1; | ||||
| if (ref_iter->second != 0) { | if (ref_iter->second != 0) { | ||||
| continue; | continue; | ||||
| @@ -138,6 +138,11 @@ void RunOpsInGraphTask::Run() { | |||||
| session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_); | session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_); | ||||
| } | } | ||||
| void CleanUselessTensorsTask::Run() { | |||||
| MS_EXCEPTION_IF_NULL(session_); | |||||
| session_->CleanUselessTensorsImpl(useless_tensors_); | |||||
| } | |||||
| void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } | void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } | ||||
| void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } | void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } | ||||
| @@ -371,6 +376,15 @@ void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, | |||||
| *outputs = task->outputs_; | *outputs = task->outputs_; | ||||
| } | } | ||||
| void Executor::CleanUselessTensors(const SessionPtr &session, | |||||
| const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) { | |||||
| MS_EXCEPTION_IF_NULL(useless_tensors); | |||||
| auto task = std::make_shared<CleanUselessTensorsTask>(); | |||||
| task->session_ = session; | |||||
| task->useless_tensors_ = useless_tensors; | |||||
| SyncRunTask(task); | |||||
| } | |||||
| bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) { | bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) { | ||||
| auto task = std::make_shared<CreateCommGroupTask>(); | auto task = std::make_shared<CreateCommGroupTask>(); | ||||
| task->group_name_ = group_name; | task->group_name_ = group_name; | ||||
| @@ -47,7 +47,8 @@ enum TaskType { | |||||
| kRunOp, | kRunOp, | ||||
| kCreateCommGroup, | kCreateCommGroup, | ||||
| kDestroyCommGroup, | kDestroyCommGroup, | ||||
| kRunOpsInGraph | |||||
| kRunOpsInGraph, | |||||
| kCleanUselessTensors | |||||
| }; | }; | ||||
| class Task { | class Task { | ||||
| @@ -110,6 +111,14 @@ class RunOpsInGraphTask : public Task { | |||||
| GraphId graph_id_{0}; | GraphId graph_id_{0}; | ||||
| }; | }; | ||||
| class CleanUselessTensorsTask : public Task { | |||||
| public: | |||||
| CleanUselessTensorsTask() { type_ = kCleanUselessTensors; } | |||||
| ~CleanUselessTensorsTask() override = default; | |||||
| void Run() override; | |||||
| std::shared_ptr<std::vector<tensor::TensorPtr>> useless_tensors_{nullptr}; | |||||
| }; | |||||
| class RunOpTask : public Task { | class RunOpTask : public Task { | ||||
| public: | public: | ||||
| RunOpTask() { type_ = kRunOp; } | RunOpTask() { type_ = kRunOp; } | ||||
| @@ -165,6 +174,8 @@ class Executor { | |||||
| const std::vector<int64_t> &tensors_mask); | const std::vector<int64_t> &tensors_mask); | ||||
| void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| VectorRef *outputs); | VectorRef *outputs); | ||||
| void CleanUselessTensors(const SessionPtr &session, | |||||
| const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | |||||
| void OnRunGraphFinished(); | void OnRunGraphFinished(); | ||||
| bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | ||||
| bool DestroyCommGroup(const std::string &group_name); | bool DestroyCommGroup(const std::string &group_name); | ||||
| @@ -1656,6 +1656,11 @@ void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tens | |||||
| executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs); | executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs); | ||||
| } | } | ||||
| void SessionBasic::CleanUselessTensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) { | |||||
| MS_EXCEPTION_IF_NULL(executor_); | |||||
| executor_->CleanUselessTensors(shared_from_this(), useless_tensors); | |||||
| } | |||||
| void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | ||||
| MS_EXCEPTION_IF_NULL(executor_); | MS_EXCEPTION_IF_NULL(executor_); | ||||
| executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs); | executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs); | ||||
| @@ -1704,6 +1709,12 @@ void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &ro | |||||
| root_graph->UpdateGraphDynamicAttr(); | root_graph->UpdateGraphDynamicAttr(); | ||||
| } | } | ||||
| void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) { | |||||
| for (const auto &tensor : *useless_tensors) { | |||||
| tensor->set_device_address(nullptr); | |||||
| } | |||||
| } | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | ||||
| auto input_nodes = kernel_graph->inputs(); | auto input_nodes = kernel_graph->inputs(); | ||||
| @@ -82,6 +82,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | ||||
| const std::vector<int64_t> &tensors_mask); | const std::vector<int64_t> &tensors_mask); | ||||
| void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | ||||
| void CleanUselessTensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | |||||
| virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | ||||
| @@ -141,6 +142,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| friend class RunGraphTask; | friend class RunGraphTask; | ||||
| friend class RunOpTask; | friend class RunOpTask; | ||||
| friend class RunOpsInGraphTask; | friend class RunOpsInGraphTask; | ||||
| friend class CleanUselessTensorsTask; | |||||
| virtual bool IsSupportSummary() { return true; } | virtual bool IsSupportSummary() { return true; } | ||||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| VectorRef *outputs, | VectorRef *outputs, | ||||
| @@ -161,6 +163,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| const std::vector<int64_t> &tensors_mask) {} | const std::vector<int64_t> &tensors_mask) {} | ||||
| virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| VectorRef *outputs) {} | VectorRef *outputs) {} | ||||
| void CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | |||||
| void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | ||||
| virtual void SetSummaryNodes(KernelGraph *graph); | virtual void SetSummaryNodes(KernelGraph *graph); | ||||
| @@ -355,6 +355,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||||
| MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first; | MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first; | ||||
| auto input_value_node = NewValueNode(input_value.first); | auto input_value_node = NewValueNode(input_value.first); | ||||
| input_value_node->set_has_new_value(true); | input_value_node->set_has_new_value(true); | ||||
| input_value_node->set_used_graph_count(para_ref_size); | |||||
| manager->Replace(paras[i], input_value_node); | manager->Replace(paras[i], input_value_node); | ||||
| } | } | ||||
| } | } | ||||
| @@ -305,18 +305,6 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| return out; | return out; | ||||
| } | } | ||||
| void OnlySaveAbstractInfo(const ValueNodePtr &value_node) { | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto &value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| if (value->isa<tensor::Tensor>()) { | |||||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||||
| value_node->set_value(MakeValue(new_tensor)); | |||||
| } | |||||
| } | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | ||||
| @@ -363,11 +351,31 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} | // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} | ||||
| PatternNode<AnfNodePtr> binop_grad_common; | PatternNode<AnfNodePtr> binop_grad_common; | ||||
| PatternNode<AnfNodePtr> getitem_vnode; | PatternNode<AnfNodePtr> getitem_vnode; | ||||
| PatternNode<AnfNodePtr> arg1; | |||||
| PatternNode<AnfNodePtr> arg2; | |||||
| PatternNode<AnfNodePtr> arg3; | |||||
| PatternNode<AnfNodePtr> arg4; | |||||
| std::vector<PatternNode<AnfNodePtr>> args(4); | |||||
| auto resolve_binop = PPrimitive(prim::kPrimResolve, symbol_str_vnode, binop_grad_common); | |||||
| auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); | |||||
| if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | |||||
| CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { | |||||
| for (size_t i = 0; i < 2; i++) { | |||||
| auto rep = (args[i]).GetNode(node); | |||||
| if (rep != nullptr && rep->isa<ValueNode>()) { | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto &value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| // when the use count of value node equals to one, it only used in binop_grad_common function | |||||
| if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) { | |||||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||||
| value_node->set_value(new_tensor); | |||||
| } | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| // resolve(CommonOPS, getitem)((tensors), 3) | // resolve(CommonOPS, getitem)((tensors), 3) | ||||
| PatternNode<AnfNodePtr> arg1; | |||||
| auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); | auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); | ||||
| auto pattern2 = PCNode(resolve2, arg, arg1); | auto pattern2 = PCNode(resolve2, arg, arg1); | ||||
| if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && | if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && | ||||
| @@ -1039,6 +1039,18 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { | |||||
| } | } | ||||
| } | } | ||||
| void PynativeExecutor::CleanTensorsInValueNode() { | |||||
| // Only need clean in ms backend policy and session should not be nullptr in ms backend. | |||||
| if (session == nullptr) { | |||||
| return; | |||||
| } | |||||
| auto useless_tensors = std::make_shared<std::vector<tensor::TensorPtr>>(); | |||||
| for (const auto &id_tensor_pair : tensor_id_with_tensor_) { | |||||
| std::copy(id_tensor_pair.second.begin(), id_tensor_pair.second.end(), std::back_inserter(*useless_tensors)); | |||||
| } | |||||
| session->CleanUselessTensors(useless_tensors); | |||||
| } | |||||
| AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { | AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { | ||||
| auto &out = graph_info_map_[curr_g_].node_map[obj_id]; | auto &out = graph_info_map_[curr_g_].node_map[obj_id]; | ||||
| if (out.second.size() == 1 && out.second[0] == -1) { | if (out.second.size() == 1 && out.second[0] == -1) { | ||||
| @@ -2039,6 +2051,7 @@ py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) | |||||
| MS_LOG(DEBUG) << "Eval run " << backend; | MS_LOG(DEBUG) << "Eval run " << backend; | ||||
| grad_is_running = true; | grad_is_running = true; | ||||
| BaseRef value = (*run)(arg_list); | BaseRef value = (*run)(arg_list); | ||||
| CleanTensorsInValueNode(); | |||||
| grad_is_running = false; | grad_is_running = false; | ||||
| MS_LOG(DEBUG) << "Run end " << value.ToString(); | MS_LOG(DEBUG) << "Run end " << value.ToString(); | ||||
| return BaseRefToPyData(value); | return BaseRefToPyData(value); | ||||
| @@ -141,6 +141,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| // Update the abstract and device address info of value node and tensors in bprop graph | // Update the abstract and device address info of value node and tensors in bprop graph | ||||
| void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); | void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); | ||||
| void SaveTensorsInValueNode(const ResourcePtr &resource); | void SaveTensorsInValueNode(const ResourcePtr &resource); | ||||
| void CleanTensorsInValueNode(); | |||||
| // construct grad graph | // construct grad graph | ||||
| void PushCurrentGraphToStack(); | void PushCurrentGraphToStack(); | ||||
| @@ -381,6 +381,9 @@ class ValueNode : public ANode { | |||||
| void set_has_new_value(bool flag) { has_new_value_ = flag; } | void set_has_new_value(bool flag) { has_new_value_ = flag; } | ||||
| bool has_new_value() const { return has_new_value_; } | bool has_new_value() const { return has_new_value_; } | ||||
| size_t used_graph_count() const { return used_graph_count_; } | |||||
| void set_used_graph_count(size_t used_graph_count) { used_graph_count_ = used_graph_count; } | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string DebugString(int recursive_level = 1) const override; | std::string DebugString(int recursive_level = 1) const override; | ||||
| std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } | std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } | ||||
| @@ -401,6 +404,7 @@ class ValueNode : public ANode { | |||||
| private: | private: | ||||
| ValuePtr value_; | ValuePtr value_; | ||||
| bool has_new_value_ = false; | bool has_new_value_ = false; | ||||
| size_t used_graph_count_{0}; | |||||
| }; | }; | ||||
| template <typename T> | template <typename T> | ||||