From ba311cf57f64e71b5f05fd774130f84c80eb7e0d Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Mon, 4 Jan 2021 16:02:17 +0800 Subject: [PATCH] Add cnode attr for recomputation --- mindspore/ccsrc/debug/anf_ir_dump.cc | 31 ++++- .../ccsrc/frontend/optimizer/recompute.cc | 130 ++++++++---------- mindspore/core/ir/anf.h | 16 +++ mindspore/core/ir/func_graph_cloner.cc | 1 + mindspore/nn/cell.py | 11 +- mindspore/ops/primitive.py | 10 ++ 6 files changed, 120 insertions(+), 79 deletions(-) diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index b990a6f561..6b55bb5492 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -313,7 +313,7 @@ void DumpOperateAttrs(const AnfNodePtr &op, const std::shared_ptrattrs(); if (!attrs.empty()) { - gsub->buffer << " {"; + gsub->buffer << " primitive_attrs: {"; int i = 0; for (const auto &attr : attrs) { if (attr.first == PARALLEL_STRATEGY) { @@ -332,6 +332,32 @@ void DumpOperateAttrs(const AnfNodePtr &op, const std::shared_ptrbuffer << "}"; } } +} + +void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr &gsub) { + if (op == nullptr || gsub == nullptr) { + return; + } + if (op->attrs().empty()) { + gsub->buffer << std::endl; + return; + } + + auto attrs = op->attrs(); + gsub->buffer << " cnode_attrs: {"; + int i = 0; + for (const auto &attr : attrs) { + if (i++ != 0) { + gsub->buffer << ", "; + } + gsub->buffer << attr.first << ": "; + if (attr.second == nullptr) { + gsub->buffer << "null"; + } else { + gsub->buffer << attr.second->ToString(); + } + } + gsub->buffer << "}"; gsub->buffer << std::endl; } @@ -384,6 +410,9 @@ void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMapisa()) { return false; @@ -40,37 +40,26 @@ bool IsTargetNode(const AnfNodePtr &node) { return node->fullname_with_scope().find(kGradientsFlag) == 0; } -bool HasNoRecomputedAttr(const AnfNodePtr &node) { - auto prim = GetCNodePrimitive(node); - if (prim != nullptr) { - auto no_recompute_val = prim->GetAttr(kAttrNoRecomputed); - if (no_recompute_val != nullptr && no_recompute_val->isa()) { - return GetValue(no_recompute_val); - } - } - return false; -} - bool WithRecomputedScope(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return false; } - return node->fullname_with_scope().find(kAttrRecomputed) == 0; + auto full_name_with_scope = node->fullname_with_scope(); + return full_name_with_scope.find(kAttrRecompute) == 0 && + full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos; } -bool IsSetRecomputed(const AnfNodePtr &node) { - auto prim = GetCNodePrimitive(node); - if (prim != nullptr) { - auto recompute_val = prim->GetAttr(kAttrRecomputed); - if (recompute_val != nullptr && recompute_val->isa()) { - return GetValue(recompute_val); - } +bool HasRecomputeCNodeAttr(const AnfNodePtr &node) { + auto cnode = node->cast(); + if (cnode == nullptr) { + return false; } - return false; + auto cnode_recompute_val = cnode->GetAttr(kAttrRecompute); + return cnode_recompute_val != nullptr && cnode_recompute_val->isa() && GetValue(cnode_recompute_val); } -bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsTargetNode(node) && IsSetRecomputed(node); } +bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && HasRecomputeCNodeAttr(node); } std::vector FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, const std::vector &cnodes) { @@ -89,12 +78,12 @@ std::vector FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mn } const auto &node_index_set = output_set_iter->second; if (!std::any_of(node_index_set.begin(), node_index_set.end(), - [](const std::pair &node_index) { return IsTargetNode(node_index.first); })) { + [](const std::pair &node_index) { return IsBpropNode(node_index.first); })) { continue; } // Check inputs. const auto &inputs = cnode->inputs(); - if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsTargetNode(node); })) { + if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsBpropNode(node); })) { continue; } candidate_recomputed_nodes.emplace_back(cnode); @@ -166,7 +155,7 @@ void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng, for (const auto &node_index_set : output_set_iter->second) { auto output_node = node_index_set.first; MS_EXCEPTION_IF_NULL(output_node); - if (!IsTargetNode(output_node)) { + if (!IsBpropNode(output_node)) { continue; } target_nodes->insert(output_node->cast()); @@ -215,7 +204,7 @@ bool HasGradInputs(const AnfNodePtr &node, std::unordered_map } const auto &inputs = cnode->inputs(); if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) { - return IsTargetNode(input) || HasGradInputs(input, has_grad_inputs_map); + return IsBpropNode(input) || HasGradInputs(input, has_grad_inputs_map); })) { has_grad_inputs_map->insert(std::make_pair(node, true)); return true; @@ -232,7 +221,7 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) { return false; } for (const auto &node_index_set : output_set_iter->second) { - if (!IsTargetNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) { + if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) { return true; } } @@ -255,8 +244,8 @@ void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr } } -// Set 'recomputed' attr for the nodes according to its scope. -// A node set 'recomputed' attr can be the candidate recomputed node. +// Set 'recompute' cnode attr for the nodes according to its scope. +// A node set 'recompute' cnode attr can become the candidate recomputed node. void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector &origin_nodes_topological) { MS_EXCEPTION_IF_NULL(graph); auto mng = graph->manager(); @@ -264,44 +253,44 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector &o std::unordered_map has_grad_inputs_map; for (const auto &node : origin_nodes_topological) { MS_EXCEPTION_IF_NULL(node); - if (!WithRecomputedScope(node) || HasNoRecomputedAttr(node)) { + if (IsBpropNode(node)) { continue; } - auto prim = GetCNodePrimitive(node); - if (prim == nullptr || prim->name() == prim::kPrimTupleGetItem->name() || - prim->name() == prim::kPrimAllGather->name()) { + // Do not recompute the communicate op. + if (IsPrimitiveCNode(node, prim::kPrimAllGather)) { + continue; + } + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { continue; } if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) { continue; } - // Make a new primitive to set attr because some nodes share the same primitive probably. - auto new_prim = std::make_shared(prim->name()); - new_prim->SetAttrs(prim->attrs()); - new_prim->set_prim_type(prim->prim_type()); - new_prim->set_attr(kAttrRecomputed, MakeValue(true)); - std::vector new_inputs{NewValueNode(new_prim)}; - const auto &origin_inputs = node->inputs(); - std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs)); - auto new_node = graph->NewCNode(new_inputs); - new_node->set_abstract(node->abstract()); - new_node->set_scope(node->scope()); - mng->Replace(node, new_node); - + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto prim = GetCNodePrimitive(cnode); + if (prim == nullptr) { + continue; + } + auto prim_recompute_attr = prim->GetAttr(kAttrRecompute); + int prim_recompute_val = -1; + if (prim_recompute_attr != nullptr && prim_recompute_attr->isa()) { + prim_recompute_val = GetValue(prim_recompute_attr); + } + if ((WithRecomputedScope(node) && prim_recompute_val != 0) || prim_recompute_val == 1) { + cnode->AddAttr(kAttrRecompute, MakeValue(true)); + } + if (!HasRecomputeCNodeAttr(node)) { + continue; + } // Set attr for the tuple_getitem outputs. std::vector tuple_getitem_output_nodes; - GetTupleGetItemOutputNodes(mng, new_node, &tuple_getitem_output_nodes); + GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes); for (const auto &output_node : tuple_getitem_output_nodes) { - auto new_output_prim = std::make_shared(prim::kPrimTupleGetItem->name()); - new_output_prim->set_attr(kAttrRecomputed, MakeValue(true)); - std::vector new_tuple_getitem_inputs{NewValueNode(new_output_prim)}; - auto origin_tuple_getitem_inputs = output_node->cast()->inputs(); - std::copy(origin_tuple_getitem_inputs.begin() + 1, origin_tuple_getitem_inputs.end(), - std::back_inserter(new_tuple_getitem_inputs)); - auto new_tuple_getitem = graph->NewCNode(new_tuple_getitem_inputs); - new_tuple_getitem->set_abstract(output_node->abstract()); - mng->Replace(output_node, new_tuple_getitem); + auto output_cnode = output_node->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + output_cnode->AddAttr(kAttrRecompute, MakeValue(true)); } } } @@ -318,15 +307,9 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod return iter->second; } MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString(); - auto prim = GetCNodePrimitive(origin_node); - MS_EXCEPTION_IF_NULL(prim); - auto new_prim = std::make_shared(prim->name()); - new_prim->SetAttrs(prim->attrs()); - new_prim->set_attr("duplicated", MakeValue(true)); - new_prim->set_prim_type(prim->prim_type()); - std::vector new_inputs{NewValueNode(new_prim)}; + std::vector new_inputs; bool has_recomputed_inputs = false; - for (size_t i = 1; i < origin_node->size(); ++i) { + for (size_t i = 0; i < origin_node->size(); ++i) { auto input = origin_node->input(i); MS_EXCEPTION_IF_NULL(input); if (!input->isa()) { @@ -356,6 +339,8 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod new_inputs[1] = depend_node; } auto recomputed_node = graph->NewCNode(new_inputs); + MS_EXCEPTION_IF_NULL(recomputed_node); + recomputed_node->AddAttr("duplicated", MakeValue(true)); recomputed_node->set_abstract(origin_node->abstract()); recomputed_node->set_scope(origin_node->scope()); origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); @@ -374,10 +359,6 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_se MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input"; auto target_cnode = target_node->cast(); MS_EXCEPTION_IF_NULL(target_cnode); - auto prim = GetCNodePrimitive(target_cnode); - if (prim != nullptr) { - prim->set_attr("target_grad", MakeValue(true)); - } std::vector new_target_inputs; for (const auto &input : target_cnode->inputs()) { MS_EXCEPTION_IF_NULL(input); @@ -394,6 +375,7 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_se } } auto new_target_node = graph->NewCNode(new_target_inputs); + new_target_node->AddAttr("target_grad", MakeValue(true)); new_target_node->set_abstract(target_node->abstract()); new_target_node->set_scope(target_node->scope()); mng->Replace(target_node, new_target_node); @@ -405,11 +387,9 @@ void InsertRecomputedNodes(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); auto mng = graph->manager(); MS_EXCEPTION_IF_NULL(mng); - std::list old_orders = graph->GetOrderedCnodes(); - std::vector old_nodes_topological(old_orders.begin(), old_orders.end()); - SetRecomputedAttr(graph, old_nodes_topological); - std::list new_orders = graph->GetOrderedCnodes(); - std::vector origin_nodes_topological(new_orders.begin(), new_orders.end()); + std::list orders = graph->GetOrderedCnodes(); + std::vector origin_nodes_topological(orders.begin(), orders.end()); + SetRecomputedAttr(graph, origin_nodes_topological); // Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly. std::vector candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological); std::unordered_set visited_nodes; diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 2000f97ba7..d7960151ed 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -267,6 +267,21 @@ class CNode : public AnfNode { VarPtr func_graph_as_var() const { return func_graph_as_var_; } + const std::unordered_map &attrs() const { return attrs_; } + void set_attrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { + attrs_[attr.first] = attr.second; + } + } + + void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; } + void EraseAttr(const std::string &name) { (void)attrs_.erase(name); } + ValuePtr GetAttr(const std::string &name) const { + auto iter = attrs_.find(name); + return iter == attrs_.cend() ? nullptr : iter->second; + } + bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); } + private: std::vector inputs_; VarPtr func_graph_as_var_; @@ -276,6 +291,7 @@ class CNode : public AnfNode { // output_value_ store cnode value and id in pynative mode std::vector> inputs_value_; std::pair output_value_; + std::unordered_map attrs_; }; // ANode represents the atomic node. It's derived Parameter and ValueNode. diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index f981e9f2c3..52f883dcf2 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -90,6 +90,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { new_node->set_abstract(old_node->abstract()); new_node->set_forward(old_node->forward().first, old_node->forward().second); new_node->set_inputs_value(old_node->inputs_value()); + new_node->set_attrs(old_node->attrs()); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_node->set_scope(scope); if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) { diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index b15da3e581..cfe31eee9e 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1095,14 +1095,19 @@ class Cell(Cell_): param.comm_fusion = fusion_type return self - def recompute(self): + def recompute(self, mode=True): """ Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad node and is set recomputed, we will compute it again for the grad node after the forward computation. + Args: + mode (bool): Specifies whether the cell is recomputed. Default: True. """ - self._set_scope('recomputed') + if mode is True: + self._set_scope("recompute") + else: + self._set_scope("no_recompute") for cell in self.cells(): - cell.recompute() + cell.recompute(mode) class GraphKernel(Cell): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 096b4baf30..79c3fc2e2c 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -206,6 +206,16 @@ class Primitive(Primitive_): """ Whether the primitive will update the value of parameter.""" return self._update_parameter + def recompute(self, mode): + """ + Set the primitive recomputed. If a primitive feeds into a grad node and is set recomputed, + we will compute it again for the grad node after the forward computation. + Args: + mode (bool): Specifies whether the primitive is recomputed. Default: True. + """ + self.add_prim_attr("recompute", mode) + return self + class PrimitiveWithCheck(Primitive): """