|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "optimizer/ad/dfunctor.h"
-
- #include <memory>
- #include <string>
- #include <utility>
-
- #include "ir/anf.h"
- #include "ir/meta_func_graph.h"
- #include "debug/info.h"
- #include "ir/func_graph_cloner.h"
- #include "ir/manager.h"
- #include "pipeline/resource.h"
- #include "pipeline/parse/parse.h"
- #include "optimizer/ad/adjoint.h"
- #include "optimizer/opt.h"
- #include "operator/ops.h"
- #include "operator/composite/composite.h"
- #include "utils/symbolic.h"
- #include "./common.h"
-
- namespace mindspore {
- namespace ad {
- std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
- std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
-
- DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
- : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
- TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
- k_graph_ = std::make_shared<FuncGraph>();
- TraceManager::EndTrace();
-
- TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
- tape_ = std::make_shared<FuncGraph>();
- TraceManager::EndTrace();
-
- dout_ = tape_->add_parameter();
- }
-
- void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
- func_graph_to_functor_[primal_graph_] = functor;
- is_top_ = is_top;
- }
-
- void DFunctor::Clear() {
- func_graph_to_functor_.clear();
- anfnode_to_adjoin_definition_.clear();
- }
-
- void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
- auto fv_adjoint = anfnode_to_adjoin_.find(fv);
- if (fv_adjoint == anfnode_to_adjoin_.end()) {
- MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
- << " " << fv->ToString() << ".";
- fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
- if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
- MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
- << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
- auto parent_adjoint = FindAdjoint(fv);
- AdjointPtr adjoint = nullptr;
- if (parent_adjoint != nullptr) {
- adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
- } else {
- MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
- << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
- adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
- }
- anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
- fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
- }
- }
- auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
- fv_adjoint->second->RegisterKUser(key, 1);
- auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()});
- fv_adjoint->second->RegisterKUser(default_val, 1);
- auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, key, default_val});
- MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
- << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
- MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << key->ToString() << ".";
- fv_adjoint->second->AccumulateDout(dfv);
- }
-
- void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
- auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)});
- // Call with delimited continuation dout.
- auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
- node_adjoint->RegisterDoutUser(bprop_app, 1);
- for (size_t i = 0; i < cnode_morph->size(); i++) {
- auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))});
- auto input = cnode_morph->input(i);
- // Backprop sens wrt fvs.
- if (IsValueNode<FuncGraph>(input)) {
- auto func_graph = GetValueNode<FuncGraphPtr>(input);
- auto functor = func_graph_to_functor_.find(func_graph);
- if (functor == func_graph_to_functor_.end()) {
- MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] "
- << func_graph->ToString() << ".";
- }
- // Consider direct and indirect fvs.
- for (auto fv : func_graph->free_variables_nodes()) {
- BackPropagateFv(fv, din);
- }
- for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
- MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " "
- << indirect_fv.first->ToString() << ".";
- BackPropagateFv(indirect_fv.first, din);
- }
- continue;
- }
- // Backprop sens wrt inputs.
- auto input_adjoint = anfnode_to_adjoin_.find(input);
- if (input_adjoint == anfnode_to_adjoin_.end()) {
- MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
- }
- input_adjoint->second->AccumulateDout(din);
- }
- }
-
- // Map a morphism.
- AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
- // MapMorphism All type except CNode should already be mapped by MapObject.
- if (!morph->isa<CNode>()) {
- return nullptr;
- }
- ScopeGuard scope_guard(morph->scope());
- auto cnode_morph = morph->cast<CNodePtr>();
-
- std::vector<AnfNodePtr> inputs;
- std::vector<AdjointPtr> param_adjoints;
- for (size_t i = 0; i < cnode_morph->size(); i++) {
- auto node = cnode_morph->input(i);
- auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
- AdjointPtr node_adjoint = nullptr;
- AnfNodePtr k = nullptr;
- if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
- node_adjoint = node_adjoint_iter->second;
- } else {
- // Input might be a CNode that needs to be handled before hand.
- node_adjoint = MapMorphism(node);
- }
- MS_EXCEPTION_IF_NULL(node_adjoint);
- k = node_adjoint->k();
- if (k == nullptr) {
- MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
- }
- inputs.push_back(k);
- param_adjoints.push_back(node_adjoint);
- }
- TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
- auto k_app = k_graph_->NewCNode(inputs);
- TraceManager::EndTrace();
- for (size_t i = 0; i < param_adjoints.size(); ++i) {
- param_adjoints[i]->RegisterKUser(k_app, i);
- }
-
- // Do forward computation
- auto foward_app = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(0)});
- // K:: cnode -> forward_app
- auto node_adjoint = std::make_shared<Adjoint>(morph, foward_app, tape_);
- UpdateAdjoint(node_adjoint);
- anfnode_to_adjoin_[morph] = node_adjoint;
- if (cnode_morph->stop_gradient()) {
- MS_LOG(WARNING) << "MapMorphism node " << morph->ToString() << " is stopped.";
- return node_adjoint;
- }
-
- // Do sens backpropagation
- BackPropagate(cnode_morph, k_app, node_adjoint);
- MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << ".";
- return node_adjoint;
- }
-
- bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
- // Do not care about non-CNode
- if (!node->isa<CNode>()) {
- return false;
- }
- // Do not care about kPrimReturn
- if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
- return false;
- }
- auto &users = primal_graph_->manager()->node_users()[node];
- // Do not care about isolated morphisms
- if (users.empty()) {
- return false;
- }
- // Not free if it's used by some node in primal_graph
- bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
- auto &user = kv.first;
- return user->func_graph() == primal_graph_;
- });
- return !nonfree;
- }
-
- void DFunctor::MapFreeMorphism() {
- // Handle cnode not attached to output, that might be refered in other functions.
- for (auto &node : primal_graph_->nodes()) {
- if (!IsFreeMorphism(node)) {
- continue;
- }
- MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
- (void)MapMorphism(node);
- }
- }
-
- AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
- AnfNodePtr new_grad_fv = grad_fv;
- // Add grads wrt fv.
- const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
- for (auto &fv : free_variables_nodes) {
- auto fv_adjoint = anfnode_to_adjoin_.find(fv);
- if (fv_adjoint == anfnode_to_adjoin_.end()) {
- MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
- }
- auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
- fv_adjoint->second->RegisterKUser(key, 1);
- auto sens = fv_adjoint->second->dout();
- new_grad_fv = tape_->NewCNode({
- NewValueNode(prim::kPrimEnvSetItem),
- new_grad_fv,
- key,
- sens,
- });
- fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
- MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " "
- << fv->ToString() << " " << primal_graph_->ToString() << ".";
- }
- return new_grad_fv;
- }
-
- AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
- AnfNodePtr new_grad_fv = grad_fv;
- // Add indirect fv bprop.
- for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
- MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " "
- << primal_graph_->ToString() << ".";
- auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
- fv_adjoint.second->RegisterKUser(key, 1);
- auto sens = fv_adjoint.second->dout();
- new_grad_fv = tape_->NewCNode({
- NewValueNode(prim::kPrimEnvSetItem),
- new_grad_fv,
- key,
- sens,
- });
- fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
- MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to "
- << new_grad_fv->ToString() << ".";
- }
- return new_grad_fv;
- }
-
- void DFunctor::MapMorphism() {
- // Set stop_gradient before MapMorphism.
- BroadCastStopFlag();
-
- // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
- MapFreeMorphism();
- // Handle morphism from output.
- (void)MapMorphism(primal_graph_->output());
-
- // Construct K for primal_graph_
- auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
- // Attach dout_ parameter to output_adjoint.
- output_adjoint->second->AccumulateDout(dout_);
-
- // Set output for tape closure.
- auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
-
- std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
- // Add grads wrt inputs.
- std::vector<AdjointPtr> param_adjoints;
- for (auto ¶m : primal_graph_->parameters()) {
- auto param_adjoint = anfnode_to_adjoin_.find(param);
- inputs.push_back(param_adjoint->second->dout());
- param_adjoints.push_back(param_adjoint->second);
- }
- auto tape_output = tape_->NewCNode(inputs);
- for (size_t i = 0; i < param_adjoints.size(); ++i) {
- param_adjoints[i]->RegisterDoutUser(tape_output, i + 2);
- }
- tape_->set_output(tape_output);
- // Set output for k_graph_, K:: cnode->forward_app.
- auto forward_app = output_adjoint->second->k();
- auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
- output_adjoint->second->RegisterKUser(output, 1);
- k_graph_->set_output(output);
- (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_)));
- (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_)));
- }
-
- FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
- // K user defined cell bprop.
- auto bprop = primal->transforms().find("bprop");
- if (bprop != primal->transforms().end()) {
- FuncGraphPtr bprop_graph = bprop->second.func_graph();
- const size_t param_diff = 1;
- if (bprop_graph->output()->isa<CNode>() &&
- bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
- // It does not matter with the final tangents, just a tip for debugging
- MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope "
- << primal->output()->scope()->name()
- << " output must be a tuple and output number should be the same with inputs.";
- }
- resources_->manager()->AddFuncGraph(bprop_graph);
-
- if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) {
- MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
- << primal->output()->scope()->name() << " does not support Parameter data type.";
- }
- auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph);
- if (fg == nullptr) {
- MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
- << primal->output()->scope()->name() << ".";
- }
-
- // Cache the grad func
- (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
- (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
- // Reset defer_inline to enable successive inlining
- primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
-
- auto functor = std::make_shared<DFunctor>(primal, resources_);
- functor->Init(functor);
- functor->k_graph_ = fg;
-
- return fg;
- }
- return nullptr;
- }
-
- // MapToK(func)
- AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
- auto f = func_graph_to_functor_.find(primal);
- if (f != func_graph_to_functor_.end()) {
- MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << ".";
- return NewValueNode(f->second->k_graph_);
- }
-
- auto k_user_defined = KUserDefined(primal);
- if (k_user_defined != nullptr) {
- MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << ".";
- return NewValueNode(k_user_defined);
- }
-
- auto functor = std::make_shared<DFunctor>(primal, resources_);
- functor->Init(functor);
- functor->MapObject();
- functor->MapMorphism();
-
- MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << ".";
- return NewValueNode(functor->k_graph_);
- }
-
- // Construct representation graph for given node.
- AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
- ScopeGuard scope_guard(primal->scope());
- // MapToK(prim)
- if (IsValueNode<Primitive>(primal)) {
- auto value_node = primal->cast<ValueNodePtr>();
- auto prim = GetValueNode<PrimitivePtr>(value_node);
- if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
- MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
- need_cut_ = true;
- }
- auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
- if (k_prim != nullptr) {
- k_prim = BasicClone(k_prim);
- return NewValueNode(k_prim);
- }
- // When failed to find k_prim, try k_meta.
- auto k_meta = g_k_prims.KMetaFuncGraph(prim);
- if (k_meta != nullptr) {
- return NewValueNode(k_meta);
- }
- }
-
- // MapToK(func)
- if (IsValueNode<FuncGraph>(primal)) {
- auto func_graph = GetValueNode<FuncGraphPtr>(primal);
- auto k_func = MapToK(func_graph);
- return k_func;
- }
-
- if (primal->isa<Parameter>()) {
- TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal->debug_info()));
- auto ret = k_graph_->add_parameter();
- TraceManager::EndTrace();
- return ret;
- }
-
- if (!primal->isa<ValueNode>()) {
- MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode.";
- }
- return primal;
- }
-
- void DFunctor::MapFvObject() {
- // Map free variable.
- const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
- for (auto &node : free_variables_nodes) {
- ScopeGuard scope_guard(node->scope());
- MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << ".";
- // Find fv's K from parent.
- AdjointPtr adjoint = nullptr;
- auto parent_adjoint = FindAdjoint(node);
- if (parent_adjoint != nullptr) {
- adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
- } else {
- if (is_top_) {
- // Top graph for ad, add adjoint for free variables.
- adjoint = std::make_shared<Adjoint>(node, node, tape_);
- UpdateAdjoint(adjoint);
- } else {
- MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << ".";
- adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
- }
- }
- if (adjoint == nullptr) {
- MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << ".";
- }
- anfnode_to_adjoin_[node] = adjoint;
- }
- }
-
- void DFunctor::MapParamObject() {
- // Map parameter.
- for (auto &p : primal_graph_->parameters()) {
- ScopeGuard scope_guard(p->scope());
- MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
- auto adjoint = std::make_shared<Adjoint>(p, MapToK(p), tape_);
- UpdateAdjoint(adjoint);
- anfnode_to_adjoin_[p] = adjoint;
- }
- }
-
- void DFunctor::MapValueObject() {
- // Map ValueNode.
- auto manager = resources_->manager();
- auto &value_nodes = manager->valuenodes()[primal_graph_];
- for (const auto &value_pair : value_nodes) {
- auto node = value_pair.first;
- auto parent_adjoint = FindAdjoint(node);
- if (parent_adjoint != nullptr) {
- auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
- anfnode_to_adjoin_[node] = adjoint;
- continue;
- }
- // Skip Return.
- if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
- continue;
- }
- MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << ".";
- auto adjoint = std::make_shared<Adjoint>(node, MapToK(node), tape_);
- UpdateAdjoint(adjoint);
- anfnode_to_adjoin_[node] = adjoint;
- }
- }
-
- // Skip morphism.
- void DFunctor::MapObject() {
- // The order does not matter
- MapFvObject();
- MapParamObject();
- MapValueObject();
- }
-
- void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
- auto primal = adjoint_definition->primal();
- if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
- MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
- << primal->ToString() << ".";
- }
- anfnode_to_adjoin_definition_[primal] = adjoint_definition;
- // Update k hole for primal.
- for (auto &f : func_graph_to_functor_) {
- auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
- if (adjoint != f.second->anfnode_to_adjoin_.end()) {
- adjoint->second->UpdateK(adjoint_definition->k());
- }
- adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
- if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
- adjoint->second->UpdateK(adjoint_definition->k());
- }
- }
- }
-
- AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
- auto adjoint = anfnode_to_adjoin_definition_.find(primal);
- if (adjoint != anfnode_to_adjoin_definition_.end()) {
- MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << ".";
- return adjoint->second;
- }
- MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << ".";
- return nullptr;
- }
-
- void DFunctor::CallDoutHoleOnTape() {
- // Call dout hole of all adjoint.
- for (auto &f : func_graph_to_functor_) {
- for (auto &adjoint : f.second->anfnode_to_adjoin_) {
- adjoint.second->CallDoutHole();
- }
- for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
- adjoint.second->CallDoutHole();
- }
- }
- }
- FuncGraphPtr DFunctor::k_graph() {
- CallDoutHoleOnTape();
- return k_graph_;
- }
-
- void DFunctor::BroadCastStopFlag() {
- // As stop set expanding, all directly or indirectly stopped CNode will be cut off
- while (need_cut_) {
- need_cut_ = false;
- for (auto &node : primal_graph_->nodes()) {
- if (node->isa<CNode>()) {
- auto cnode = node->cast<CNodePtr>();
- if (!cnode->stop_gradient()) {
- // Cut off the cnode only when it's not referred any more
- if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) {
- MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
- cnode->set_stop_gradient(true);
- // The stop set changed, more cut required
- need_cut_ = true;
- }
- }
- }
- }
- }
- }
-
- bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
- auto &users = primal_graph_->manager()->node_users()[node];
- // Only care about stop_gradient caused cutting
- if (users.empty()) {
- return false;
- }
- for (auto &kv : users) {
- auto &user = kv.first;
- if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
- return false;
- }
- }
- return true;
- }
- } // namespace ad
- } // namespace mindspore
|