Browse Source

!10952 Add cnode attrs for recomputation

From: @ginfung
Reviewed-by: @zh_qh
Signed-off-by: @zh_qh
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
07ab1c1b4e
6 changed files with 120 additions and 79 deletions
  1. +30
    -1
      mindspore/ccsrc/debug/anf_ir_dump.cc
  2. +55
    -75
      mindspore/ccsrc/frontend/optimizer/recompute.cc
  3. +16
    -0
      mindspore/core/ir/anf.h
  4. +1
    -0
      mindspore/core/ir/func_graph_cloner.cc
  5. +8
    -3
      mindspore/nn/cell.py
  6. +10
    -0
      mindspore/ops/primitive.py

+ 30
- 1
mindspore/ccsrc/debug/anf_ir_dump.cc View File

@@ -313,7 +313,7 @@ void DumpOperateAttrs(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo
}
auto attrs = primitive->attrs();
if (!attrs.empty()) {
gsub->buffer << " {";
gsub->buffer << " primitive_attrs: {";
int i = 0;
for (const auto &attr : attrs) {
if (attr.first == PARALLEL_STRATEGY) {
@@ -332,6 +332,32 @@ void DumpOperateAttrs(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo
gsub->buffer << "}";
}
}
}

void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &gsub) {
if (op == nullptr || gsub == nullptr) {
return;
}
if (op->attrs().empty()) {
gsub->buffer << std::endl;
return;
}

auto attrs = op->attrs();
gsub->buffer << " cnode_attrs: {";
int i = 0;
for (const auto &attr : attrs) {
if (i++ != 0) {
gsub->buffer << ", ";
}
gsub->buffer << attr.first << ": ";
if (attr.second == nullptr) {
gsub->buffer << "null";
} else {
gsub->buffer << attr.second->ToString();
}
}
gsub->buffer << "}";
gsub->buffer << std::endl;
}

@@ -384,6 +410,9 @@ void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMap<Anf
// print operator attrs
DumpOperateAttrs(op, gsub);

// print cnode attrs
DumpCNodeAttrs(nd, gsub);

// print parallel info
DumpParallelInfo(nd, gsub);



+ 55
- 75
mindspore/ccsrc/frontend/optimizer/recompute.cc View File

@@ -30,9 +30,9 @@ namespace mindspore {
namespace opt {
namespace {
constexpr auto kGradientsFlag = "Gradients";
constexpr auto kAttrRecomputed = "recomputed";
constexpr auto kAttrNoRecomputed = "no_recomputed";
bool IsTargetNode(const AnfNodePtr &node) {
constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNoRecompute = "no_recompute";
bool IsBpropNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
@@ -40,37 +40,26 @@ bool IsTargetNode(const AnfNodePtr &node) {
return node->fullname_with_scope().find(kGradientsFlag) == 0;
}

bool HasNoRecomputedAttr(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim != nullptr) {
auto no_recompute_val = prim->GetAttr(kAttrNoRecomputed);
if (no_recompute_val != nullptr && no_recompute_val->isa<BoolImm>()) {
return GetValue<bool>(no_recompute_val);
}
}
return false;
}

bool WithRecomputedScope(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
return node->fullname_with_scope().find(kAttrRecomputed) == 0;
auto full_name_with_scope = node->fullname_with_scope();
return full_name_with_scope.find(kAttrRecompute) == 0 &&
full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos;
}

bool IsSetRecomputed(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim != nullptr) {
auto recompute_val = prim->GetAttr(kAttrRecomputed);
if (recompute_val != nullptr && recompute_val->isa<BoolImm>()) {
return GetValue<bool>(recompute_val);
}
bool HasRecomputeCNodeAttr(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return false;
}
return false;
auto cnode_recompute_val = cnode->GetAttr(kAttrRecompute);
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val);
}

bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsTargetNode(node) && IsSetRecomputed(node); }
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && HasRecomputeCNodeAttr(node); }

std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
const std::vector<CNodePtr> &cnodes) {
@@ -89,12 +78,12 @@ std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mn
}
const auto &node_index_set = output_set_iter->second;
if (!std::any_of(node_index_set.begin(), node_index_set.end(),
[](const std::pair<AnfNodePtr, int> &node_index) { return IsTargetNode(node_index.first); })) {
[](const std::pair<AnfNodePtr, int> &node_index) { return IsBpropNode(node_index.first); })) {
continue;
}
// Check inputs.
const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsTargetNode(node); })) {
if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsBpropNode(node); })) {
continue;
}
candidate_recomputed_nodes.emplace_back(cnode);
@@ -166,7 +155,7 @@ void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng,
for (const auto &node_index_set : output_set_iter->second) {
auto output_node = node_index_set.first;
MS_EXCEPTION_IF_NULL(output_node);
if (!IsTargetNode(output_node)) {
if (!IsBpropNode(output_node)) {
continue;
}
target_nodes->insert(output_node->cast<CNodePtr>());
@@ -215,7 +204,7 @@ bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool>
}
const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) {
return IsTargetNode(input) || HasGradInputs(input, has_grad_inputs_map);
return IsBpropNode(input) || HasGradInputs(input, has_grad_inputs_map);
})) {
has_grad_inputs_map->insert(std::make_pair(node, true));
return true;
@@ -232,7 +221,7 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
return false;
}
for (const auto &node_index_set : output_set_iter->second) {
if (!IsTargetNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) {
if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) {
return true;
}
}
@@ -255,8 +244,8 @@ void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr
}
}

// Set 'recomputed' attr for the nodes according to its scope.
// A node set 'recomputed' attr can be the candidate recomputed node.
// Set 'recompute' cnode attr for the nodes according to its scope.
// A node set 'recompute' cnode attr can become the candidate recomputed node.
void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) {
MS_EXCEPTION_IF_NULL(graph);
auto mng = graph->manager();
@@ -264,44 +253,44 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map;
for (const auto &node : origin_nodes_topological) {
MS_EXCEPTION_IF_NULL(node);
if (!WithRecomputedScope(node) || HasNoRecomputedAttr(node)) {
if (IsBpropNode(node)) {
continue;
}
auto prim = GetCNodePrimitive(node);
if (prim == nullptr || prim->name() == prim::kPrimTupleGetItem->name() ||
prim->name() == prim::kPrimAllGather->name()) {
// Do not recompute the communicate op.
if (IsPrimitiveCNode(node, prim::kPrimAllGather)) {
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
continue;
}
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) {
continue;
}

// Make a new primitive to set attr because some nodes share the same primitive probably.
auto new_prim = std::make_shared<Primitive>(prim->name());
new_prim->SetAttrs(prim->attrs());
new_prim->set_prim_type(prim->prim_type());
new_prim->set_attr(kAttrRecomputed, MakeValue(true));
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)};
const auto &origin_inputs = node->inputs();
std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs));
auto new_node = graph->NewCNode(new_inputs);
new_node->set_abstract(node->abstract());
new_node->set_scope(node->scope());
mng->Replace(node, new_node);

auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
continue;
}
auto prim_recompute_attr = prim->GetAttr(kAttrRecompute);
int prim_recompute_val = -1;
if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>()) {
prim_recompute_val = GetValue<bool>(prim_recompute_attr);
}
if ((WithRecomputedScope(node) && prim_recompute_val != 0) || prim_recompute_val == 1) {
cnode->AddAttr(kAttrRecompute, MakeValue(true));
}
if (!HasRecomputeCNodeAttr(node)) {
continue;
}
// Set attr for the tuple_getitem outputs.
std::vector<AnfNodePtr> tuple_getitem_output_nodes;
GetTupleGetItemOutputNodes(mng, new_node, &tuple_getitem_output_nodes);
GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes);
for (const auto &output_node : tuple_getitem_output_nodes) {
auto new_output_prim = std::make_shared<Primitive>(prim::kPrimTupleGetItem->name());
new_output_prim->set_attr(kAttrRecomputed, MakeValue(true));
std::vector<AnfNodePtr> new_tuple_getitem_inputs{NewValueNode(new_output_prim)};
auto origin_tuple_getitem_inputs = output_node->cast<CNodePtr>()->inputs();
std::copy(origin_tuple_getitem_inputs.begin() + 1, origin_tuple_getitem_inputs.end(),
std::back_inserter(new_tuple_getitem_inputs));
auto new_tuple_getitem = graph->NewCNode(new_tuple_getitem_inputs);
new_tuple_getitem->set_abstract(output_node->abstract());
mng->Replace(output_node, new_tuple_getitem);
auto output_cnode = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
output_cnode->AddAttr(kAttrRecompute, MakeValue(true));
}
}
}
@@ -318,15 +307,9 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
return iter->second;
}
MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString();
auto prim = GetCNodePrimitive(origin_node);
MS_EXCEPTION_IF_NULL(prim);
auto new_prim = std::make_shared<Primitive>(prim->name());
new_prim->SetAttrs(prim->attrs());
new_prim->set_attr("duplicated", MakeValue(true));
new_prim->set_prim_type(prim->prim_type());
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)};
std::vector<AnfNodePtr> new_inputs;
bool has_recomputed_inputs = false;
for (size_t i = 1; i < origin_node->size(); ++i) {
for (size_t i = 0; i < origin_node->size(); ++i) {
auto input = origin_node->input(i);
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>()) {
@@ -356,6 +339,8 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
new_inputs[1] = depend_node;
}
auto recomputed_node = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(recomputed_node);
recomputed_node->AddAttr("duplicated", MakeValue(true));
recomputed_node->set_abstract(origin_node->abstract());
recomputed_node->set_scope(origin_node->scope());
origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node));
@@ -374,10 +359,6 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_se
MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input";
auto target_cnode = target_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(target_cnode);
auto prim = GetCNodePrimitive(target_cnode);
if (prim != nullptr) {
prim->set_attr("target_grad", MakeValue(true));
}
std::vector<AnfNodePtr> new_target_inputs;
for (const auto &input : target_cnode->inputs()) {
MS_EXCEPTION_IF_NULL(input);
@@ -394,6 +375,7 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_se
}
}
auto new_target_node = graph->NewCNode(new_target_inputs);
new_target_node->AddAttr("target_grad", MakeValue(true));
new_target_node->set_abstract(target_node->abstract());
new_target_node->set_scope(target_node->scope());
mng->Replace(target_node, new_target_node);
@@ -405,11 +387,9 @@ void InsertRecomputedNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto mng = graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::list<CNodePtr> old_orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> old_nodes_topological(old_orders.begin(), old_orders.end());
SetRecomputedAttr(graph, old_nodes_topological);
std::list<CNodePtr> new_orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> origin_nodes_topological(new_orders.begin(), new_orders.end());
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> origin_nodes_topological(orders.begin(), orders.end());
SetRecomputedAttr(graph, origin_nodes_topological);
// Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly.
std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological);
std::unordered_set<CNodePtr> visited_nodes;


+ 16
- 0
mindspore/core/ir/anf.h View File

@@ -267,6 +267,21 @@ class CNode : public AnfNode {

VarPtr func_graph_as_var() const { return func_graph_as_var_; }

const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;
}
}

void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
void EraseAttr(const std::string &name) { (void)attrs_.erase(name); }
ValuePtr GetAttr(const std::string &name) const {
auto iter = attrs_.find(name);
return iter == attrs_.cend() ? nullptr : iter->second;
}
bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); }

private:
std::vector<AnfNodePtr> inputs_;
VarPtr func_graph_as_var_;
@@ -276,6 +291,7 @@ class CNode : public AnfNode {
// output_value_ store cnode value and id in pynative mode
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
std::pair<ValuePtr, std::string> output_value_;
std::unordered_map<std::string, ValuePtr> attrs_;
};

// ANode represents the atomic node. It's derived Parameter and ValueNode.


+ 1
- 0
mindspore/core/ir/func_graph_cloner.cc View File

@@ -90,6 +90,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_abstract(old_node->abstract());
new_node->set_forward(old_node->forward().first, old_node->forward().second);
new_node->set_inputs_value(old_node->inputs_value());
new_node->set_attrs(old_node->attrs());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope);
if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) {


+ 8
- 3
mindspore/nn/cell.py View File

@@ -1095,14 +1095,19 @@ class Cell(Cell_):
param.comm_fusion = fusion_type
return self

def recompute(self):
def recompute(self, mode=True):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad
node and is set recomputed, we will compute it again for the grad node after the forward computation.
Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
"""
self._set_scope('recomputed')
if mode is True:
self._set_scope("recompute")
else:
self._set_scope("no_recompute")
for cell in self.cells():
cell.recompute()
cell.recompute(mode)


class GraphKernel(Cell):


+ 10
- 0
mindspore/ops/primitive.py View File

@@ -206,6 +206,16 @@ class Primitive(Primitive_):
""" Whether the primitive will update the value of parameter."""
return self._update_parameter

def recompute(self, mode):
"""
Set the primitive recomputed. If a primitive feeds into a grad node and is set recomputed,
we will compute it again for the grad node after the forward computation.
Args:
mode (bool): Specifies whether the primitive is recomputed. Default: True.
"""
self.add_prim_attr("recompute", mode)
return self


class PrimitiveWithCheck(Primitive):
"""


Loading…
Cancel
Save