You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dfunctor.cc 40 kB

5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/ad/dfunctor.h"
  17. #include <map>
  18. #include <memory>
  19. #include <string>
  20. #include "ir/anf.h"
  21. #include "utils/info.h"
  22. #include "ir/func_graph_cloner.h"
  23. #include "ir/manager.h"
  24. #include "pipeline/jit/resource.h"
  25. #include "frontend/optimizer/ad/adjoint.h"
  26. #include "frontend/operator/ops.h"
  27. #include "utils/symbolic.h"
  28. #include "utils/ms_context.h"
  29. #include "pipeline/jit/action.h"
  30. #include "pipeline/jit/parse/resolve.h"
  31. #include "pipeline/pynative/pynative_execute.h"
  32. #include "debug/anf_ir_dump.h"
  33. namespace mindspore {
  34. namespace ad {
  35. mindspore::HashMap<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
  36. mindspore::HashMap<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
  37. bool lift_fv_before_grad = true;
  38. DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
  39. : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
  40. {
  41. TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
  42. k_graph_ = std::make_shared<FuncGraph>();
  43. }
  44. // To keep switch or switch_layer's inputs from being inlined
  45. k_graph_->set_switch_input(primal_graph->switch_input());
  46. k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
  47. k_graph_->set_stage(primal_graph->stage());
  48. {
  49. TraceGuard guard(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
  50. tape_ = std::make_shared<FuncGraph>();
  51. }
  52. tape_->set_stage(primal_graph->stage());
  53. dout_ = tape_->add_parameter();
  54. }
  55. void DFunctor::Init(bool is_top) {
  56. func_graph_to_functor_[primal_graph_] = shared_from_this();
  57. is_top_ = is_top;
  58. }
  59. void DFunctor::Finish() {
  60. CallDoutHoleOnTape();
  61. EliminatePrimalGraph();
  62. }
  63. void DFunctor::Clear() {
  64. func_graph_to_functor_.clear();
  65. anfnode_to_adjoin_definition_.clear();
  66. }
  67. void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
  68. MS_EXCEPTION_IF_NULL(fv);
  69. if (lift_fv_before_grad) {
  70. MS_EXCEPTION_IF_NULL(fv->func_graph());
  71. MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv:"
  72. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  73. }
  74. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  75. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  76. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
  77. << " " << fv->ToString() << ".";
  78. if (fv->func_graph() == primal_graph_) {
  79. // If this fv is not mapped by MapMorphism because of cnode order, then map it now.
  80. (void)MapMorphism(fv);
  81. fv_adjoint = anfnode_to_adjoin_.find(fv);
  82. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  83. MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
  84. << fv->ToString() << ".";
  85. }
  86. } else {
  87. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  88. if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
  89. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
  90. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  91. auto parent_adjoint = FindAdjoint(fv);
  92. AdjointPtr adjoint = nullptr;
  93. if (parent_adjoint != nullptr) {
  94. adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
  95. } else {
  96. MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
  97. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  98. adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
  99. }
  100. anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
  101. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  102. }
  103. }
  104. }
  105. auto fv_node = fv_adjoint->second->k();
  106. auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
  107. CNodePtr embed_node, default_val_node;
  108. if (cached_envitem_iter != anfnode_to_envitem_.end()) {
  109. embed_node = cached_envitem_iter->second.first;
  110. default_val_node = cached_envitem_iter->second.second;
  111. } else {
  112. embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node});
  113. default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node});
  114. fv_adjoint->second->RegisterKUser(embed_node, 1);
  115. fv_adjoint->second->RegisterKUser(default_val_node, 1);
  116. anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
  117. }
  118. auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
  119. MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
  120. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  121. MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
  122. fv_adjoint->second->AccumulateDout(dfv);
  123. }
  124. void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
  125. // Take switch_layer as a set of candidate functions.
  126. constexpr size_t input_tuple_index = 2;
  127. auto input = cnode_morph->input(input_tuple_index);
  128. if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
  129. MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
  130. }
  131. mindspore::HashMap<AnfNodePtr, FuncGraphPtr> node_to_fg;
  132. auto tuple_graphs = input->cast<CNodePtr>();
  133. for (size_t i = 1; i < tuple_graphs->size(); ++i) {
  134. auto graph = tuple_graphs->input(i);
  135. if (!IsValueNode<FuncGraph>(graph)) {
  136. MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
  137. << " as the " << i << "th element.";
  138. }
  139. auto func_graph = GetValueNode<FuncGraphPtr>(graph);
  140. auto functor = func_graph_to_functor_.find(func_graph);
  141. if (functor == func_graph_to_functor_.end()) {
  142. MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
  143. << func_graph->ToString() << ".";
  144. }
  145. // Consider direct and indirect fvs.
  146. for (auto fv : func_graph->free_variables_nodes()) {
  147. if (node_to_fg.find(fv) != node_to_fg.end()) {
  148. continue;
  149. }
  150. node_to_fg[fv] = func_graph;
  151. BackPropagateFv(fv, env);
  152. }
  153. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  154. MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
  155. << indirect_fv.first->ToString() << ".";
  156. if (node_to_fg.find(indirect_fv.first) != node_to_fg.end()) {
  157. continue;
  158. }
  159. node_to_fg[indirect_fv.first] = func_graph;
  160. BackPropagateFv(indirect_fv.first, env);
  161. }
  162. }
  163. }
  164. static bool HasSideEffectBackProp(const CNodePtr &cnode) {
  165. if (IsPrimitiveCNode(cnode)) {
  166. const auto &prim = GetCNodePrimitive(cnode);
  167. MS_EXCEPTION_IF_NULL(prim);
  168. auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
  169. return bprop_flag;
  170. }
  171. return false;
  172. }
  173. AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, FuncGraphPtr fg) {
  174. MS_EXCEPTION_IF_NULL(input);
  175. TypePtr input_type = input->Type();
  176. if (input_type == nullptr || !input_type->isa<TensorType>()) {
  177. return din;
  178. }
  179. input_type = input_type->cast<TensorTypePtr>()->element();
  180. MS_EXCEPTION_IF_NULL(input_type);
  181. if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) {
  182. return din;
  183. }
  184. MS_EXCEPTION_IF_NULL(din);
  185. // If we can not get the dtype of din, we insert real op ignoring din's dtype,
  186. // and eliminate it in "real_op_elimiate" pass.
  187. MS_EXCEPTION_IF_NULL(fg);
  188. if (din->abstract() == nullptr) {
  189. return fg->NewCNode({NewValueNode(prim::kPrimReal), din});
  190. }
  191. TypePtr din_type = din->Type();
  192. if (din_type == nullptr || !din_type->isa<TensorType>()) {
  193. return din;
  194. }
  195. din_type = din_type->cast<TensorTypePtr>()->element();
  196. MS_EXCEPTION_IF_NULL(din_type);
  197. if (din_type->type_id() != kNumberTypeComplex64 && din_type->type_id() != kNumberTypeComplex128) {
  198. return din;
  199. }
  200. AnfNodePtr new_din = fg->NewCNode({NewValueNode(prim::kPrimReal), din});
  201. AbstractBasePtr abs = std::make_shared<abstract::AbstractTensor>(
  202. abstract::AbstractTensor(input_type, input->abstract()->GetShapeTrack()));
  203. new_din->set_abstract(abs);
  204. return new_din;
  205. }
  206. void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
  207. auto bprop =
  208. k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(1))});
  209. // Call with delimited continuation dout.
  210. CNodePtr bprop_app;
  211. if (HasSideEffectBackProp(cnode_morph)) {
  212. // as MapMorphism is called recursively, so the order of bprop_app should reversed as visited order.
  213. bprop_app = tape_->NewCNodeInFront({bprop, node_adjoint->dout()});
  214. tape_->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
  215. } else {
  216. bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
  217. }
  218. node_adjoint->RegisterDoutUser(bprop_app, 1);
  219. // Special case for switch_layer
  220. if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
  221. auto din =
  222. tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(static_cast<int64_t>(0))});
  223. BackPropagateSwitchLayer(cnode_morph, din);
  224. return;
  225. }
  226. for (size_t i = 0; i < cnode_morph->size(); i++) {
  227. auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
  228. auto input = cnode_morph->input(i);
  229. // Skip HookBackward op
  230. if (IsPrimitiveCNode(input, prim::kPrimHookBackward)) {
  231. auto inp_i = input->cast<CNodePtr>();
  232. input = inp_i->input(1);
  233. }
  234. auto din_with_real = HandleRealToComplex(input, din, tape_);
  235. MS_EXCEPTION_IF_NULL(din_with_real);
  236. din = din_with_real->cast<CNodePtr>();
  237. // Backprop sens wrt fvs.
  238. if (IsValueNode<FuncGraph>(input)) {
  239. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  240. auto functor = func_graph_to_functor_.find(func_graph);
  241. if (functor == func_graph_to_functor_.end()) {
  242. MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] "
  243. << func_graph->ToString() << ".";
  244. }
  245. // Consider direct and indirect fvs.
  246. for (auto fv : func_graph->free_variables_nodes()) {
  247. BackPropagateFv(fv, din);
  248. }
  249. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  250. MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " "
  251. << indirect_fv.first->ToString() << ".";
  252. BackPropagateFv(indirect_fv.first, din);
  253. }
  254. continue;
  255. }
  256. // Backprop sens wrt inputs.
  257. auto input_adjoint = anfnode_to_adjoin_.find(input);
  258. if (input_adjoint == anfnode_to_adjoin_.end()) {
  259. MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
  260. }
  261. input_adjoint->second->AccumulateDout(din);
  262. }
  263. }
  264. // Map a morphism.
  265. AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
  266. MS_LOG(DEBUG) << "start MapMorphism:" << morph->DebugString(4);
  267. // MapMorphism All type except CNode should already be mapped by MapObject.
  268. if (!morph->isa<CNode>()) {
  269. return nullptr;
  270. }
  271. // for free variable, which may be handled in MapValueObject, just return it
  272. auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
  273. if (node_adjoint_found != anfnode_to_adjoin_.end()) {
  274. return node_adjoint_found->second;
  275. }
  276. ScopeGuard scope_guard(morph->scope());
  277. auto cnode_morph = morph->cast<CNodePtr>();
  278. std::vector<AnfNodePtr> inputs;
  279. std::vector<AdjointPtr> param_adjoints;
  280. for (size_t i = 0; i < cnode_morph->size(); i++) {
  281. auto node = cnode_morph->input(i);
  282. // Skip HookBackward op
  283. if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
  284. auto input_i = node->cast<CNodePtr>();
  285. MS_LOG(WARNING)
  286. << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
  287. node = input_i->input(1);
  288. }
  289. AdjointPtr node_adjoint = nullptr;
  290. auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
  291. if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
  292. node_adjoint = node_adjoint_iter->second;
  293. } else {
  294. // Input might be a CNode that needs to be handled previously.
  295. node_adjoint = MapMorphism(node);
  296. }
  297. MS_EXCEPTION_IF_NULL(node_adjoint);
  298. AnfNodePtr k = node_adjoint->k();
  299. if (k == nullptr) {
  300. MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
  301. }
  302. inputs.push_back(k);
  303. param_adjoints.push_back(node_adjoint);
  304. }
  305. CNodePtr k_app = nullptr;
  306. {
  307. TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
  308. k_app = k_graph_->NewCNode(inputs);
  309. }
  310. // Run in pynative mode, when @ms_function is used.
  311. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  312. auto pynative_exec = pynative::PynativeExecutor::GetInstance();
  313. auto grad_exec = pynative_exec->grad_executor();
  314. if (grad_exec->eliminate_forward()) {
  315. PynativeDFunctor::ReplaceEquivdout(k_app, cnode_morph);
  316. cnode_morph->clear_inputs_value();
  317. }
  318. }
  319. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  320. param_adjoints[i]->RegisterKUser(k_app, i);
  321. }
  322. // Do forward computation
  323. auto forward_app =
  324. k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(0))});
  325. // K:: cnode -> forward_app
  326. auto node_adjoint = std::make_shared<Adjoint>(morph, forward_app, tape_);
  327. UpdateAdjoint(node_adjoint);
  328. anfnode_to_adjoin_[morph] = node_adjoint;
  329. if (cnode_morph->stop_gradient()) {
  330. MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
  331. return node_adjoint;
  332. }
  333. // Do sens backpropagation
  334. BackPropagate(cnode_morph, k_app, node_adjoint);
  335. MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
  336. return node_adjoint;
  337. }
  338. bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
  339. // Do not care about non-CNode
  340. if (!node->isa<CNode>()) {
  341. return false;
  342. }
  343. // Do not care about kPrimReturn
  344. if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
  345. return false;
  346. }
  347. MS_EXCEPTION_IF_NULL(primal_graph_->manager());
  348. auto &node_users = primal_graph_->manager()->node_users();
  349. auto iter = node_users.find(node);
  350. if (iter == node_users.end()) {
  351. return false;
  352. }
  353. auto &users = iter->second;
  354. // Do not care about isolated morphisms
  355. if (users.empty()) {
  356. return false;
  357. }
  358. // Not free if it's used by some node in primal_graph
  359. bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
  360. auto &user = kv.first;
  361. return user->func_graph() == primal_graph_;
  362. });
  363. return !nonfree;
  364. }
  365. void DFunctor::MapFreeMorphism() {
  366. // Handle cnode not attached to output, that might be referred in other functions.
  367. for (auto &node : primal_graph_->nodes()) {
  368. if (!IsFreeMorphism(node)) {
  369. continue;
  370. }
  371. MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
  372. (void)MapMorphism(node);
  373. }
  374. }
  375. AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
  376. AnfNodePtr new_grad_fv = grad_fv;
  377. // Add grads wrt fv.
  378. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  379. if (!is_top_ && free_variables_nodes.size() != 0) {
  380. if (lift_fv_before_grad) {
  381. MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString()
  382. << ".";
  383. }
  384. }
  385. for (auto &fv : free_variables_nodes) {
  386. if (IsPrimitiveCNode(fv, prim::kPrimJ)) { // Ignore if FV is a J CNode.
  387. continue;
  388. }
  389. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  390. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  391. MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
  392. }
  393. auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
  394. fv_adjoint->second->RegisterKUser(node, 1);
  395. auto sens = fv_adjoint->second->dout();
  396. new_grad_fv = tape_->NewCNode({
  397. NewValueNode(prim::kPrimEnvSetItem),
  398. new_grad_fv,
  399. node,
  400. sens,
  401. });
  402. constexpr size_t sens_index = 3;
  403. fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), sens_index);
  404. MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " "
  405. << fv->ToString() << " " << primal_graph_->ToString() << ".";
  406. }
  407. return new_grad_fv;
  408. }
  409. AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
  410. if (lift_fv_before_grad) {
  411. MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv "
  412. << grad_fv->ToString() << " " << primal_graph_->ToString() << ".";
  413. }
  414. AnfNodePtr new_grad_fv = grad_fv;
  415. // Add indirect fv bprop.
  416. for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
  417. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " "
  418. << primal_graph_->ToString() << ".";
  419. auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
  420. fv_adjoint.second->RegisterKUser(node, 1);
  421. auto sens = fv_adjoint.second->dout();
  422. new_grad_fv = tape_->NewCNode({
  423. NewValueNode(prim::kPrimEnvSetItem),
  424. new_grad_fv,
  425. node,
  426. sens,
  427. });
  428. constexpr size_t sens_index = 3;
  429. fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), sens_index);
  430. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to "
  431. << new_grad_fv->ToString() << ".";
  432. }
  433. return new_grad_fv;
  434. }
  435. void DFunctor::MapMorphism() {
  436. // Set stop_gradient before MapMorphism.
  437. BroadCastStopFlag();
  438. // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
  439. MapFreeMorphism();
  440. // Skip HookBackward when it is the output node.
  441. auto output_node = primal_graph_->output();
  442. if (IsPrimitiveCNode(output_node, prim::kPrimHookBackward)) {
  443. auto output_cnode = output_node->cast<CNodePtr>();
  444. MS_LOG(WARNING)
  445. << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
  446. output_node = output_cnode->input(1);
  447. }
  448. // Handle morphism from output.
  449. (void)MapMorphism(output_node);
  450. // Construct K for primal_graph_.
  451. auto output_adjoint = anfnode_to_adjoin_.find(output_node);
  452. // Attach dout_ parameter to output_adjoint.
  453. output_adjoint->second->AccumulateDout(dout_);
  454. // Set output for tape closure.
  455. AnfNodePtr grad_fv;
  456. if (lift_fv_before_grad) {
  457. grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
  458. } else {
  459. grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
  460. }
  461. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
  462. // Add grads wrt inputs.
  463. std::vector<AdjointPtr> param_adjoints;
  464. for (auto &param : primal_graph_->parameters()) {
  465. auto param_adjoint = anfnode_to_adjoin_.find(param);
  466. inputs.push_back(param_adjoint->second->dout());
  467. param_adjoints.push_back(param_adjoint->second);
  468. }
  469. auto tape_output = tape_->NewCNode(inputs);
  470. constexpr size_t offset_num = 2;
  471. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  472. param_adjoints[i]->RegisterDoutUser(tape_output, i + offset_num);
  473. }
  474. tape_->set_output(tape_output);
  475. // Set output for k_graph_, K:: cnode->forward_app.
  476. auto forward_app = output_adjoint->second->k();
  477. auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
  478. output_adjoint->second->RegisterKUser(output, 1);
  479. k_graph_->set_output(output);
  480. (void)primal_graph_->transforms().emplace("grad", FuncGraphTransform(k_graph_));
  481. (void)k_graph_->transforms().emplace("primal", FuncGraphTransform(primal_graph_));
  482. }
  483. FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
  484. // K user defined cell bprop.
  485. auto bprop = primal->transforms().find("bprop");
  486. if (bprop != primal->transforms().end()) {
  487. FuncGraphPtr bprop_graph = bprop->second.func_graph();
  488. resources_->manager()->AddFuncGraph(bprop_graph);
  489. if (!bprop_graph->free_variables_nodes().empty() || !primal->free_variables_nodes().empty()) {
  490. MS_LOG(EXCEPTION) << "The Cell with user defined 'bprop' function in scope " << primal->output()->scope()->name()
  491. << " does not support Parameter data type.\n"
  492. << trace::GetDebugInfo(bprop_graph->debug_info());
  493. }
  494. bprop_graph->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);
  495. bprop_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
  496. auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph, primal);
  497. if (fg == nullptr) {
  498. MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
  499. << primal->output()->scope()->name() << ".";
  500. }
  501. // Cache the grad func
  502. (void)primal->transforms().emplace("grad", FuncGraphTransform(fg));
  503. (void)fg->transforms().emplace("primal", FuncGraphTransform(primal));
  504. // Reset defer_inline to enable successive inlining
  505. primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
  506. auto functor = std::make_shared<DFunctor>(primal, resources_);
  507. functor->Init();
  508. functor->k_graph_ = fg;
  509. return fg;
  510. }
  511. return nullptr;
  512. }
  513. // Construct representation graph for {CNode, Index} of Primitive.
  514. AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) {
  515. auto primal = primitive_user->input(index);
  516. if (!IsValueNode<Primitive>(primal)) {
  517. MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive.";
  518. }
  519. ScopeGuard scope_guard(primal->scope());
  520. // Map Primitive to K
  521. auto value_node = primal->cast<ValueNodePtr>();
  522. auto prim = GetValueNode<PrimitivePtr>(value_node);
  523. if ((prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) ||
  524. (prim->Hash() == prim::kPrimUpdateState->Hash() && prim->name() == prim::kPrimUpdateState->name())) {
  525. MS_LOG(DEBUG) << "Should stop gradient for " << prim->ToString();
  526. need_cut_ = true;
  527. }
  528. auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_);
  529. if (k_prim != nullptr) {
  530. return NewValueNode(k_prim);
  531. }
  532. // When failed to find k_prim, try k_meta.
  533. auto k_meta = g_k_prims.KMetaFuncGraph(prim);
  534. if (k_meta != nullptr) {
  535. return NewValueNode(k_meta);
  536. }
  537. MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K.";
  538. }
  539. // Construct representation graph for ValueNode of FuncGraph.
  540. AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) {
  541. if (!IsValueNode<FuncGraph>(primal)) {
  542. MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph.";
  543. }
  544. ScopeGuard scope_guard(primal->scope());
  545. // Map func graph to K
  546. auto func_graph = GetValueNode<FuncGraphPtr>(primal);
  547. auto f = func_graph_to_functor_.find(func_graph);
  548. if (f != func_graph_to_functor_.end()) {
  549. MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << ".";
  550. return NewValueNode(f->second->k_graph_);
  551. }
  552. auto k_user_defined = KUserDefined(func_graph);
  553. if (k_user_defined != nullptr) {
  554. MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << ".";
  555. return NewValueNode(k_user_defined);
  556. }
  557. auto functor = std::make_shared<DFunctor>(func_graph, resources_);
  558. functor->Init();
  559. functor->MapObject();
  560. functor->MapMorphism();
  561. MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\"";
  562. return NewValueNode(functor->k_graph_);
  563. }
  564. // Construct for ValueNode of Parameter.
  565. AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) {
  566. if (!primal->isa<Parameter>()) {
  567. MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter.";
  568. }
  569. ScopeGuard scope_guard(primal->scope());
  570. // Map Parameter to K
  571. TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
  572. auto ret = k_graph_->add_parameter();
  573. return ret;
  574. }
  575. void DFunctor::MapFvObject() {
  576. // Map free variable.
  577. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  578. for (auto &node : free_variables_nodes) {
  579. ScopeGuard scope_guard(node->scope());
  580. MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << ".";
  581. // Find fv's K from parent.
  582. AdjointPtr adjoint = nullptr;
  583. auto parent_adjoint = FindAdjoint(node);
  584. if (parent_adjoint != nullptr) {
  585. adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  586. } else {
  587. if (is_top_ || node->isa<Parameter>()) {
  588. // Out of ad scope, add adjoint for free variables.
  589. adjoint = std::make_shared<Adjoint>(node, node, tape_);
  590. UpdateAdjoint(adjoint);
  591. } else {
  592. MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << ".";
  593. adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
  594. }
  595. }
  596. if (adjoint == nullptr) {
  597. MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << ".";
  598. }
  599. anfnode_to_adjoin_[node] = adjoint;
  600. }
  601. }
  602. void DFunctor::MapParamObject() {
  603. // Map parameter.
  604. for (auto &p : primal_graph_->parameters()) {
  605. ScopeGuard scope_guard(p->scope());
  606. MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
  607. auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_);
  608. UpdateAdjoint(adjoint);
  609. anfnode_to_adjoin_[p] = adjoint;
  610. }
  611. }
  612. void DFunctor::MapValueObject() {
  613. // Map ValueNode.
  614. auto manager = resources_->manager();
  615. auto &value_nodes = primal_graph_->value_nodes();
  616. for (const auto &value_pair : value_nodes) {
  617. auto node = value_pair.first;
  618. auto parent_adjoint = FindAdjoint(node);
  619. if (parent_adjoint != nullptr) {
  620. auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  621. anfnode_to_adjoin_[node] = adjoint;
  622. continue;
  623. }
  624. AdjointPtr adjoint = nullptr;
  625. if (IsValueNode<Primitive>(node)) { // Primitive.
  626. auto prim = GetValueNode<PrimitivePtr>(node);
  627. if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
  628. (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name())) {
  629. continue;
  630. }
  631. MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
  632. auto &users = manager->node_users()[node];
  633. if (users.size() == 0) {
  634. MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user.";
  635. continue;
  636. } else if (users.size() > 1) {
  637. MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size();
  638. }
  639. auto cnode = users.begin()->first->cast<CNodePtr>(); // We just use the first user.
  640. auto index = users.begin()->second;
  641. adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_);
  642. } else if (IsValueNode<FuncGraph>(node)) { // FuncGraph
  643. MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << ".";
  644. adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_);
  645. } else if (node->isa<Parameter>()) { // Parameter, hardly reach here.
  646. MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << ".";
  647. adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_);
  648. } else {
  649. adjoint = std::make_shared<Adjoint>(node, node, tape_);
  650. }
  651. UpdateAdjoint(adjoint);
  652. anfnode_to_adjoin_[node] = adjoint;
  653. }
  654. }
  655. // Skip morphism.
  656. void DFunctor::MapObject() {
  657. // The order does not matter
  658. MapFvObject();
  659. MapParamObject();
  660. MapValueObject();
  661. }
  662. void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
  663. auto primal = adjoint_definition->primal();
  664. if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
  665. MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
  666. << primal->ToString() << ".";
  667. }
  668. anfnode_to_adjoin_definition_[primal] = adjoint_definition;
  669. // Update k hole for primal.
  670. for (auto &f : func_graph_to_functor_) {
  671. auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
  672. if (adjoint != f.second->anfnode_to_adjoin_.end()) {
  673. adjoint->second->UpdateK(adjoint_definition->k());
  674. }
  675. adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
  676. if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
  677. adjoint->second->UpdateK(adjoint_definition->k());
  678. }
  679. }
  680. }
  681. AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
  682. auto adjoint = anfnode_to_adjoin_definition_.find(primal);
  683. if (adjoint != anfnode_to_adjoin_definition_.end()) {
  684. MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << ".";
  685. return adjoint->second;
  686. }
  687. MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << ".";
  688. return nullptr;
  689. }
  690. void DFunctor::CallDoutHoleOnTape() {
  691. if (!is_top_) {
  692. return;
  693. }
  694. // Call dout hole of all adjoint.
  695. for (auto &f : func_graph_to_functor_) {
  696. for (auto &adjoint : f.second->anfnode_to_adjoin_) {
  697. adjoint.second->CallDoutHole();
  698. }
  699. for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
  700. adjoint.second->CallDoutHole();
  701. }
  702. }
  703. }
  704. FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
  705. FuncGraphPtr DFunctor::tape() { return tape_; }
  706. void DFunctor::BroadCastStopFlag() {
  707. // As stop set expanding, all directly or indirectly stopped CNode will be cut off
  708. while (need_cut_) {
  709. need_cut_ = false;
  710. for (auto &node : primal_graph_->nodes()) {
  711. if (node->isa<CNode>()) {
  712. auto cnode = node->cast<CNodePtr>();
  713. if (!cnode->stop_gradient()) {
  714. // Cut off the cnode only when it's not referred any more
  715. if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
  716. AllReferencesStopped(cnode)) {
  717. MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
  718. cnode->set_stop_gradient(true);
  719. // The stop set changed, more cut required
  720. need_cut_ = true;
  721. }
  722. }
  723. }
  724. }
  725. }
  726. }
  727. bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
  728. auto &users = primal_graph_->manager()->node_users()[node];
  729. // Only care about stop_gradient caused cutting
  730. if (users.empty()) {
  731. return false;
  732. }
  733. for (auto &kv : users) {
  734. auto &user = kv.first;
  735. if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
  736. return false;
  737. }
  738. }
  739. return true;
  740. }
  741. CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int index) {
  742. auto it = node_user_map.find(cnode);
  743. if (it == node_user_map.end()) {
  744. MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}";
  745. }
  746. auto &j_users = it->second;
  747. auto size = j_users.size();
  748. if (size != 1) {
  749. bool has_multiple_j_call_user = false;
  750. CNodePtr j_call_user = nullptr;
  751. for (auto &user : j_users) {
  752. // If J CNode is used as a FV, the j_users.size may exceed 1 user. It is allowed.
  753. if (user.second == 0) {
  754. // Real J CNode call user.
  755. if (j_call_user == nullptr) { // First user.
  756. j_call_user = user.first->cast<CNodePtr>();
  757. } else { // More than 1 call user. Not allowed.
  758. has_multiple_j_call_user = true;
  759. }
  760. }
  761. }
  762. if (has_multiple_j_call_user) { // Has multiple J CNode call user.
  763. std::ostringstream user_info;
  764. for (auto &user : j_users) {
  765. user_info << " user: " << user.first->DebugString() << ", index: " << user.second << "\n";
  766. }
  767. #ifdef ENABLE_DUMP_IR
  768. DumpIR("J_User_Ex_" + cnode->func_graph()->ToString() + ".ir", cnode->func_graph());
  769. #endif
  770. MS_LOG(EXCEPTION) << "Incorrect J CNode user size: " << size << ", of {" << cnode->DebugString(2) << "/" << index
  771. << "}\nUser Info:\n"
  772. << user_info.str();
  773. } else {
  774. return j_call_user;
  775. }
  776. }
  777. return j_users.begin()->first->cast<CNodePtr>();
  778. }
  779. CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std::vector<CNodePtr>> &primal_map) {
  780. // Check if J operation has relevant primal call in the same graph.
  781. auto graph = j_user->func_graph();
  782. auto iter = primal_map.find(graph);
  783. if (iter == primal_map.end()) {
  784. MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString()
  785. << ", J user: " << j_user->DebugString();
  786. return nullptr;
  787. }
  788. // Check if there is only one primal call corresponding to the specified j user.
  789. auto primal_users = iter->second;
  790. if (primal_users.size() != 1) {
  791. MS_LOG(WARNING) << "It is recommended to call the forward network only once.";
  792. MS_LOG(INFO) << "There is " << primal_users.size()
  793. << " primal calls for same J operation in the same graph. Func graph: " << graph->ToString()
  794. << ", J operation: " << j_user->DebugString() << ", Primal call: ";
  795. size_t count = 0;
  796. for (const auto &user : primal_users) {
  797. MS_LOG(INFO) << "[ " << ++count << " ] : " << user->DebugString(2) << trace::DumpSourceLines(user);
  798. }
  799. return nullptr;
  800. }
  801. // Check input size.
  802. auto primal_user = primal_users[0];
  803. if (primal_user->size() != j_user->size()) {
  804. MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal_user->DebugString() << " is "
  805. << primal_user->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
  806. return nullptr;
  807. }
  808. return primal_user;
  809. }
  810. static mindspore::HashMap<CNodePtr, std::vector<CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
  811. const FuncGraphPtr &primal_graph) {
  812. std::vector<CNodePtr> j_users;
  813. std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
  814. const auto &node_user_map = manager->node_users();
  815. // Search primal graph user cnodes.
  816. for (auto &entry : primal_graph->func_graph_cnodes_index()) {
  817. auto cnode = entry.first->first->cast<CNodePtr>();
  818. auto index = entry.first->second;
  819. if (index == 0) {
  820. // To find real calling.
  821. auto fg = cnode->func_graph();
  822. MS_EXCEPTION_IF_NULL(fg);
  823. auto iter = primal_map.find(fg);
  824. if (iter != primal_map.end()) {
  825. iter->second.push_back(cnode);
  826. continue;
  827. }
  828. primal_map[fg] = {cnode};
  829. } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
  830. // To find J user.
  831. j_users.emplace_back(GetJUser(node_user_map, cnode, index));
  832. }
  833. }
  834. mindspore::HashMap<CNodePtr, std::vector<CNodePtr>> primal_user_to_j_users;
  835. for (const auto &j_user : j_users) {
  836. MS_EXCEPTION_IF_NULL(j_user);
  837. auto primal = GetPrimalUser(j_user, primal_map);
  838. if (primal == nullptr) {
  839. continue;
  840. }
  841. MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
  842. << " and J user is: " << j_user->DebugString();
  843. primal_user_to_j_users[primal].emplace_back(j_user);
  844. }
  845. return primal_user_to_j_users;
  846. }
  847. static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
  848. auto &node_users = manager->node_users();
  849. auto iter = node_users.find(primal_call);
  850. if (iter == node_users.end()) {
  851. // Skip if user of primal_call not found.
  852. return;
  853. }
  854. // Find UpdateState nodes after the primal call.
  855. std::vector<CNodePtr> update_states;
  856. for (auto &user : iter->second) {
  857. auto &user_node = user.first;
  858. if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
  859. update_states.emplace_back(user_node->cast<CNodePtr>());
  860. }
  861. }
  862. // Remove UpdateStates by replace them with their monad input.
  863. for (auto &update_state : update_states) {
  864. auto &input_monad = update_state->inputs().at(1);
  865. manager->Replace(update_state, input_monad);
  866. }
  867. }
  868. static bool CopyMonadArguments(const CNodePtr &primal_user, const CNodePtr &j_user) {
  869. auto &primal_inputs = primal_user->inputs();
  870. auto &j_user_inputs = j_user->inputs();
  871. bool has_monad = false;
  872. for (size_t i = 1; i < primal_inputs.size(); ++i) {
  873. auto &input = primal_inputs.at(i);
  874. if (HasAbstractMonad(input)) {
  875. // Copy monad input from primal to j_user.
  876. j_user->set_input(i, input);
  877. has_monad = true;
  878. } else if (input != j_user_inputs.at(i)) {
  879. // Skip if there are different non-monad inputs.
  880. return false;
  881. }
  882. }
  883. return has_monad;
  884. }
  885. //
  886. // To replace the primal graph with k graph.
  887. // Convert:
  888. // x = primal(args, u0)
  889. // u1 = update_state(u0, x)
  890. // ...
  891. // tuple = K(args, u1)
  892. // u2 = update_state(u1, tuple)
  893. // ...
  894. // To:
  895. // tuple = K(args, u0)
  896. // x = get_item(tuple, 0)
  897. // ...
  898. // tuple = K(args, u0)
  899. // u2 = update_state(u0, tuple)
  900. // ...
  901. //
  902. void DFunctor::EliminatePrimalGraph() {
  903. // Find primal user and paired J user cnodes.
  904. auto manager = primal_graph_->manager();
  905. MS_EXCEPTION_IF_NULL(manager);
  906. auto primal_user_to_j_users = FindPrimalJPair(manager, primal_graph_);
  907. for (const auto &iter : primal_user_to_j_users) {
  908. auto primal_user = iter.first;
  909. auto &j_users = iter.second;
  910. MS_EXCEPTION_IF_NULL(primal_user);
  911. if (j_users.size() == 1) {
  912. // If both inputs are same except monads, we copy primal monad args to k graph
  913. // so that they can be combined in CSE (common subexpression elimination) pass.
  914. // Only do this when the size of j_users is 1 in order to keep the execution order.
  915. const bool has_monad = CopyMonadArguments(primal_user, j_users[0]);
  916. // Remove the UpdateState nodes after primal_user if need.
  917. if (has_monad) {
  918. RemovePrimalUpdateStates(manager, primal_user);
  919. }
  920. } else {
  921. MS_LOG(INFO) << "There are multiple j users with the same primal user " << primal_user->DebugString();
  922. }
  923. // Replace primal graph with k graph.
  924. auto k_vnode = NewValueNode(k_graph_);
  925. primal_user->set_input(0, k_vnode);
  926. if (j_users.empty()) {
  927. MS_LOG(EXCEPTION) << "The J nodes for primal graph " << primal_graph_->ToString()
  928. << " should be used by at least one other node.";
  929. }
  930. primal_user->set_abstract(j_users[0]->abstract());
  931. // Insert tuple_getitem after primal user cnode.
  932. auto construct_wrapper = primal_user->func_graph();
  933. auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
  934. auto imm0 = std::make_shared<Int64Imm>(0);
  935. auto idx0 = NewValueNode(SizeToLong(0));
  936. idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
  937. auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
  938. getitem0->CloneCNodeInfo(primal_user);
  939. (void)manager->Replace(primal_user, getitem0);
  940. }
  941. }
  942. } // namespace ad
  943. } // namespace mindspore