You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dfunctor.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "optimizer/ad/dfunctor.h"
  17. #include <memory>
  18. #include <string>
  19. #include <utility>
  20. #include "ir/anf.h"
  21. #include "ir/meta_func_graph.h"
  22. #include "debug/info.h"
  23. #include "ir/func_graph_cloner.h"
  24. #include "ir/manager.h"
  25. #include "pipeline/resource.h"
  26. #include "pipeline/parse/parse.h"
  27. #include "optimizer/ad/adjoint.h"
  28. #include "optimizer/opt.h"
  29. #include "operator/ops.h"
  30. #include "operator/composite/composite.h"
  31. #include "utils/symbolic.h"
  32. #include "./common.h"
  33. namespace mindspore {
  34. namespace ad {
  35. std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
  36. std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
  37. DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
  38. : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
  39. TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
  40. k_graph_ = std::make_shared<FuncGraph>();
  41. TraceManager::EndTrace();
  42. TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
  43. tape_ = std::make_shared<FuncGraph>();
  44. TraceManager::EndTrace();
  45. dout_ = tape_->add_parameter();
  46. }
  47. void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
  48. func_graph_to_functor_[primal_graph_] = functor;
  49. is_top_ = is_top;
  50. }
  51. void DFunctor::Clear() {
  52. func_graph_to_functor_.clear();
  53. anfnode_to_adjoin_definition_.clear();
  54. }
  55. void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
  56. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  57. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  58. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
  59. << " " << fv->ToString() << ".";
  60. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  61. if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
  62. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
  63. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  64. auto parent_adjoint = FindAdjoint(fv);
  65. AdjointPtr adjoint = nullptr;
  66. if (parent_adjoint != nullptr) {
  67. adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
  68. } else {
  69. MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
  70. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  71. adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
  72. }
  73. anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
  74. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  75. }
  76. }
  77. auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
  78. fv_adjoint->second->RegisterKUser(key, 1);
  79. auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()});
  80. fv_adjoint->second->RegisterKUser(default_val, 1);
  81. auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, key, default_val});
  82. MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
  83. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  84. MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << key->ToString() << ".";
  85. fv_adjoint->second->AccumulateDout(dfv);
  86. }
  87. void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
  88. auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)});
  89. // Call with delimited continuation dout.
  90. auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
  91. node_adjoint->RegisterDoutUser(bprop_app, 1);
  92. for (size_t i = 0; i < cnode_morph->size(); i++) {
  93. auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))});
  94. auto input = cnode_morph->input(i);
  95. // Backprop sens wrt fvs.
  96. if (IsValueNode<FuncGraph>(input)) {
  97. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  98. auto functor = func_graph_to_functor_.find(func_graph);
  99. if (functor == func_graph_to_functor_.end()) {
  100. MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] "
  101. << func_graph->ToString() << ".";
  102. }
  103. // Consider direct and indirect fvs.
  104. for (auto fv : func_graph->free_variables_nodes()) {
  105. BackPropagateFv(fv, din);
  106. }
  107. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  108. MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " "
  109. << indirect_fv.first->ToString() << ".";
  110. BackPropagateFv(indirect_fv.first, din);
  111. }
  112. continue;
  113. }
  114. // Backprop sens wrt inputs.
  115. auto input_adjoint = anfnode_to_adjoin_.find(input);
  116. if (input_adjoint == anfnode_to_adjoin_.end()) {
  117. MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
  118. }
  119. input_adjoint->second->AccumulateDout(din);
  120. }
  121. }
  122. // Map a morphism.
  123. AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
  124. // MapMorphism All type except CNode should already be mapped by MapObject.
  125. if (!morph->isa<CNode>()) {
  126. return nullptr;
  127. }
  128. ScopeGuard scope_guard(morph->scope());
  129. auto cnode_morph = morph->cast<CNodePtr>();
  130. std::vector<AnfNodePtr> inputs;
  131. std::vector<AdjointPtr> param_adjoints;
  132. for (size_t i = 0; i < cnode_morph->size(); i++) {
  133. auto node = cnode_morph->input(i);
  134. auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
  135. AdjointPtr node_adjoint = nullptr;
  136. AnfNodePtr k = nullptr;
  137. if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
  138. node_adjoint = node_adjoint_iter->second;
  139. } else {
  140. // Input might be a CNode that needs to be handled before hand.
  141. node_adjoint = MapMorphism(node);
  142. }
  143. MS_EXCEPTION_IF_NULL(node_adjoint);
  144. k = node_adjoint->k();
  145. if (k == nullptr) {
  146. MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
  147. }
  148. inputs.push_back(k);
  149. param_adjoints.push_back(node_adjoint);
  150. }
  151. TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
  152. auto k_app = k_graph_->NewCNode(inputs);
  153. TraceManager::EndTrace();
  154. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  155. param_adjoints[i]->RegisterKUser(k_app, i);
  156. }
  157. // Do forward computation
  158. auto foward_app = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(0)});
  159. // K:: cnode -> forward_app
  160. auto node_adjoint = std::make_shared<Adjoint>(morph, foward_app, tape_);
  161. UpdateAdjoint(node_adjoint);
  162. anfnode_to_adjoin_[morph] = node_adjoint;
  163. if (cnode_morph->stop_gradient()) {
  164. MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
  165. return node_adjoint;
  166. }
  167. // Do sens backpropagation
  168. BackPropagate(cnode_morph, k_app, node_adjoint);
  169. MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << ".";
  170. return node_adjoint;
  171. }
  172. bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
  173. // Do not care about non-CNode
  174. if (!node->isa<CNode>()) {
  175. return false;
  176. }
  177. // Do not care about kPrimReturn
  178. if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
  179. return false;
  180. }
  181. auto &users = primal_graph_->manager()->node_users()[node];
  182. // Do not care about isolated morphisms
  183. if (users.empty()) {
  184. return false;
  185. }
  186. // Not free if it's used by some node in primal_graph
  187. bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
  188. auto &user = kv.first;
  189. return user->func_graph() == primal_graph_;
  190. });
  191. return !nonfree;
  192. }
  193. void DFunctor::MapFreeMorphism() {
  194. // Handle cnode not attached to output, that might be refered in other functions.
  195. for (auto &node : primal_graph_->nodes()) {
  196. if (!IsFreeMorphism(node)) {
  197. continue;
  198. }
  199. MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
  200. (void)MapMorphism(node);
  201. }
  202. }
  203. AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
  204. AnfNodePtr new_grad_fv = grad_fv;
  205. // Add grads wrt fv.
  206. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  207. for (auto &fv : free_variables_nodes) {
  208. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  209. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  210. MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
  211. }
  212. auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
  213. fv_adjoint->second->RegisterKUser(key, 1);
  214. auto sens = fv_adjoint->second->dout();
  215. new_grad_fv = tape_->NewCNode({
  216. NewValueNode(prim::kPrimEnvSetItem),
  217. new_grad_fv,
  218. key,
  219. sens,
  220. });
  221. fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
  222. MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " "
  223. << fv->ToString() << " " << primal_graph_->ToString() << ".";
  224. }
  225. return new_grad_fv;
  226. }
  227. AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
  228. AnfNodePtr new_grad_fv = grad_fv;
  229. // Add indirect fv bprop.
  230. for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
  231. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " "
  232. << primal_graph_->ToString() << ".";
  233. auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
  234. fv_adjoint.second->RegisterKUser(key, 1);
  235. auto sens = fv_adjoint.second->dout();
  236. new_grad_fv = tape_->NewCNode({
  237. NewValueNode(prim::kPrimEnvSetItem),
  238. new_grad_fv,
  239. key,
  240. sens,
  241. });
  242. fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
  243. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to "
  244. << new_grad_fv->ToString() << ".";
  245. }
  246. return new_grad_fv;
  247. }
  248. void DFunctor::MapMorphism() {
  249. // Set stop_gradient before MapMorphism.
  250. BroadCastStopFlag();
  251. // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
  252. MapFreeMorphism();
  253. // Handle morphism from output.
  254. (void)MapMorphism(primal_graph_->output());
  255. // Construct K for primal_graph_
  256. auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
  257. // Attach dout_ parameter to output_adjoint.
  258. output_adjoint->second->AccumulateDout(dout_);
  259. // Set output for tape closure.
  260. auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
  261. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
  262. // Add grads wrt inputs.
  263. std::vector<AdjointPtr> param_adjoints;
  264. for (auto &param : primal_graph_->parameters()) {
  265. auto param_adjoint = anfnode_to_adjoin_.find(param);
  266. inputs.push_back(param_adjoint->second->dout());
  267. param_adjoints.push_back(param_adjoint->second);
  268. }
  269. auto tape_output = tape_->NewCNode(inputs);
  270. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  271. param_adjoints[i]->RegisterDoutUser(tape_output, i + 2);
  272. }
  273. tape_->set_output(tape_output);
  274. // Set output for k_graph_, K:: cnode->forward_app.
  275. auto forward_app = output_adjoint->second->k();
  276. auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
  277. output_adjoint->second->RegisterKUser(output, 1);
  278. k_graph_->set_output(output);
  279. (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_)));
  280. (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_)));
  281. }
  282. FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
  283. // K user defined cell bprop.
  284. auto bprop = primal->transforms().find("bprop");
  285. if (bprop != primal->transforms().end()) {
  286. FuncGraphPtr bprop_graph = bprop->second.func_graph();
  287. resources_->manager()->AddFuncGraph(bprop_graph);
  288. if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) {
  289. MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
  290. << primal->output()->scope()->name() << " does not support Parameter data type.";
  291. }
  292. auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph);
  293. if (fg == nullptr) {
  294. MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
  295. << primal->output()->scope()->name() << ".";
  296. }
  297. // Cache the grad func
  298. (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
  299. (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
  300. // Reset defer_inline to enable successive inlining
  301. primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
  302. auto functor = std::make_shared<DFunctor>(primal, resources_);
  303. functor->Init(functor);
  304. functor->k_graph_ = fg;
  305. return fg;
  306. }
  307. return nullptr;
  308. }
  309. // MapToK(func)
  310. AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
  311. auto f = func_graph_to_functor_.find(primal);
  312. if (f != func_graph_to_functor_.end()) {
  313. MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << ".";
  314. return NewValueNode(f->second->k_graph_);
  315. }
  316. auto k_user_defined = KUserDefined(primal);
  317. if (k_user_defined != nullptr) {
  318. MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << ".";
  319. return NewValueNode(k_user_defined);
  320. }
  321. auto functor = std::make_shared<DFunctor>(primal, resources_);
  322. functor->Init(functor);
  323. functor->MapObject();
  324. functor->MapMorphism();
  325. MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << ".";
  326. return NewValueNode(functor->k_graph_);
  327. }
  328. // Construct representation graph for given node.
  329. AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
  330. ScopeGuard scope_guard(primal->scope());
  331. // MapToK(prim)
  332. if (IsValueNode<Primitive>(primal)) {
  333. auto value_node = primal->cast<ValueNodePtr>();
  334. auto prim = GetValueNode<PrimitivePtr>(value_node);
  335. if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
  336. MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
  337. need_cut_ = true;
  338. }
  339. auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
  340. if (k_prim != nullptr) {
  341. k_prim = BasicClone(k_prim);
  342. return NewValueNode(k_prim);
  343. }
  344. // When failed to find k_prim, try k_meta.
  345. auto k_meta = g_k_prims.KMetaFuncGraph(prim);
  346. if (k_meta != nullptr) {
  347. return NewValueNode(k_meta);
  348. }
  349. }
  350. // MapToK(func)
  351. if (IsValueNode<FuncGraph>(primal)) {
  352. auto func_graph = GetValueNode<FuncGraphPtr>(primal);
  353. auto k_func = MapToK(func_graph);
  354. return k_func;
  355. }
  356. if (primal->isa<Parameter>()) {
  357. TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal->debug_info()));
  358. auto ret = k_graph_->add_parameter();
  359. TraceManager::EndTrace();
  360. return ret;
  361. }
  362. if (!primal->isa<ValueNode>()) {
  363. MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode.";
  364. }
  365. return primal;
  366. }
  367. void DFunctor::MapFvObject() {
  368. // Map free variable.
  369. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  370. for (auto &node : free_variables_nodes) {
  371. ScopeGuard scope_guard(node->scope());
  372. MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << ".";
  373. // Find fv's K from parent.
  374. AdjointPtr adjoint = nullptr;
  375. auto parent_adjoint = FindAdjoint(node);
  376. if (parent_adjoint != nullptr) {
  377. adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  378. } else {
  379. if (is_top_) {
  380. // Top graph for ad, add adjoint for free variables.
  381. adjoint = std::make_shared<Adjoint>(node, node, tape_);
  382. UpdateAdjoint(adjoint);
  383. } else {
  384. MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << ".";
  385. adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
  386. }
  387. }
  388. if (adjoint == nullptr) {
  389. MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << ".";
  390. }
  391. anfnode_to_adjoin_[node] = adjoint;
  392. }
  393. }
  394. void DFunctor::MapParamObject() {
  395. // Map parameter.
  396. for (auto &p : primal_graph_->parameters()) {
  397. ScopeGuard scope_guard(p->scope());
  398. MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
  399. auto adjoint = std::make_shared<Adjoint>(p, MapToK(p), tape_);
  400. UpdateAdjoint(adjoint);
  401. anfnode_to_adjoin_[p] = adjoint;
  402. }
  403. }
  404. void DFunctor::MapValueObject() {
  405. // Map ValueNode.
  406. auto manager = resources_->manager();
  407. auto &value_nodes = manager->valuenodes()[primal_graph_];
  408. for (const auto &value_pair : value_nodes) {
  409. auto node = value_pair.first;
  410. auto parent_adjoint = FindAdjoint(node);
  411. if (parent_adjoint != nullptr) {
  412. auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  413. anfnode_to_adjoin_[node] = adjoint;
  414. continue;
  415. }
  416. // Skip Return.
  417. if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
  418. continue;
  419. }
  420. MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << ".";
  421. auto adjoint = std::make_shared<Adjoint>(node, MapToK(node), tape_);
  422. UpdateAdjoint(adjoint);
  423. anfnode_to_adjoin_[node] = adjoint;
  424. }
  425. }
  426. // Skip morphism.
  427. void DFunctor::MapObject() {
  428. // The order does not matter
  429. MapFvObject();
  430. MapParamObject();
  431. MapValueObject();
  432. }
  433. void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
  434. auto primal = adjoint_definition->primal();
  435. if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
  436. MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
  437. << primal->ToString() << ".";
  438. }
  439. anfnode_to_adjoin_definition_[primal] = adjoint_definition;
  440. // Update k hole for primal.
  441. for (auto &f : func_graph_to_functor_) {
  442. auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
  443. if (adjoint != f.second->anfnode_to_adjoin_.end()) {
  444. adjoint->second->UpdateK(adjoint_definition->k());
  445. }
  446. adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
  447. if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
  448. adjoint->second->UpdateK(adjoint_definition->k());
  449. }
  450. }
  451. }
  452. AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
  453. auto adjoint = anfnode_to_adjoin_definition_.find(primal);
  454. if (adjoint != anfnode_to_adjoin_definition_.end()) {
  455. MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << ".";
  456. return adjoint->second;
  457. }
  458. MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << ".";
  459. return nullptr;
  460. }
  461. void DFunctor::CallDoutHoleOnTape() {
  462. // Call dout hole of all adjoint.
  463. for (auto &f : func_graph_to_functor_) {
  464. for (auto &adjoint : f.second->anfnode_to_adjoin_) {
  465. adjoint.second->CallDoutHole();
  466. }
  467. for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
  468. adjoint.second->CallDoutHole();
  469. }
  470. }
  471. }
  472. FuncGraphPtr DFunctor::k_graph() {
  473. CallDoutHoleOnTape();
  474. return k_graph_;
  475. }
  476. void DFunctor::BroadCastStopFlag() {
  477. // As stop set expanding, all directly or indirectly stopped CNode will be cut off
  478. while (need_cut_) {
  479. need_cut_ = false;
  480. for (auto &node : primal_graph_->nodes()) {
  481. if (node->isa<CNode>()) {
  482. auto cnode = node->cast<CNodePtr>();
  483. if (!cnode->stop_gradient()) {
  484. // Cut off the cnode only when it's not referred any more
  485. if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) {
  486. MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
  487. cnode->set_stop_gradient(true);
  488. // The stop set changed, more cut required
  489. need_cut_ = true;
  490. }
  491. }
  492. }
  493. }
  494. }
  495. }
  496. bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
  497. auto &users = primal_graph_->manager()->node_users()[node];
  498. // Only care about stop_gradient caused cutting
  499. if (users.empty()) {
  500. return false;
  501. }
  502. for (auto &kv : users) {
  503. auto &user = kv.first;
  504. if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
  505. return false;
  506. }
  507. }
  508. return true;
  509. }
  510. } // namespace ad
  511. } // namespace mindspore