You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dfunctor.cc 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "optimizer/ad/dfunctor.h"
  17. #include <memory>
  18. #include <string>
  19. #include <utility>
  20. #include "ir/anf.h"
  21. #include "ir/meta_func_graph.h"
  22. #include "debug/info.h"
  23. #include "ir/func_graph_cloner.h"
  24. #include "ir/manager.h"
  25. #include "pipeline/resource.h"
  26. #include "pipeline/parse/parse.h"
  27. #include "optimizer/ad/adjoint.h"
  28. #include "optimizer/opt.h"
  29. #include "operator/ops.h"
  30. #include "operator/composite/composite.h"
  31. #include "utils/symbolic.h"
  32. #include "utils/context/ms_context.h"
  33. #include "./common.h"
  34. namespace mindspore {
  35. namespace ad {
  36. std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
  37. std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
  38. FuncGraphSet DFunctor::scope_;
  39. DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
  40. : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
  41. TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
  42. k_graph_ = std::make_shared<FuncGraph>();
  43. TraceManager::EndTrace();
  44. TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
  45. tape_ = std::make_shared<FuncGraph>();
  46. TraceManager::EndTrace();
  47. dout_ = tape_->add_parameter();
  48. }
  49. void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
  50. func_graph_to_functor_[primal_graph_] = functor;
  51. is_top_ = is_top;
  52. if (is_top) {
  53. scope_ = primal_graph_->scope();
  54. }
  55. }
  56. void DFunctor::Clear() {
  57. func_graph_to_functor_.clear();
  58. anfnode_to_adjoin_definition_.clear();
  59. scope_.clear();
  60. }
  61. void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
  62. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  63. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  64. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
  65. << " " << fv->ToString() << ".";
  66. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  67. if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
  68. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
  69. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  70. auto parent_adjoint = FindAdjoint(fv);
  71. AdjointPtr adjoint = nullptr;
  72. if (parent_adjoint != nullptr) {
  73. adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
  74. } else {
  75. MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
  76. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  77. adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
  78. }
  79. anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
  80. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  81. }
  82. }
  83. auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
  84. fv_adjoint->second->RegisterKUser(key, 1);
  85. auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()});
  86. fv_adjoint->second->RegisterKUser(default_val, 1);
  87. auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, key, default_val});
  88. MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
  89. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  90. MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << key->ToString() << ".";
  91. fv_adjoint->second->AccumulateDout(dfv);
  92. }
  93. void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
  94. // Take switch_layer as a set of candidate functions.
  95. auto input = cnode_morph->input(2);
  96. if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
  97. MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
  98. }
  99. auto tuple_graphs = input->cast<CNodePtr>();
  100. for (size_t i = 1; i < tuple_graphs->size(); ++i) {
  101. auto graph = tuple_graphs->input(i);
  102. if (!IsValueNode<FuncGraph>(graph)) {
  103. MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
  104. << " as the " << i << "th element.";
  105. }
  106. auto func_graph = GetValueNode<FuncGraphPtr>(graph);
  107. auto functor = func_graph_to_functor_.find(func_graph);
  108. if (functor == func_graph_to_functor_.end()) {
  109. MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
  110. << func_graph->ToString() << ".";
  111. }
  112. // Consider direct and indirect fvs.
  113. for (auto fv : func_graph->free_variables_nodes()) {
  114. BackPropagateFv(fv, env);
  115. }
  116. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  117. MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
  118. << indirect_fv.first->ToString() << ".";
  119. BackPropagateFv(indirect_fv.first, env);
  120. }
  121. }
  122. }
  123. void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
  124. auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)});
  125. // Call with delimited continuation dout.
  126. auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
  127. node_adjoint->RegisterDoutUser(bprop_app, 1);
  128. // Special case for switch_layer
  129. if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
  130. auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)});
  131. BackPropagateSwitchLayer(cnode_morph, din);
  132. return;
  133. }
  134. for (size_t i = 0; i < cnode_morph->size(); i++) {
  135. auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))});
  136. auto input = cnode_morph->input(i);
  137. // Backprop sens wrt fvs.
  138. if (IsValueNode<FuncGraph>(input)) {
  139. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  140. auto functor = func_graph_to_functor_.find(func_graph);
  141. if (functor == func_graph_to_functor_.end()) {
  142. MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] "
  143. << func_graph->ToString() << ".";
  144. }
  145. // Consider direct and indirect fvs.
  146. for (auto fv : func_graph->free_variables_nodes()) {
  147. BackPropagateFv(fv, din);
  148. }
  149. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  150. MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " "
  151. << indirect_fv.first->ToString() << ".";
  152. BackPropagateFv(indirect_fv.first, din);
  153. }
  154. continue;
  155. }
  156. // Backprop sens wrt inputs.
  157. auto input_adjoint = anfnode_to_adjoin_.find(input);
  158. if (input_adjoint == anfnode_to_adjoin_.end()) {
  159. MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
  160. }
  161. input_adjoint->second->AccumulateDout(din);
  162. }
  163. }
  164. // Map a morphism.
  165. AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
  166. // MapMorphism All type except CNode should already be mapped by MapObject.
  167. if (!morph->isa<CNode>()) {
  168. return nullptr;
  169. }
  170. ScopeGuard scope_guard(morph->scope());
  171. auto cnode_morph = morph->cast<CNodePtr>();
  172. std::vector<AnfNodePtr> inputs;
  173. std::vector<AdjointPtr> param_adjoints;
  174. for (size_t i = 0; i < cnode_morph->size(); i++) {
  175. auto node = cnode_morph->input(i);
  176. auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
  177. AdjointPtr node_adjoint = nullptr;
  178. AnfNodePtr k = nullptr;
  179. if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
  180. node_adjoint = node_adjoint_iter->second;
  181. } else {
  182. // Input might be a CNode that needs to be handled before hand.
  183. node_adjoint = MapMorphism(node);
  184. }
  185. MS_EXCEPTION_IF_NULL(node_adjoint);
  186. k = node_adjoint->k();
  187. if (k == nullptr) {
  188. MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
  189. }
  190. inputs.push_back(k);
  191. param_adjoints.push_back(node_adjoint);
  192. }
  193. TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
  194. auto k_app = k_graph_->NewCNode(inputs);
  195. TraceManager::EndTrace();
  196. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  197. param_adjoints[i]->RegisterKUser(k_app, i);
  198. }
  199. // Do forward computation
  200. auto foward_app = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(0)});
  201. // K:: cnode -> forward_app
  202. auto node_adjoint = std::make_shared<Adjoint>(morph, foward_app, tape_);
  203. UpdateAdjoint(node_adjoint);
  204. anfnode_to_adjoin_[morph] = node_adjoint;
  205. if (cnode_morph->stop_gradient()) {
  206. MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
  207. return node_adjoint;
  208. }
  209. // Do sens backpropagation
  210. BackPropagate(cnode_morph, k_app, node_adjoint);
  211. MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << ".";
  212. return node_adjoint;
  213. }
  214. bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
  215. // Do not care about non-CNode
  216. if (!node->isa<CNode>()) {
  217. return false;
  218. }
  219. // Do not care about kPrimReturn
  220. if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
  221. return false;
  222. }
  223. auto &users = primal_graph_->manager()->node_users()[node];
  224. // Do not care about isolated morphisms
  225. if (users.empty()) {
  226. return false;
  227. }
  228. // Not free if it's used by some node in primal_graph
  229. bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
  230. auto &user = kv.first;
  231. return user->func_graph() == primal_graph_;
  232. });
  233. return !nonfree;
  234. }
  235. void DFunctor::MapFreeMorphism() {
  236. // Handle cnode not attached to output, that might be refered in other functions.
  237. for (auto &node : primal_graph_->nodes()) {
  238. if (!IsFreeMorphism(node)) {
  239. continue;
  240. }
  241. MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
  242. (void)MapMorphism(node);
  243. }
  244. }
  245. AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
  246. AnfNodePtr new_grad_fv = grad_fv;
  247. // Add grads wrt fv.
  248. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  249. for (auto &fv : free_variables_nodes) {
  250. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  251. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  252. MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
  253. }
  254. auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
  255. fv_adjoint->second->RegisterKUser(key, 1);
  256. auto sens = fv_adjoint->second->dout();
  257. new_grad_fv = tape_->NewCNode({
  258. NewValueNode(prim::kPrimEnvSetItem),
  259. new_grad_fv,
  260. key,
  261. sens,
  262. });
  263. fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
  264. MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " "
  265. << fv->ToString() << " " << primal_graph_->ToString() << ".";
  266. }
  267. return new_grad_fv;
  268. }
  269. AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
  270. AnfNodePtr new_grad_fv = grad_fv;
  271. // Add indirect fv bprop.
  272. for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
  273. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " "
  274. << primal_graph_->ToString() << ".";
  275. auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
  276. fv_adjoint.second->RegisterKUser(key, 1);
  277. auto sens = fv_adjoint.second->dout();
  278. new_grad_fv = tape_->NewCNode({
  279. NewValueNode(prim::kPrimEnvSetItem),
  280. new_grad_fv,
  281. key,
  282. sens,
  283. });
  284. fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
  285. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to "
  286. << new_grad_fv->ToString() << ".";
  287. }
  288. return new_grad_fv;
  289. }
  290. void DFunctor::MapMorphism() {
  291. // Set stop_gradient before MapMorphism.
  292. BroadCastStopFlag();
  293. // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
  294. MapFreeMorphism();
  295. // Handle morphism from output.
  296. (void)MapMorphism(primal_graph_->output());
  297. // Construct K for primal_graph_
  298. auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
  299. // Attach dout_ parameter to output_adjoint.
  300. output_adjoint->second->AccumulateDout(dout_);
  301. // Set output for tape closure.
  302. auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
  303. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
  304. // Add grads wrt inputs.
  305. std::vector<AdjointPtr> param_adjoints;
  306. for (auto &param : primal_graph_->parameters()) {
  307. auto param_adjoint = anfnode_to_adjoin_.find(param);
  308. inputs.push_back(param_adjoint->second->dout());
  309. param_adjoints.push_back(param_adjoint->second);
  310. }
  311. auto tape_output = tape_->NewCNode(inputs);
  312. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  313. param_adjoints[i]->RegisterDoutUser(tape_output, i + 2);
  314. }
  315. tape_->set_output(tape_output);
  316. // Set output for k_graph_, K:: cnode->forward_app.
  317. auto forward_app = output_adjoint->second->k();
  318. auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
  319. output_adjoint->second->RegisterKUser(output, 1);
  320. k_graph_->set_output(output);
  321. (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_)));
  322. (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_)));
  323. }
  324. FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
  325. // K user defined cell bprop.
  326. auto bprop = primal->transforms().find("bprop");
  327. if (bprop != primal->transforms().end()) {
  328. FuncGraphPtr bprop_graph = bprop->second.func_graph();
  329. resources_->manager()->AddFuncGraph(bprop_graph);
  330. if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) {
  331. MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
  332. << primal->output()->scope()->name() << " does not support Parameter data type.";
  333. }
  334. auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph);
  335. if (fg == nullptr) {
  336. MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
  337. << primal->output()->scope()->name() << ".";
  338. }
  339. // Cache the grad func
  340. (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
  341. (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
  342. // Reset defer_inline to enable successive inlining
  343. primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
  344. auto functor = std::make_shared<DFunctor>(primal, resources_);
  345. functor->Init(functor);
  346. functor->k_graph_ = fg;
  347. return fg;
  348. }
  349. return nullptr;
  350. }
  351. // MapToK(func)
  352. AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
  353. auto f = func_graph_to_functor_.find(primal);
  354. if (f != func_graph_to_functor_.end()) {
  355. MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << ".";
  356. return NewValueNode(f->second->k_graph_);
  357. }
  358. auto k_user_defined = KUserDefined(primal);
  359. if (k_user_defined != nullptr) {
  360. MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << ".";
  361. return NewValueNode(k_user_defined);
  362. }
  363. auto functor = std::make_shared<DFunctor>(primal, resources_);
  364. functor->Init(functor);
  365. functor->MapObject();
  366. functor->MapMorphism();
  367. MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << ".";
  368. return NewValueNode(functor->k_graph_);
  369. }
  370. // Construct representation graph for given node.
  371. AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
  372. ScopeGuard scope_guard(primal->scope());
  373. // MapToK(prim)
  374. if (IsValueNode<Primitive>(primal)) {
  375. auto value_node = primal->cast<ValueNodePtr>();
  376. auto prim = GetValueNode<PrimitivePtr>(value_node);
  377. if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
  378. MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
  379. need_cut_ = true;
  380. }
  381. auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
  382. if (k_prim != nullptr) {
  383. k_prim = BasicClone(k_prim);
  384. return NewValueNode(k_prim);
  385. }
  386. // When failed to find k_prim, try k_meta.
  387. auto k_meta = g_k_prims.KMetaFuncGraph(prim);
  388. if (k_meta != nullptr) {
  389. return NewValueNode(k_meta);
  390. }
  391. }
  392. // MapToK(func)
  393. if (IsValueNode<FuncGraph>(primal)) {
  394. auto func_graph = GetValueNode<FuncGraphPtr>(primal);
  395. auto k_func = MapToK(func_graph);
  396. return k_func;
  397. }
  398. if (primal->isa<Parameter>()) {
  399. TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal->debug_info()));
  400. auto ret = k_graph_->add_parameter();
  401. TraceManager::EndTrace();
  402. return ret;
  403. }
  404. if (!primal->isa<ValueNode>()) {
  405. MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode.";
  406. }
  407. return primal;
  408. }
  409. bool DFunctor::IsInScope(const AnfNodePtr &node) {
  410. return std::any_of(scope_.begin(), scope_.end(),
  411. [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; });
  412. }
  413. void DFunctor::MapFvObject() {
  414. // Map free variable.
  415. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  416. for (auto &node : free_variables_nodes) {
  417. ScopeGuard scope_guard(node->scope());
  418. MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << ".";
  419. // Find fv's K from parent.
  420. AdjointPtr adjoint = nullptr;
  421. auto parent_adjoint = FindAdjoint(node);
  422. if (parent_adjoint != nullptr) {
  423. adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  424. } else {
  425. if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) {
  426. // Out of ad scope, add adjoint for free variables.
  427. adjoint = std::make_shared<Adjoint>(node, node, tape_);
  428. UpdateAdjoint(adjoint);
  429. } else {
  430. MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << ".";
  431. adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
  432. }
  433. }
  434. if (adjoint == nullptr) {
  435. MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << ".";
  436. }
  437. anfnode_to_adjoin_[node] = adjoint;
  438. }
  439. }
  440. void DFunctor::MapParamObject() {
  441. // Map parameter.
  442. for (auto &p : primal_graph_->parameters()) {
  443. ScopeGuard scope_guard(p->scope());
  444. MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
  445. auto adjoint = std::make_shared<Adjoint>(p, MapToK(p), tape_);
  446. UpdateAdjoint(adjoint);
  447. anfnode_to_adjoin_[p] = adjoint;
  448. }
  449. }
  450. void DFunctor::MapValueObject() {
  451. // Map ValueNode.
  452. auto manager = resources_->manager();
  453. auto &value_nodes = primal_graph_->value_nodes();
  454. for (const auto &value_pair : value_nodes) {
  455. auto node = value_pair.first;
  456. auto parent_adjoint = FindAdjoint(node);
  457. if (parent_adjoint != nullptr) {
  458. auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  459. anfnode_to_adjoin_[node] = adjoint;
  460. continue;
  461. }
  462. // Skip Return.
  463. if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
  464. continue;
  465. }
  466. MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << ".";
  467. auto adjoint = std::make_shared<Adjoint>(node, MapToK(node), tape_);
  468. UpdateAdjoint(adjoint);
  469. anfnode_to_adjoin_[node] = adjoint;
  470. }
  471. }
  472. // Skip morphism.
  473. void DFunctor::MapObject() {
  474. // The order does not matter
  475. MapFvObject();
  476. MapParamObject();
  477. MapValueObject();
  478. }
  479. void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
  480. auto primal = adjoint_definition->primal();
  481. if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
  482. MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
  483. << primal->ToString() << ".";
  484. }
  485. anfnode_to_adjoin_definition_[primal] = adjoint_definition;
  486. // Update k hole for primal.
  487. for (auto &f : func_graph_to_functor_) {
  488. auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
  489. if (adjoint != f.second->anfnode_to_adjoin_.end()) {
  490. adjoint->second->UpdateK(adjoint_definition->k());
  491. }
  492. adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
  493. if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
  494. adjoint->second->UpdateK(adjoint_definition->k());
  495. }
  496. }
  497. }
  498. AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
  499. auto adjoint = anfnode_to_adjoin_definition_.find(primal);
  500. if (adjoint != anfnode_to_adjoin_definition_.end()) {
  501. MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << ".";
  502. return adjoint->second;
  503. }
  504. MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << ".";
  505. return nullptr;
  506. }
  507. void DFunctor::CallDoutHoleOnTape() {
  508. // Call dout hole of all adjoint.
  509. for (auto &f : func_graph_to_functor_) {
  510. for (auto &adjoint : f.second->anfnode_to_adjoin_) {
  511. adjoint.second->CallDoutHole();
  512. }
  513. for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
  514. adjoint.second->CallDoutHole();
  515. }
  516. }
  517. }
  518. FuncGraphPtr DFunctor::k_graph() {
  519. CallDoutHoleOnTape();
  520. return k_graph_;
  521. }
  522. void DFunctor::BroadCastStopFlag() {
  523. // As stop set expanding, all directly or indirectly stopped CNode will be cut off
  524. while (need_cut_) {
  525. need_cut_ = false;
  526. for (auto &node : primal_graph_->nodes()) {
  527. if (node->isa<CNode>()) {
  528. auto cnode = node->cast<CNodePtr>();
  529. if (!cnode->stop_gradient()) {
  530. // Cut off the cnode only when it's not referred any more
  531. if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) {
  532. MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
  533. cnode->set_stop_gradient(true);
  534. // The stop set changed, more cut required
  535. need_cut_ = true;
  536. }
  537. }
  538. }
  539. }
  540. }
  541. }
  542. bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
  543. auto &users = primal_graph_->manager()->node_users()[node];
  544. // Only care about stop_gradient caused cutting
  545. if (users.empty()) {
  546. return false;
  547. }
  548. for (auto &kv : users) {
  549. auto &user = kv.first;
  550. if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
  551. return false;
  552. }
  553. }
  554. return true;
  555. }
  556. } // namespace ad
  557. } // namespace mindspore