Browse Source

!10952 Add cnode attrs for recomputation

From: @ginfung
Reviewed-by: @zh_qh
Signed-off-by: @zh_qh
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
07ab1c1b4e
6 changed files with 120 additions and 79 deletions
  1. +30
    -1
      mindspore/ccsrc/debug/anf_ir_dump.cc
  2. +55
    -75
      mindspore/ccsrc/frontend/optimizer/recompute.cc
  3. +16
    -0
      mindspore/core/ir/anf.h
  4. +1
    -0
      mindspore/core/ir/func_graph_cloner.cc
  5. +8
    -3
      mindspore/nn/cell.py
  6. +10
    -0
      mindspore/ops/primitive.py

+ 30
- 1
mindspore/ccsrc/debug/anf_ir_dump.cc View File

@@ -313,7 +313,7 @@ void DumpOperateAttrs(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo
} }
auto attrs = primitive->attrs(); auto attrs = primitive->attrs();
if (!attrs.empty()) { if (!attrs.empty()) {
gsub->buffer << " {";
gsub->buffer << " primitive_attrs: {";
int i = 0; int i = 0;
for (const auto &attr : attrs) { for (const auto &attr : attrs) {
if (attr.first == PARALLEL_STRATEGY) { if (attr.first == PARALLEL_STRATEGY) {
@@ -332,6 +332,32 @@ void DumpOperateAttrs(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo
gsub->buffer << "}"; gsub->buffer << "}";
} }
} }
}

void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &gsub) {
if (op == nullptr || gsub == nullptr) {
return;
}
if (op->attrs().empty()) {
gsub->buffer << std::endl;
return;
}

auto attrs = op->attrs();
gsub->buffer << " cnode_attrs: {";
int i = 0;
for (const auto &attr : attrs) {
if (i++ != 0) {
gsub->buffer << ", ";
}
gsub->buffer << attr.first << ": ";
if (attr.second == nullptr) {
gsub->buffer << "null";
} else {
gsub->buffer << attr.second->ToString();
}
}
gsub->buffer << "}";
gsub->buffer << std::endl; gsub->buffer << std::endl;
} }


@@ -384,6 +410,9 @@ void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMap<Anf
// print operator attrs // print operator attrs
DumpOperateAttrs(op, gsub); DumpOperateAttrs(op, gsub);


// print cnode attrs
DumpCNodeAttrs(nd, gsub);

// print parallel info // print parallel info
DumpParallelInfo(nd, gsub); DumpParallelInfo(nd, gsub);




+ 55
- 75
mindspore/ccsrc/frontend/optimizer/recompute.cc View File

@@ -30,9 +30,9 @@ namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
constexpr auto kGradientsFlag = "Gradients"; constexpr auto kGradientsFlag = "Gradients";
constexpr auto kAttrRecomputed = "recomputed";
constexpr auto kAttrNoRecomputed = "no_recomputed";
bool IsTargetNode(const AnfNodePtr &node) {
constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNoRecompute = "no_recompute";
bool IsBpropNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return false; return false;
@@ -40,37 +40,26 @@ bool IsTargetNode(const AnfNodePtr &node) {
return node->fullname_with_scope().find(kGradientsFlag) == 0; return node->fullname_with_scope().find(kGradientsFlag) == 0;
} }


bool HasNoRecomputedAttr(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim != nullptr) {
auto no_recompute_val = prim->GetAttr(kAttrNoRecomputed);
if (no_recompute_val != nullptr && no_recompute_val->isa<BoolImm>()) {
return GetValue<bool>(no_recompute_val);
}
}
return false;
}

bool WithRecomputedScope(const AnfNodePtr &node) { bool WithRecomputedScope(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return false; return false;
} }
return node->fullname_with_scope().find(kAttrRecomputed) == 0;
auto full_name_with_scope = node->fullname_with_scope();
return full_name_with_scope.find(kAttrRecompute) == 0 &&
full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos;
} }


bool IsSetRecomputed(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim != nullptr) {
auto recompute_val = prim->GetAttr(kAttrRecomputed);
if (recompute_val != nullptr && recompute_val->isa<BoolImm>()) {
return GetValue<bool>(recompute_val);
}
bool HasRecomputeCNodeAttr(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return false;
} }
return false;
auto cnode_recompute_val = cnode->GetAttr(kAttrRecompute);
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val);
} }


bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsTargetNode(node) && IsSetRecomputed(node); }
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && HasRecomputeCNodeAttr(node); }


std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
const std::vector<CNodePtr> &cnodes) { const std::vector<CNodePtr> &cnodes) {
@@ -89,12 +78,12 @@ std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mn
} }
const auto &node_index_set = output_set_iter->second; const auto &node_index_set = output_set_iter->second;
if (!std::any_of(node_index_set.begin(), node_index_set.end(), if (!std::any_of(node_index_set.begin(), node_index_set.end(),
[](const std::pair<AnfNodePtr, int> &node_index) { return IsTargetNode(node_index.first); })) {
[](const std::pair<AnfNodePtr, int> &node_index) { return IsBpropNode(node_index.first); })) {
continue; continue;
} }
// Check inputs. // Check inputs.
const auto &inputs = cnode->inputs(); const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsTargetNode(node); })) {
if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsBpropNode(node); })) {
continue; continue;
} }
candidate_recomputed_nodes.emplace_back(cnode); candidate_recomputed_nodes.emplace_back(cnode);
@@ -166,7 +155,7 @@ void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng,
for (const auto &node_index_set : output_set_iter->second) { for (const auto &node_index_set : output_set_iter->second) {
auto output_node = node_index_set.first; auto output_node = node_index_set.first;
MS_EXCEPTION_IF_NULL(output_node); MS_EXCEPTION_IF_NULL(output_node);
if (!IsTargetNode(output_node)) {
if (!IsBpropNode(output_node)) {
continue; continue;
} }
target_nodes->insert(output_node->cast<CNodePtr>()); target_nodes->insert(output_node->cast<CNodePtr>());
@@ -215,7 +204,7 @@ bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool>
} }
const auto &inputs = cnode->inputs(); const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) { if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) {
return IsTargetNode(input) || HasGradInputs(input, has_grad_inputs_map);
return IsBpropNode(input) || HasGradInputs(input, has_grad_inputs_map);
})) { })) {
has_grad_inputs_map->insert(std::make_pair(node, true)); has_grad_inputs_map->insert(std::make_pair(node, true));
return true; return true;
@@ -232,7 +221,7 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
return false; return false;
} }
for (const auto &node_index_set : output_set_iter->second) { for (const auto &node_index_set : output_set_iter->second) {
if (!IsTargetNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) {
if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) {
return true; return true;
} }
} }
@@ -255,8 +244,8 @@ void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr
} }
} }


// Set 'recomputed' attr for the nodes according to its scope.
// A node set 'recomputed' attr can be the candidate recomputed node.
// Set 'recompute' cnode attr for the nodes according to its scope.
// A node set 'recompute' cnode attr can become the candidate recomputed node.
void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) { void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto mng = graph->manager(); auto mng = graph->manager();
@@ -264,44 +253,44 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map; std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map;
for (const auto &node : origin_nodes_topological) { for (const auto &node : origin_nodes_topological) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!WithRecomputedScope(node) || HasNoRecomputedAttr(node)) {
if (IsBpropNode(node)) {
continue; continue;
} }
auto prim = GetCNodePrimitive(node);
if (prim == nullptr || prim->name() == prim::kPrimTupleGetItem->name() ||
prim->name() == prim::kPrimAllGather->name()) {
// Do not recompute the communicate op.
if (IsPrimitiveCNode(node, prim::kPrimAllGather)) {
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
continue; continue;
} }
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) { if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) {
continue; continue;
} }


// Make a new primitive to set attr because some nodes share the same primitive probably.
auto new_prim = std::make_shared<Primitive>(prim->name());
new_prim->SetAttrs(prim->attrs());
new_prim->set_prim_type(prim->prim_type());
new_prim->set_attr(kAttrRecomputed, MakeValue(true));
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)};
const auto &origin_inputs = node->inputs();
std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs));
auto new_node = graph->NewCNode(new_inputs);
new_node->set_abstract(node->abstract());
new_node->set_scope(node->scope());
mng->Replace(node, new_node);

auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
continue;
}
auto prim_recompute_attr = prim->GetAttr(kAttrRecompute);
int prim_recompute_val = -1;
if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>()) {
prim_recompute_val = GetValue<bool>(prim_recompute_attr);
}
if ((WithRecomputedScope(node) && prim_recompute_val != 0) || prim_recompute_val == 1) {
cnode->AddAttr(kAttrRecompute, MakeValue(true));
}
if (!HasRecomputeCNodeAttr(node)) {
continue;
}
// Set attr for the tuple_getitem outputs. // Set attr for the tuple_getitem outputs.
std::vector<AnfNodePtr> tuple_getitem_output_nodes; std::vector<AnfNodePtr> tuple_getitem_output_nodes;
GetTupleGetItemOutputNodes(mng, new_node, &tuple_getitem_output_nodes);
GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes);
for (const auto &output_node : tuple_getitem_output_nodes) { for (const auto &output_node : tuple_getitem_output_nodes) {
auto new_output_prim = std::make_shared<Primitive>(prim::kPrimTupleGetItem->name());
new_output_prim->set_attr(kAttrRecomputed, MakeValue(true));
std::vector<AnfNodePtr> new_tuple_getitem_inputs{NewValueNode(new_output_prim)};
auto origin_tuple_getitem_inputs = output_node->cast<CNodePtr>()->inputs();
std::copy(origin_tuple_getitem_inputs.begin() + 1, origin_tuple_getitem_inputs.end(),
std::back_inserter(new_tuple_getitem_inputs));
auto new_tuple_getitem = graph->NewCNode(new_tuple_getitem_inputs);
new_tuple_getitem->set_abstract(output_node->abstract());
mng->Replace(output_node, new_tuple_getitem);
auto output_cnode = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
output_cnode->AddAttr(kAttrRecompute, MakeValue(true));
} }
} }
} }
@@ -318,15 +307,9 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
return iter->second; return iter->second;
} }
MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString(); MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString();
auto prim = GetCNodePrimitive(origin_node);
MS_EXCEPTION_IF_NULL(prim);
auto new_prim = std::make_shared<Primitive>(prim->name());
new_prim->SetAttrs(prim->attrs());
new_prim->set_attr("duplicated", MakeValue(true));
new_prim->set_prim_type(prim->prim_type());
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)};
std::vector<AnfNodePtr> new_inputs;
bool has_recomputed_inputs = false; bool has_recomputed_inputs = false;
for (size_t i = 1; i < origin_node->size(); ++i) {
for (size_t i = 0; i < origin_node->size(); ++i) {
auto input = origin_node->input(i); auto input = origin_node->input(i);
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>()) { if (!input->isa<CNode>()) {
@@ -356,6 +339,8 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
new_inputs[1] = depend_node; new_inputs[1] = depend_node;
} }
auto recomputed_node = graph->NewCNode(new_inputs); auto recomputed_node = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(recomputed_node);
recomputed_node->AddAttr("duplicated", MakeValue(true));
recomputed_node->set_abstract(origin_node->abstract()); recomputed_node->set_abstract(origin_node->abstract());
recomputed_node->set_scope(origin_node->scope()); recomputed_node->set_scope(origin_node->scope());
origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node));
@@ -374,10 +359,6 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_se
MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input"; MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input";
auto target_cnode = target_node->cast<CNodePtr>(); auto target_cnode = target_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(target_cnode); MS_EXCEPTION_IF_NULL(target_cnode);
auto prim = GetCNodePrimitive(target_cnode);
if (prim != nullptr) {
prim->set_attr("target_grad", MakeValue(true));
}
std::vector<AnfNodePtr> new_target_inputs; std::vector<AnfNodePtr> new_target_inputs;
for (const auto &input : target_cnode->inputs()) { for (const auto &input : target_cnode->inputs()) {
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
@@ -394,6 +375,7 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_se
} }
} }
auto new_target_node = graph->NewCNode(new_target_inputs); auto new_target_node = graph->NewCNode(new_target_inputs);
new_target_node->AddAttr("target_grad", MakeValue(true));
new_target_node->set_abstract(target_node->abstract()); new_target_node->set_abstract(target_node->abstract());
new_target_node->set_scope(target_node->scope()); new_target_node->set_scope(target_node->scope());
mng->Replace(target_node, new_target_node); mng->Replace(target_node, new_target_node);
@@ -405,11 +387,9 @@ void InsertRecomputedNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto mng = graph->manager(); auto mng = graph->manager();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
std::list<CNodePtr> old_orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> old_nodes_topological(old_orders.begin(), old_orders.end());
SetRecomputedAttr(graph, old_nodes_topological);
std::list<CNodePtr> new_orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> origin_nodes_topological(new_orders.begin(), new_orders.end());
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> origin_nodes_topological(orders.begin(), orders.end());
SetRecomputedAttr(graph, origin_nodes_topological);
// Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly. // Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly.
std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological); std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological);
std::unordered_set<CNodePtr> visited_nodes; std::unordered_set<CNodePtr> visited_nodes;


+ 16
- 0
mindspore/core/ir/anf.h View File

@@ -267,6 +267,21 @@ class CNode : public AnfNode {


VarPtr func_graph_as_var() const { return func_graph_as_var_; } VarPtr func_graph_as_var() const { return func_graph_as_var_; }


const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;
}
}

void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
void EraseAttr(const std::string &name) { (void)attrs_.erase(name); }
ValuePtr GetAttr(const std::string &name) const {
auto iter = attrs_.find(name);
return iter == attrs_.cend() ? nullptr : iter->second;
}
bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); }

private: private:
std::vector<AnfNodePtr> inputs_; std::vector<AnfNodePtr> inputs_;
VarPtr func_graph_as_var_; VarPtr func_graph_as_var_;
@@ -276,6 +291,7 @@ class CNode : public AnfNode {
// output_value_ store cnode value and id in pynative mode // output_value_ store cnode value and id in pynative mode
std::vector<std::pair<ValuePtr, std::string>> inputs_value_; std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
std::pair<ValuePtr, std::string> output_value_; std::pair<ValuePtr, std::string> output_value_;
std::unordered_map<std::string, ValuePtr> attrs_;
}; };


// ANode represents the atomic node. It's derived Parameter and ValueNode. // ANode represents the atomic node. It's derived Parameter and ValueNode.


+ 1
- 0
mindspore/core/ir/func_graph_cloner.cc View File

@@ -90,6 +90,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_abstract(old_node->abstract()); new_node->set_abstract(old_node->abstract());
new_node->set_forward(old_node->forward().first, old_node->forward().second); new_node->set_forward(old_node->forward().first, old_node->forward().second);
new_node->set_inputs_value(old_node->inputs_value()); new_node->set_inputs_value(old_node->inputs_value());
new_node->set_attrs(old_node->attrs());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope); new_node->set_scope(scope);
if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) { if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) {


+ 8
- 3
mindspore/nn/cell.py View File

@@ -1095,14 +1095,19 @@ class Cell(Cell_):
param.comm_fusion = fusion_type param.comm_fusion = fusion_type
return self return self


def recompute(self):
def recompute(self, mode=True):
""" """
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad
node and is set recomputed, we will compute it again for the grad node after the forward computation. node and is set recomputed, we will compute it again for the grad node after the forward computation.
Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
""" """
self._set_scope('recomputed')
if mode is True:
self._set_scope("recompute")
else:
self._set_scope("no_recompute")
for cell in self.cells(): for cell in self.cells():
cell.recompute()
cell.recompute(mode)




class GraphKernel(Cell): class GraphKernel(Cell):


+ 10
- 0
mindspore/ops/primitive.py View File

@@ -206,6 +206,16 @@ class Primitive(Primitive_):
""" Whether the primitive will update the value of parameter.""" """ Whether the primitive will update the value of parameter."""
return self._update_parameter return self._update_parameter


def recompute(self, mode):
"""
Set the primitive recomputed. If a primitive feeds into a grad node and is set recomputed,
we will compute it again for the grad node after the forward computation.
Args:
mode (bool): Specifies whether the primitive is recomputed. Default: True.
"""
self.add_prim_attr("recompute", mode)
return self



class PrimitiveWithCheck(Primitive): class PrimitiveWithCheck(Primitive):
""" """


Loading…
Cancel
Save