|
|
|
@@ -33,7 +33,8 @@ namespace { |
|
|
|
constexpr auto kGradientsFlag = "Gradients"; |
|
|
|
|
|
|
|
bool CanNotRecomputed(const CNodePtr &node) { |
|
|
|
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask}; |
|
|
|
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask, |
|
|
|
prim::kPrimLoad, prim::kPrimTupleGetItem}; |
|
|
|
|
|
|
|
return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(), |
|
|
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); |
|
|
|
@@ -56,16 +57,26 @@ bool WithRecomputedScope(const AnfNodePtr &node) { |
|
|
|
return full_name_with_scope.find(kAttrRecompute) == 0; |
|
|
|
} |
|
|
|
|
|
|
|
bool HasRecomputeCNodeAttr(const AnfNodePtr &node) { |
|
|
|
ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr) { |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto cnode_recompute_val = cnode->GetAttr(kAttrRecompute); |
|
|
|
return cnode->GetAttr(kAttrRecompute); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsSetNoRecomputeCNodeAttr(const AnfNodePtr &node) { |
|
|
|
auto cnode_recompute_val = GetRecomputeCNodeAttr(node); |
|
|
|
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && !GetValue<bool>(cnode_recompute_val); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node) { |
|
|
|
auto cnode_recompute_val = GetRecomputeCNodeAttr(node); |
|
|
|
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && HasRecomputeCNodeAttr(node); } |
|
|
|
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && IsSetRecomputeCNodeAttr(node); } |
|
|
|
|
|
|
|
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, |
|
|
|
const std::vector<CNodePtr> &cnodes) { |
|
|
|
@@ -213,11 +224,15 @@ bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool> |
|
|
|
return false; |
|
|
|
} |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) { |
|
|
|
return IsBpropNode(input) || HasGradInputs(input, has_grad_inputs_map); |
|
|
|
})) { |
|
|
|
has_grad_inputs_map->insert(std::make_pair(node, true)); |
|
|
|
return true; |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
// For the pipeline split case, the forward pass may depend on the backward pass. |
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && i == kDependAttachNodeIndex) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) { |
|
|
|
has_grad_inputs_map->insert(std::make_pair(node, true)); |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
has_grad_inputs_map->insert(std::make_pair(node, false)); |
|
|
|
return false; |
|
|
|
@@ -265,6 +280,10 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o |
|
|
|
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map; |
|
|
|
for (const auto &node : origin_nodes_topological) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// The node may be set the non-recomputed before such as the cell outputs. |
|
|
|
if (IsSetNoRecomputeCNodeAttr(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsBpropNode(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -272,9 +291,6 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o |
|
|
|
if (CanNotRecomputed(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -293,7 +309,7 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o |
|
|
|
if ((SetRecomputedScope(cnode) && prim_recompute_val != 0) || prim_recompute_val == 1) { |
|
|
|
cnode->AddAttr(kAttrRecompute, MakeValue(true)); |
|
|
|
} |
|
|
|
if (!HasRecomputeCNodeAttr(node)) { |
|
|
|
if (!IsSetRecomputeCNodeAttr(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Set attr for the tuple_getitem outputs. |
|
|
|
|