|
|
@@ -30,9 +30,9 @@ namespace mindspore { |
|
|
namespace opt { |
|
|
namespace opt { |
|
|
namespace { |
|
|
namespace { |
|
|
constexpr auto kGradientsFlag = "Gradients"; |
|
|
constexpr auto kGradientsFlag = "Gradients"; |
|
|
constexpr auto kAttrRecomputed = "recomputed"; |
|
|
|
|
|
constexpr auto kAttrNoRecomputed = "no_recomputed"; |
|
|
|
|
|
bool IsTargetNode(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
|
constexpr auto kAttrRecompute = "recompute"; |
|
|
|
|
|
constexpr auto kAttrNoRecompute = "no_recompute"; |
|
|
|
|
|
bool IsBpropNode(const AnfNodePtr &node) { |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
if (!node->isa<CNode>()) { |
|
|
if (!node->isa<CNode>()) { |
|
|
return false; |
|
|
return false; |
|
|
@@ -40,37 +40,26 @@ bool IsTargetNode(const AnfNodePtr &node) { |
|
|
return node->fullname_with_scope().find(kGradientsFlag) == 0; |
|
|
return node->fullname_with_scope().find(kGradientsFlag) == 0; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool HasNoRecomputedAttr(const AnfNodePtr &node) { |
|
|
|
|
|
auto prim = GetCNodePrimitive(node); |
|
|
|
|
|
if (prim != nullptr) { |
|
|
|
|
|
auto no_recompute_val = prim->GetAttr(kAttrNoRecomputed); |
|
|
|
|
|
if (no_recompute_val != nullptr && no_recompute_val->isa<BoolImm>()) { |
|
|
|
|
|
return GetValue<bool>(no_recompute_val); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool WithRecomputedScope(const AnfNodePtr &node) { |
|
|
bool WithRecomputedScope(const AnfNodePtr &node) { |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
if (!node->isa<CNode>()) { |
|
|
if (!node->isa<CNode>()) { |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
return node->fullname_with_scope().find(kAttrRecomputed) == 0; |
|
|
|
|
|
|
|
|
auto full_name_with_scope = node->fullname_with_scope(); |
|
|
|
|
|
return full_name_with_scope.find(kAttrRecompute) == 0 && |
|
|
|
|
|
full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool IsSetRecomputed(const AnfNodePtr &node) { |
|
|
|
|
|
auto prim = GetCNodePrimitive(node); |
|
|
|
|
|
if (prim != nullptr) { |
|
|
|
|
|
auto recompute_val = prim->GetAttr(kAttrRecomputed); |
|
|
|
|
|
if (recompute_val != nullptr && recompute_val->isa<BoolImm>()) { |
|
|
|
|
|
return GetValue<bool>(recompute_val); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
bool HasRecomputeCNodeAttr(const AnfNodePtr &node) { |
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
|
|
if (cnode == nullptr) { |
|
|
|
|
|
return false; |
|
|
} |
|
|
} |
|
|
return false; |
|
|
|
|
|
|
|
|
auto cnode_recompute_val = cnode->GetAttr(kAttrRecompute); |
|
|
|
|
|
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsTargetNode(node) && IsSetRecomputed(node); } |
|
|
|
|
|
|
|
|
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && HasRecomputeCNodeAttr(node); } |
|
|
|
|
|
|
|
|
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, |
|
|
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, |
|
|
const std::vector<CNodePtr> &cnodes) { |
|
|
const std::vector<CNodePtr> &cnodes) { |
|
|
@@ -89,12 +78,12 @@ std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mn |
|
|
} |
|
|
} |
|
|
const auto &node_index_set = output_set_iter->second; |
|
|
const auto &node_index_set = output_set_iter->second; |
|
|
if (!std::any_of(node_index_set.begin(), node_index_set.end(), |
|
|
if (!std::any_of(node_index_set.begin(), node_index_set.end(), |
|
|
[](const std::pair<AnfNodePtr, int> &node_index) { return IsTargetNode(node_index.first); })) { |
|
|
|
|
|
|
|
|
[](const std::pair<AnfNodePtr, int> &node_index) { return IsBpropNode(node_index.first); })) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
// Check inputs. |
|
|
// Check inputs. |
|
|
const auto &inputs = cnode->inputs(); |
|
|
const auto &inputs = cnode->inputs(); |
|
|
if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsTargetNode(node); })) { |
|
|
|
|
|
|
|
|
if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsBpropNode(node); })) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
candidate_recomputed_nodes.emplace_back(cnode); |
|
|
candidate_recomputed_nodes.emplace_back(cnode); |
|
|
@@ -166,7 +155,7 @@ void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng, |
|
|
for (const auto &node_index_set : output_set_iter->second) { |
|
|
for (const auto &node_index_set : output_set_iter->second) { |
|
|
auto output_node = node_index_set.first; |
|
|
auto output_node = node_index_set.first; |
|
|
MS_EXCEPTION_IF_NULL(output_node); |
|
|
MS_EXCEPTION_IF_NULL(output_node); |
|
|
if (!IsTargetNode(output_node)) { |
|
|
|
|
|
|
|
|
if (!IsBpropNode(output_node)) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
target_nodes->insert(output_node->cast<CNodePtr>()); |
|
|
target_nodes->insert(output_node->cast<CNodePtr>()); |
|
|
@@ -215,7 +204,7 @@ bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool> |
|
|
} |
|
|
} |
|
|
const auto &inputs = cnode->inputs(); |
|
|
const auto &inputs = cnode->inputs(); |
|
|
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) { |
|
|
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) { |
|
|
return IsTargetNode(input) || HasGradInputs(input, has_grad_inputs_map); |
|
|
|
|
|
|
|
|
return IsBpropNode(input) || HasGradInputs(input, has_grad_inputs_map); |
|
|
})) { |
|
|
})) { |
|
|
has_grad_inputs_map->insert(std::make_pair(node, true)); |
|
|
has_grad_inputs_map->insert(std::make_pair(node, true)); |
|
|
return true; |
|
|
return true; |
|
|
@@ -232,7 +221,7 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) { |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
for (const auto &node_index_set : output_set_iter->second) { |
|
|
for (const auto &node_index_set : output_set_iter->second) { |
|
|
if (!IsTargetNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) { |
|
|
|
|
|
|
|
|
if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) { |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -255,8 +244,8 @@ void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// Set 'recomputed' attr for the nodes according to its scope. |
|
|
|
|
|
// A node set 'recomputed' attr can be the candidate recomputed node. |
|
|
|
|
|
|
|
|
// Set 'recompute' cnode attr for the nodes according to its scope. |
|
|
|
|
|
// A node set 'recompute' cnode attr can become the candidate recomputed node. |
|
|
void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) { |
|
|
void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
auto mng = graph->manager(); |
|
|
auto mng = graph->manager(); |
|
|
@@ -264,44 +253,44 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o |
|
|
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map; |
|
|
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map; |
|
|
for (const auto &node : origin_nodes_topological) { |
|
|
for (const auto &node : origin_nodes_topological) { |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
if (!WithRecomputedScope(node) || HasNoRecomputedAttr(node)) { |
|
|
|
|
|
|
|
|
if (IsBpropNode(node)) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
auto prim = GetCNodePrimitive(node); |
|
|
|
|
|
if (prim == nullptr || prim->name() == prim::kPrimTupleGetItem->name() || |
|
|
|
|
|
prim->name() == prim::kPrimAllGather->name()) { |
|
|
|
|
|
|
|
|
// Do not recompute the communicate op. |
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimAllGather)) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) { |
|
|
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// Make a new primitive to set attr because some nodes share the same primitive probably. |
|
|
|
|
|
auto new_prim = std::make_shared<Primitive>(prim->name()); |
|
|
|
|
|
new_prim->SetAttrs(prim->attrs()); |
|
|
|
|
|
new_prim->set_prim_type(prim->prim_type()); |
|
|
|
|
|
new_prim->set_attr(kAttrRecomputed, MakeValue(true)); |
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)}; |
|
|
|
|
|
const auto &origin_inputs = node->inputs(); |
|
|
|
|
|
std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs)); |
|
|
|
|
|
auto new_node = graph->NewCNode(new_inputs); |
|
|
|
|
|
new_node->set_abstract(node->abstract()); |
|
|
|
|
|
new_node->set_scope(node->scope()); |
|
|
|
|
|
mng->Replace(node, new_node); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
auto prim = GetCNodePrimitive(cnode); |
|
|
|
|
|
if (prim == nullptr) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
auto prim_recompute_attr = prim->GetAttr(kAttrRecompute); |
|
|
|
|
|
int prim_recompute_val = -1; |
|
|
|
|
|
if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>()) { |
|
|
|
|
|
prim_recompute_val = GetValue<bool>(prim_recompute_attr); |
|
|
|
|
|
} |
|
|
|
|
|
if ((WithRecomputedScope(node) && prim_recompute_val != 0) || prim_recompute_val == 1) { |
|
|
|
|
|
cnode->AddAttr(kAttrRecompute, MakeValue(true)); |
|
|
|
|
|
} |
|
|
|
|
|
if (!HasRecomputeCNodeAttr(node)) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
// Set attr for the tuple_getitem outputs. |
|
|
// Set attr for the tuple_getitem outputs. |
|
|
std::vector<AnfNodePtr> tuple_getitem_output_nodes; |
|
|
std::vector<AnfNodePtr> tuple_getitem_output_nodes; |
|
|
GetTupleGetItemOutputNodes(mng, new_node, &tuple_getitem_output_nodes); |
|
|
|
|
|
|
|
|
GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes); |
|
|
for (const auto &output_node : tuple_getitem_output_nodes) { |
|
|
for (const auto &output_node : tuple_getitem_output_nodes) { |
|
|
auto new_output_prim = std::make_shared<Primitive>(prim::kPrimTupleGetItem->name()); |
|
|
|
|
|
new_output_prim->set_attr(kAttrRecomputed, MakeValue(true)); |
|
|
|
|
|
std::vector<AnfNodePtr> new_tuple_getitem_inputs{NewValueNode(new_output_prim)}; |
|
|
|
|
|
auto origin_tuple_getitem_inputs = output_node->cast<CNodePtr>()->inputs(); |
|
|
|
|
|
std::copy(origin_tuple_getitem_inputs.begin() + 1, origin_tuple_getitem_inputs.end(), |
|
|
|
|
|
std::back_inserter(new_tuple_getitem_inputs)); |
|
|
|
|
|
auto new_tuple_getitem = graph->NewCNode(new_tuple_getitem_inputs); |
|
|
|
|
|
new_tuple_getitem->set_abstract(output_node->abstract()); |
|
|
|
|
|
mng->Replace(output_node, new_tuple_getitem); |
|
|
|
|
|
|
|
|
auto output_cnode = output_node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode); |
|
|
|
|
|
output_cnode->AddAttr(kAttrRecompute, MakeValue(true)); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -318,15 +307,9 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod |
|
|
return iter->second; |
|
|
return iter->second; |
|
|
} |
|
|
} |
|
|
MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString(); |
|
|
MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString(); |
|
|
auto prim = GetCNodePrimitive(origin_node); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
|
|
auto new_prim = std::make_shared<Primitive>(prim->name()); |
|
|
|
|
|
new_prim->SetAttrs(prim->attrs()); |
|
|
|
|
|
new_prim->set_attr("duplicated", MakeValue(true)); |
|
|
|
|
|
new_prim->set_prim_type(prim->prim_type()); |
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)}; |
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs; |
|
|
bool has_recomputed_inputs = false; |
|
|
bool has_recomputed_inputs = false; |
|
|
for (size_t i = 1; i < origin_node->size(); ++i) { |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < origin_node->size(); ++i) { |
|
|
auto input = origin_node->input(i); |
|
|
auto input = origin_node->input(i); |
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
if (!input->isa<CNode>()) { |
|
|
if (!input->isa<CNode>()) { |
|
|
@@ -356,6 +339,8 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod |
|
|
new_inputs[1] = depend_node; |
|
|
new_inputs[1] = depend_node; |
|
|
} |
|
|
} |
|
|
auto recomputed_node = graph->NewCNode(new_inputs); |
|
|
auto recomputed_node = graph->NewCNode(new_inputs); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(recomputed_node); |
|
|
|
|
|
recomputed_node->AddAttr("duplicated", MakeValue(true)); |
|
|
recomputed_node->set_abstract(origin_node->abstract()); |
|
|
recomputed_node->set_abstract(origin_node->abstract()); |
|
|
recomputed_node->set_scope(origin_node->scope()); |
|
|
recomputed_node->set_scope(origin_node->scope()); |
|
|
origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); |
|
|
origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); |
|
|
@@ -374,10 +359,6 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_se |
|
|
MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input"; |
|
|
MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input"; |
|
|
auto target_cnode = target_node->cast<CNodePtr>(); |
|
|
auto target_cnode = target_node->cast<CNodePtr>(); |
|
|
MS_EXCEPTION_IF_NULL(target_cnode); |
|
|
MS_EXCEPTION_IF_NULL(target_cnode); |
|
|
auto prim = GetCNodePrimitive(target_cnode); |
|
|
|
|
|
if (prim != nullptr) { |
|
|
|
|
|
prim->set_attr("target_grad", MakeValue(true)); |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<AnfNodePtr> new_target_inputs; |
|
|
std::vector<AnfNodePtr> new_target_inputs; |
|
|
for (const auto &input : target_cnode->inputs()) { |
|
|
for (const auto &input : target_cnode->inputs()) { |
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
@@ -394,6 +375,7 @@ void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_se |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
auto new_target_node = graph->NewCNode(new_target_inputs); |
|
|
auto new_target_node = graph->NewCNode(new_target_inputs); |
|
|
|
|
|
new_target_node->AddAttr("target_grad", MakeValue(true)); |
|
|
new_target_node->set_abstract(target_node->abstract()); |
|
|
new_target_node->set_abstract(target_node->abstract()); |
|
|
new_target_node->set_scope(target_node->scope()); |
|
|
new_target_node->set_scope(target_node->scope()); |
|
|
mng->Replace(target_node, new_target_node); |
|
|
mng->Replace(target_node, new_target_node); |
|
|
@@ -405,11 +387,9 @@ void InsertRecomputedNodes(const FuncGraphPtr &graph) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
auto mng = graph->manager(); |
|
|
auto mng = graph->manager(); |
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
std::list<CNodePtr> old_orders = graph->GetOrderedCnodes(); |
|
|
|
|
|
std::vector<CNodePtr> old_nodes_topological(old_orders.begin(), old_orders.end()); |
|
|
|
|
|
SetRecomputedAttr(graph, old_nodes_topological); |
|
|
|
|
|
std::list<CNodePtr> new_orders = graph->GetOrderedCnodes(); |
|
|
|
|
|
std::vector<CNodePtr> origin_nodes_topological(new_orders.begin(), new_orders.end()); |
|
|
|
|
|
|
|
|
std::list<CNodePtr> orders = graph->GetOrderedCnodes(); |
|
|
|
|
|
std::vector<CNodePtr> origin_nodes_topological(orders.begin(), orders.end()); |
|
|
|
|
|
SetRecomputedAttr(graph, origin_nodes_topological); |
|
|
// Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly. |
|
|
// Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly. |
|
|
std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological); |
|
|
std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological); |
|
|
std::unordered_set<CNodePtr> visited_nodes; |
|
|
std::unordered_set<CNodePtr> visited_nodes; |
|
|
|