You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dfunctor.cc 37 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/ad/dfunctor.h"
  17. #include <map>
  18. #include <memory>
  19. #include <string>
  20. #include "ir/anf.h"
  21. #include "utils/info.h"
  22. #include "ir/func_graph_cloner.h"
  23. #include "ir/manager.h"
  24. #include "pipeline/jit/resource.h"
  25. #include "pipeline/pynative/pynative_execute.h"
  26. #include "frontend/optimizer/ad/adjoint.h"
  27. #include "frontend/operator/ops.h"
  28. #include "utils/symbolic.h"
  29. #include "utils/ms_context.h"
  30. #include "pipeline/jit/action.h"
  31. #include "pipeline/jit/parse/resolve.h"
  32. namespace mindspore {
  33. namespace ad {
  34. std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
  35. std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
  36. FuncGraphSet DFunctor::scope_;
  37. DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
  38. : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
  39. {
  40. TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
  41. k_graph_ = std::make_shared<FuncGraph>();
  42. }
  43. if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  44. std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
  45. k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
  46. }
  47. // To keep switch_layer's inputs from being inlined
  48. k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
  49. k_graph_->set_stage(primal_graph->stage());
  50. {
  51. TraceGuard guard(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
  52. tape_ = std::make_shared<FuncGraph>();
  53. }
  54. tape_->set_stage(primal_graph->stage());
  55. // Add "_Grad" postfix
  56. if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  57. std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
  58. tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
  59. }
  60. dout_ = tape_->add_parameter();
  61. }
  62. void DFunctor::Init(bool is_top) {
  63. func_graph_to_functor_[primal_graph_] = shared_from_this();
  64. is_top_ = is_top;
  65. if (is_top) {
  66. scope_ = primal_graph_->scope();
  67. }
  68. }
  69. void DFunctor::Finish() {
  70. CallDoutHoleOnTape();
  71. EliminatePrimalGraph();
  72. }
  73. void DFunctor::Clear() {
  74. func_graph_to_functor_.clear();
  75. anfnode_to_adjoin_definition_.clear();
  76. scope_.clear();
  77. }
  78. void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
  79. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  80. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  81. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
  82. << " " << fv->ToString() << ".";
  83. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  84. if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
  85. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
  86. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  87. auto parent_adjoint = FindAdjoint(fv);
  88. AdjointPtr adjoint = nullptr;
  89. if (parent_adjoint != nullptr) {
  90. adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
  91. } else {
  92. MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
  93. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  94. adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
  95. }
  96. anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
  97. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  98. }
  99. }
  100. auto fv_node = fv_adjoint->second->k();
  101. auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
  102. CNodePtr embed_node, default_val_node;
  103. if (cached_envitem_iter != anfnode_to_envitem_.end()) {
  104. embed_node = cached_envitem_iter->second.first;
  105. default_val_node = cached_envitem_iter->second.second;
  106. } else {
  107. embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node});
  108. default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node});
  109. fv_adjoint->second->RegisterKUser(embed_node, 1);
  110. fv_adjoint->second->RegisterKUser(default_val_node, 1);
  111. anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
  112. }
  113. auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
  114. MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
  115. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  116. MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
  117. fv_adjoint->second->AccumulateDout(dfv);
  118. }
  119. void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
  120. // Take switch_layer as a set of candidate functions.
  121. auto input = cnode_morph->input(2);
  122. if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
  123. MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
  124. }
  125. std::unordered_map<AnfNodePtr, FuncGraphPtr> node_to_fg;
  126. auto tuple_graphs = input->cast<CNodePtr>();
  127. for (size_t i = 1; i < tuple_graphs->size(); ++i) {
  128. auto graph = tuple_graphs->input(i);
  129. if (!IsValueNode<FuncGraph>(graph)) {
  130. MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
  131. << " as the " << i << "th element.";
  132. }
  133. auto func_graph = GetValueNode<FuncGraphPtr>(graph);
  134. auto functor = func_graph_to_functor_.find(func_graph);
  135. if (functor == func_graph_to_functor_.end()) {
  136. MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
  137. << func_graph->ToString() << ".";
  138. }
  139. // Consider direct and indirect fvs.
  140. for (auto fv : func_graph->free_variables_nodes()) {
  141. if (node_to_fg.find(fv) != node_to_fg.end()) {
  142. continue;
  143. }
  144. node_to_fg[fv] = func_graph;
  145. BackPropagateFv(fv, env);
  146. }
  147. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  148. MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
  149. << indirect_fv.first->ToString() << ".";
  150. if (node_to_fg.find(indirect_fv.first) != node_to_fg.end()) {
  151. continue;
  152. }
  153. node_to_fg[indirect_fv.first] = func_graph;
  154. BackPropagateFv(indirect_fv.first, env);
  155. }
  156. }
  157. }
  158. static bool HasSideEffectBackProp(const CNodePtr &cnode) {
  159. if (IsPrimitiveCNode(cnode)) {
  160. const auto &prim = GetCNodePrimitive(cnode);
  161. MS_EXCEPTION_IF_NULL(prim);
  162. auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
  163. return bprop_flag;
  164. }
  165. return false;
  166. }
  167. void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
  168. auto bprop =
  169. k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(1))});
  170. // Call with delimited continuation dout.
  171. CNodePtr bprop_app;
  172. if (HasSideEffectBackProp(cnode_morph)) {
  173. // as MapMorphism is called recursively, so the order of bprop_app should reversed as visited order.
  174. bprop_app = tape_->NewCNodeInFront({bprop, node_adjoint->dout()});
  175. tape_->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
  176. } else {
  177. bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
  178. }
  179. node_adjoint->RegisterDoutUser(bprop_app, 1);
  180. // Special case for switch_layer
  181. if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
  182. auto din =
  183. tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(static_cast<int64_t>(0))});
  184. BackPropagateSwitchLayer(cnode_morph, din);
  185. return;
  186. }
  187. for (size_t i = 0; i < cnode_morph->size(); i++) {
  188. auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
  189. auto input = cnode_morph->input(i);
  190. // Backprop sens wrt fvs.
  191. if (IsValueNode<FuncGraph>(input)) {
  192. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  193. auto functor = func_graph_to_functor_.find(func_graph);
  194. if (functor == func_graph_to_functor_.end()) {
  195. MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] "
  196. << func_graph->ToString() << ".";
  197. }
  198. // Consider direct and indirect fvs.
  199. for (auto fv : func_graph->free_variables_nodes()) {
  200. BackPropagateFv(fv, din);
  201. }
  202. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  203. MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " "
  204. << indirect_fv.first->ToString() << ".";
  205. BackPropagateFv(indirect_fv.first, din);
  206. }
  207. continue;
  208. }
  209. // Backprop sens wrt inputs.
  210. auto input_adjoint = anfnode_to_adjoin_.find(input);
  211. if (input_adjoint == anfnode_to_adjoin_.end()) {
  212. MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
  213. }
  214. input_adjoint->second->AccumulateDout(din);
  215. }
  216. }
  217. // Map a morphism.
  218. AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
  219. MS_LOG(DEBUG) << "start MapMorphism:" << morph->DebugString(4);
  220. // MapMorphism All type except CNode should already be mapped by MapObject.
  221. if (!morph->isa<CNode>()) {
  222. return nullptr;
  223. }
  224. // for free variable, which may be handled in MapValueObject, just return it
  225. auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
  226. if (node_adjoint_found != anfnode_to_adjoin_.end()) {
  227. return node_adjoint_found->second;
  228. }
  229. ScopeGuard scope_guard(morph->scope());
  230. auto cnode_morph = morph->cast<CNodePtr>();
  231. std::vector<AnfNodePtr> inputs;
  232. std::vector<AdjointPtr> param_adjoints;
  233. for (size_t i = 0; i < cnode_morph->size(); i++) {
  234. auto node = cnode_morph->input(i);
  235. AdjointPtr node_adjoint = nullptr;
  236. auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
  237. if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
  238. node_adjoint = node_adjoint_iter->second;
  239. } else {
  240. // Input might be a CNode that needs to be handled previously.
  241. node_adjoint = MapMorphism(node);
  242. }
  243. MS_EXCEPTION_IF_NULL(node_adjoint);
  244. AnfNodePtr k = node_adjoint->k();
  245. if (k == nullptr) {
  246. MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
  247. }
  248. inputs.push_back(k);
  249. param_adjoints.push_back(node_adjoint);
  250. }
  251. CNodePtr k_app = nullptr;
  252. {
  253. TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
  254. k_app = k_graph_->NewCNode(inputs);
  255. }
  256. ReplaceEquivdout(k_app, cnode_morph);
  257. cnode_morph->clear_inputs_value();
  258. cnode_morph->set_forward(nullptr, "");
  259. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  260. param_adjoints[i]->RegisterKUser(k_app, i);
  261. }
  262. // Do forward computation
  263. auto foward_app =
  264. k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(0))});
  265. // K:: cnode -> forward_app
  266. auto node_adjoint = std::make_shared<Adjoint>(morph, foward_app, tape_);
  267. UpdateAdjoint(node_adjoint);
  268. anfnode_to_adjoin_[morph] = node_adjoint;
  269. if (cnode_morph->stop_gradient()) {
  270. MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
  271. return node_adjoint;
  272. }
  273. // Do sens backpropagation
  274. BackPropagate(cnode_morph, k_app, node_adjoint);
  275. MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
  276. return node_adjoint;
  277. }
  278. ValuePtr DFunctor::GenNewTensorInner(const ValuePtr &value) {
  279. std::vector<ValuePtr> value_list;
  280. if (value->isa<tensor::Tensor>()) {
  281. auto tensor = value->cast<tensor::TensorPtr>();
  282. auto new_tensor = std::make_shared<tensor::Tensor>(*tensor);
  283. new_tensor->set_device_address(nullptr);
  284. return new_tensor;
  285. }
  286. if (value->isa<ValueTuple>()) {
  287. auto tuple = value->cast<ValueTuplePtr>();
  288. for (size_t i = 0; i < tuple->size(); i++) {
  289. value_list.push_back(GenNewTensorInner((*tuple)[i]));
  290. }
  291. return std::make_shared<ValueTuple>(value_list);
  292. }
  293. return value;
  294. }
  295. ValuePtr DFunctor::GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value,
  296. bool need_replace_forward) {
  297. ValuePtr out = value;
  298. auto ref_size = mng->node_users()[node].size();
  299. if (ref_size < 2) {
  300. if (need_replace_forward) {
  301. out = GenNewTensorInner(value);
  302. } else {
  303. auto tensor = value->cast<tensor::TensorPtr>();
  304. tensor->set_device_address(nullptr);
  305. return tensor;
  306. }
  307. }
  308. return out;
  309. }
  310. void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
  311. auto forward = cnode_morph->forward().first;
  312. if (forward == nullptr) {
  313. return;
  314. }
  315. auto &input = cnode->input(0);
  316. if (!IsValueNode<FuncGraph>(input)) {
  317. return;
  318. }
  319. auto fg = GetValueNode<FuncGraphPtr>(input);
  320. // {prim::maketuple, forward_output, bprop_graph}
  321. auto output = fg->output();
  322. if (!output->isa<CNode>()) {
  323. return;
  324. }
  325. auto cnode_output = output->cast<CNodePtr>();
  326. auto &cnode_input = cnode_output->input(1);
  327. if (!cnode_input->isa<CNode>()) {
  328. return;
  329. }
  330. auto &input_fg = cnode_output->input(2);
  331. if (!IsValueNode<FuncGraph>(input_fg)) {
  332. return;
  333. }
  334. // replace forward output with value node
  335. auto equivdout = cnode_input->cast<CNodePtr>();
  336. MS_EXCEPTION_IF_NULL(equivdout);
  337. auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
  338. MS_EXCEPTION_IF_NULL(func_graph);
  339. auto manager = Manage({fg, func_graph}, false);
  340. auto need_replace_forward = pynative::PynativeExecutor::GetInstance()->need_replace_forward();
  341. auto forward_value = GenNewTensor(manager, equivdout, forward, need_replace_forward);
  342. if (!need_replace_forward) {
  343. cnode_morph->clear_inputs_value();
  344. MS_LOG(DEBUG) << "No need replace forward result";
  345. return;
  346. }
  347. MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
  348. auto value_node = NewValueNode(forward_value);
  349. value_node->set_has_new_value(true);
  350. manager->Replace(equivdout, value_node);
  351. // replace input object with value node
  352. auto paras = fg->parameters();
  353. auto inputs_value = cnode_morph->inputs_value();
  354. if (inputs_value.empty()) {
  355. return;
  356. }
  357. if (inputs_value.size() > paras.size()) {
  358. MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " but inputs size:" << inputs_value.size();
  359. }
  360. for (size_t i = 0; i < inputs_value.size(); i++) {
  361. auto para_ref_size = manager->node_users()[paras[i]].size();
  362. auto input_value = inputs_value[i];
  363. if (para_ref_size > 0 && input_value.first != nullptr) {
  364. MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
  365. auto input_value_node = NewValueNode(input_value.first);
  366. input_value_node->set_has_new_value(true);
  367. input_value_node->set_used_graph_count(para_ref_size);
  368. manager->Replace(paras[i], input_value_node);
  369. }
  370. }
  371. MS_LOG(DEBUG) << "Start opt node" << fg->output()->DebugString(4);
  372. auto res = std::make_shared<pipeline::Resource>();
  373. res->set_manager(manager);
  374. res->set_func_graph(fg);
  375. PynativeElimOpt(res);
  376. auto out = fg->output()->cast<CNodePtr>();
  377. MS_EXCEPTION_IF_NULL(out);
  378. auto c_input = out->input(1);
  379. MS_EXCEPTION_IF_NULL(c_input);
  380. if (!c_input->isa<ValueNode>()) {
  381. return;
  382. }
  383. auto out_node = c_input->cast<ValueNodePtr>();
  384. MS_EXCEPTION_IF_NULL(out_node);
  385. out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), need_replace_forward));
  386. // clear resource
  387. fg->ClearAllManagerInfo();
  388. func_graph->ClearAllManagerInfo();
  389. }
  390. bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
  391. // Do not care about non-CNode
  392. if (!node->isa<CNode>()) {
  393. return false;
  394. }
  395. // Do not care about kPrimReturn
  396. if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
  397. return false;
  398. }
  399. auto &users = primal_graph_->manager()->node_users()[node];
  400. // Do not care about isolated morphisms
  401. if (users.empty()) {
  402. return false;
  403. }
  404. // Not free if it's used by some node in primal_graph
  405. bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
  406. auto &user = kv.first;
  407. return user->func_graph() == primal_graph_;
  408. });
  409. return !nonfree;
  410. }
  411. void DFunctor::MapFreeMorphism() {
  412. // Handle cnode not attached to output, that might be referred in other functions.
  413. for (auto &node : primal_graph_->nodes()) {
  414. if (!IsFreeMorphism(node)) {
  415. continue;
  416. }
  417. MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
  418. (void)MapMorphism(node);
  419. }
  420. }
  421. AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
  422. AnfNodePtr new_grad_fv = grad_fv;
  423. // Add grads wrt fv.
  424. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  425. for (auto &fv : free_variables_nodes) {
  426. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  427. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  428. MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
  429. }
  430. auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
  431. fv_adjoint->second->RegisterKUser(node, 1);
  432. auto sens = fv_adjoint->second->dout();
  433. new_grad_fv = tape_->NewCNode({
  434. NewValueNode(prim::kPrimEnvSetItem),
  435. new_grad_fv,
  436. node,
  437. sens,
  438. });
  439. fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
  440. MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " "
  441. << fv->ToString() << " " << primal_graph_->ToString() << ".";
  442. }
  443. return new_grad_fv;
  444. }
  445. AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
  446. AnfNodePtr new_grad_fv = grad_fv;
  447. // Add indirect fv bprop.
  448. for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
  449. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " "
  450. << primal_graph_->ToString() << ".";
  451. auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
  452. fv_adjoint.second->RegisterKUser(node, 1);
  453. auto sens = fv_adjoint.second->dout();
  454. new_grad_fv = tape_->NewCNode({
  455. NewValueNode(prim::kPrimEnvSetItem),
  456. new_grad_fv,
  457. node,
  458. sens,
  459. });
  460. fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
  461. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to "
  462. << new_grad_fv->ToString() << ".";
  463. }
  464. return new_grad_fv;
  465. }
  466. void DFunctor::MapMorphism() {
  467. // Set stop_gradient before MapMorphism.
  468. BroadCastStopFlag();
  469. // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
  470. MapFreeMorphism();
  471. // Handle morphism from output.
  472. (void)MapMorphism(primal_graph_->output());
  473. // Construct K for primal_graph_
  474. auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
  475. // Attach dout_ parameter to output_adjoint.
  476. output_adjoint->second->AccumulateDout(dout_);
  477. // Set output for tape closure.
  478. auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
  479. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
  480. // Add grads wrt inputs.
  481. std::vector<AdjointPtr> param_adjoints;
  482. for (auto &param : primal_graph_->parameters()) {
  483. auto param_adjoint = anfnode_to_adjoin_.find(param);
  484. inputs.push_back(param_adjoint->second->dout());
  485. param_adjoints.push_back(param_adjoint->second);
  486. }
  487. auto tape_output = tape_->NewCNode(inputs);
  488. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  489. param_adjoints[i]->RegisterDoutUser(tape_output, i + 2);
  490. }
  491. tape_->set_output(tape_output);
  492. // Set output for k_graph_, K:: cnode->forward_app.
  493. auto forward_app = output_adjoint->second->k();
  494. auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
  495. output_adjoint->second->RegisterKUser(output, 1);
  496. k_graph_->set_output(output);
  497. (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_)));
  498. (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_)));
  499. }
  500. FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
  501. // K user defined cell bprop.
  502. auto bprop = primal->transforms().find("bprop");
  503. if (bprop != primal->transforms().end()) {
  504. FuncGraphPtr bprop_graph = bprop->second.func_graph();
  505. resources_->manager()->AddFuncGraph(bprop_graph);
  506. if (!bprop_graph->free_variables_nodes().empty() || !primal->free_variables_nodes().empty()) {
  507. MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
  508. << primal->output()->scope()->name() << " does not support Parameter data type.";
  509. }
  510. bprop_graph->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);
  511. bprop_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
  512. auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph, primal);
  513. if (fg == nullptr) {
  514. MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
  515. << primal->output()->scope()->name() << ".";
  516. }
  517. // Cache the grad func
  518. (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
  519. (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
  520. // Reset defer_inline to enable successive inlining
  521. primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
  522. auto functor = std::make_shared<DFunctor>(primal, resources_);
  523. functor->Init();
  524. functor->k_graph_ = fg;
  525. return fg;
  526. }
  527. return nullptr;
  528. }
  529. // Construct representation graph for {CNode, Index} of Primitive.
  530. AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) {
  531. auto primal = primitive_user->input(index);
  532. if (!IsValueNode<Primitive>(primal)) {
  533. MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive.";
  534. }
  535. ScopeGuard scope_guard(primal->scope());
  536. // Map Primitive to K
  537. auto value_node = primal->cast<ValueNodePtr>();
  538. auto prim = GetValueNode<PrimitivePtr>(value_node);
  539. if ((prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) ||
  540. (prim->Hash() == prim::kPrimUpdateState->Hash() && prim->name() == prim::kPrimUpdateState->name())) {
  541. MS_LOG(DEBUG) << "Should stop gradient for " << prim->ToString();
  542. need_cut_ = true;
  543. }
  544. auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_);
  545. if (k_prim != nullptr) {
  546. return NewValueNode(k_prim);
  547. }
  548. // When failed to find k_prim, try k_meta.
  549. auto k_meta = g_k_prims.KMetaFuncGraph(prim);
  550. if (k_meta != nullptr) {
  551. return NewValueNode(k_meta);
  552. }
  553. MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K.";
  554. }
  555. // Construct representation graph for ValueNode of FuncGraph.
  556. AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) {
  557. if (!IsValueNode<FuncGraph>(primal)) {
  558. MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph.";
  559. }
  560. ScopeGuard scope_guard(primal->scope());
  561. // Map func graph to K
  562. auto func_graph = GetValueNode<FuncGraphPtr>(primal);
  563. auto f = func_graph_to_functor_.find(func_graph);
  564. if (f != func_graph_to_functor_.end()) {
  565. MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << ".";
  566. return NewValueNode(f->second->k_graph_);
  567. }
  568. auto k_user_defined = KUserDefined(func_graph);
  569. if (k_user_defined != nullptr) {
  570. MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << ".";
  571. return NewValueNode(k_user_defined);
  572. }
  573. auto functor = std::make_shared<DFunctor>(func_graph, resources_);
  574. functor->Init();
  575. functor->MapObject();
  576. functor->MapMorphism();
  577. MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\"";
  578. return NewValueNode(functor->k_graph_);
  579. }
  580. // Construct for ValueNode of Parameter.
  581. AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) {
  582. if (!primal->isa<Parameter>()) {
  583. MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter.";
  584. }
  585. ScopeGuard scope_guard(primal->scope());
  586. // Map Parameter to K
  587. TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
  588. auto ret = k_graph_->add_parameter();
  589. return ret;
  590. }
  591. bool DFunctor::IsInScope(const AnfNodePtr &node) {
  592. return std::any_of(scope_.begin(), scope_.end(),
  593. [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; });
  594. }
  595. void DFunctor::MapFvObject() {
  596. // Map free variable.
  597. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  598. for (auto &node : free_variables_nodes) {
  599. ScopeGuard scope_guard(node->scope());
  600. MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << ".";
  601. // Find fv's K from parent.
  602. AdjointPtr adjoint = nullptr;
  603. auto parent_adjoint = FindAdjoint(node);
  604. if (parent_adjoint != nullptr) {
  605. adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  606. } else {
  607. if (is_top_ || node->isa<Parameter>()) {
  608. // Out of ad scope, add adjoint for free variables.
  609. adjoint = std::make_shared<Adjoint>(node, node, tape_);
  610. UpdateAdjoint(adjoint);
  611. } else {
  612. MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << ".";
  613. adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
  614. }
  615. }
  616. if (adjoint == nullptr) {
  617. MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << ".";
  618. }
  619. anfnode_to_adjoin_[node] = adjoint;
  620. }
  621. }
  622. void DFunctor::MapParamObject() {
  623. // Map parameter.
  624. for (auto &p : primal_graph_->parameters()) {
  625. ScopeGuard scope_guard(p->scope());
  626. MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
  627. auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_);
  628. UpdateAdjoint(adjoint);
  629. anfnode_to_adjoin_[p] = adjoint;
  630. }
  631. }
  632. void DFunctor::MapValueObject() {
  633. // Map ValueNode.
  634. auto manager = resources_->manager();
  635. auto &value_nodes = primal_graph_->value_nodes();
  636. for (const auto &value_pair : value_nodes) {
  637. auto node = value_pair.first;
  638. auto parent_adjoint = FindAdjoint(node);
  639. if (parent_adjoint != nullptr) {
  640. auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  641. anfnode_to_adjoin_[node] = adjoint;
  642. continue;
  643. }
  644. AdjointPtr adjoint = nullptr;
  645. if (IsValueNode<Primitive>(node)) { // Primitive.
  646. if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
  647. continue;
  648. }
  649. MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
  650. auto &users = manager->node_users()[node];
  651. if (users.size() == 0) {
  652. MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user.";
  653. continue;
  654. } else if (users.size() > 1) {
  655. MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size();
  656. }
  657. auto cnode = users.begin()->first->cast<CNodePtr>(); // We just use the first user.
  658. auto index = users.begin()->second;
  659. adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_);
  660. } else if (IsValueNode<FuncGraph>(node)) { // FuncGraph
  661. MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << ".";
  662. adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_);
  663. } else if (node->isa<Parameter>()) { // Parameter, hardly reach here.
  664. MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << ".";
  665. adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_);
  666. } else {
  667. adjoint = std::make_shared<Adjoint>(node, node, tape_);
  668. }
  669. UpdateAdjoint(adjoint);
  670. anfnode_to_adjoin_[node] = adjoint;
  671. }
  672. }
  673. // Skip morphism.
  674. void DFunctor::MapObject() {
  675. // The order does not matter
  676. MapFvObject();
  677. MapParamObject();
  678. MapValueObject();
  679. }
  680. void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
  681. auto primal = adjoint_definition->primal();
  682. if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
  683. MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
  684. << primal->ToString() << ".";
  685. }
  686. anfnode_to_adjoin_definition_[primal] = adjoint_definition;
  687. // Update k hole for primal.
  688. for (auto &f : func_graph_to_functor_) {
  689. auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
  690. if (adjoint != f.second->anfnode_to_adjoin_.end()) {
  691. adjoint->second->UpdateK(adjoint_definition->k());
  692. }
  693. adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
  694. if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
  695. adjoint->second->UpdateK(adjoint_definition->k());
  696. }
  697. }
  698. }
  699. AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
  700. auto adjoint = anfnode_to_adjoin_definition_.find(primal);
  701. if (adjoint != anfnode_to_adjoin_definition_.end()) {
  702. MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << ".";
  703. return adjoint->second;
  704. }
  705. MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << ".";
  706. return nullptr;
  707. }
  708. void DFunctor::CallDoutHoleOnTape() {
  709. if (!is_top_) {
  710. return;
  711. }
  712. // Call dout hole of all adjoint.
  713. for (auto &f : func_graph_to_functor_) {
  714. for (auto &adjoint : f.second->anfnode_to_adjoin_) {
  715. adjoint.second->CallDoutHole();
  716. }
  717. for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
  718. adjoint.second->CallDoutHole();
  719. }
  720. }
  721. }
  722. FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
  723. FuncGraphPtr DFunctor::tape() { return tape_; }
  724. void DFunctor::BroadCastStopFlag() {
  725. // As stop set expanding, all directly or indirectly stopped CNode will be cut off
  726. while (need_cut_) {
  727. need_cut_ = false;
  728. for (auto &node : primal_graph_->nodes()) {
  729. if (node->isa<CNode>()) {
  730. auto cnode = node->cast<CNodePtr>();
  731. if (!cnode->stop_gradient()) {
  732. // Cut off the cnode only when it's not referred any more
  733. if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
  734. AllReferencesStopped(cnode)) {
  735. MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
  736. cnode->set_stop_gradient(true);
  737. // The stop set changed, more cut required
  738. need_cut_ = true;
  739. }
  740. }
  741. }
  742. }
  743. }
  744. }
  745. bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
  746. auto &users = primal_graph_->manager()->node_users()[node];
  747. // Only care about stop_gradient caused cutting
  748. if (users.empty()) {
  749. return false;
  750. }
  751. for (auto &kv : users) {
  752. auto &user = kv.first;
  753. if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
  754. return false;
  755. }
  756. }
  757. return true;
  758. }
  759. static std::pair<CNodePtr, CNodePtr> FindPrimalJPair(const FuncGraphManagerPtr &manager,
  760. const FuncGraphPtr &primal_graph) {
  761. CNodePtr primal_user = nullptr;
  762. CNodePtr j_user = nullptr;
  763. auto &node_user_map = manager->node_users();
  764. // Search primal graph user cnodes.
  765. for (auto &entry : primal_graph->func_graph_cnodes_index()) {
  766. auto cnode = entry.first->first->cast<CNodePtr>();
  767. auto index = entry.first->second;
  768. if (index == 0) {
  769. // To find real calling.
  770. primal_user = cnode;
  771. } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
  772. // To find J user.
  773. auto it = node_user_map.find(cnode);
  774. if (it == node_user_map.end()) {
  775. MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}";
  776. }
  777. auto &j_users = it->second;
  778. auto size = j_users.size();
  779. if (size != 1) {
  780. MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}";
  781. }
  782. j_user = j_users.begin()->first->cast<CNodePtr>();
  783. }
  784. if (j_user != nullptr && primal_user != nullptr) {
  785. break;
  786. }
  787. }
  788. return {primal_user, j_user};
  789. }
  790. static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
  791. auto &node_users = manager->node_users();
  792. auto iter = node_users.find(primal_call);
  793. if (iter == node_users.end()) {
  794. // Skip if user of primal_call not found.
  795. return;
  796. }
  797. // Find UpdateState nodes after the primal call.
  798. std::vector<CNodePtr> update_states;
  799. for (auto &user : iter->second) {
  800. auto &user_node = user.first;
  801. if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
  802. update_states.emplace_back(user_node->cast<CNodePtr>());
  803. }
  804. }
  805. // Remove UpdateStates by replace them with their monad input.
  806. for (auto &update_state : update_states) {
  807. auto &input_monad = update_state->inputs().at(1);
  808. manager->Replace(update_state, input_monad);
  809. }
  810. }
  811. static bool CopyMonadArguments(const CNodePtr &primal_user, const CNodePtr &j_user) {
  812. auto &primal_inputs = primal_user->inputs();
  813. auto &j_user_inputs = j_user->inputs();
  814. bool has_monad = false;
  815. for (size_t i = 1; i < primal_inputs.size(); ++i) {
  816. auto &input = primal_inputs.at(i);
  817. if (HasAbstractMonad(input)) {
  818. // Copy monad input from primal to j_user.
  819. j_user->set_input(i, input);
  820. has_monad = true;
  821. } else if (input != j_user_inputs.at(i)) {
  822. // Skip if there are different non-monad inputs.
  823. return false;
  824. }
  825. }
  826. return has_monad;
  827. }
  828. //
  829. // To replace the primal graph with k graph.
  830. // Convert:
  831. // x = primal(args, u0)
  832. // u1 = update_state(u0, x)
  833. // ...
  834. // tuple = K(args, u1)
  835. // u2 = update_state(u1, tuple)
  836. // ...
  837. // To:
  838. // tuple = K(args, u0)
  839. // x = get_item(tuple, 0)
  840. // ...
  841. // tuple = K(args, u0)
  842. // u2 = update_state(u0, tuple)
  843. // ...
  844. //
  845. void DFunctor::EliminatePrimalGraph() {
  846. // Find primal user and paired J user cnodes.
  847. auto manager = primal_graph_->manager();
  848. MS_EXCEPTION_IF_NULL(manager);
  849. auto [primal_user, j_user] = FindPrimalJPair(manager, primal_graph_);
  850. if (primal_user == nullptr || j_user == nullptr) {
  851. // Skip if one of them not found.
  852. return;
  853. }
  854. // Check input size.
  855. if (primal_user->size() != j_user->size()) {
  856. MS_LOG(WARNING) << "Input size incorrect, primal:" << primal_user->DebugString()
  857. << " juser:" << j_user->DebugString();
  858. return;
  859. }
  860. // Replace primal graph with k graph.
  861. auto k_vnode = NewValueNode(k_graph_);
  862. auto primal_abs = primal_user->abstract();
  863. primal_user->set_input(0, k_vnode);
  864. primal_user->set_abstract(j_user->abstract());
  865. // If both inputs are same except monads, we copy primal monad args to k graph
  866. // so that they can be combined in CSE (common subexpression elimination) pass.
  867. const bool has_monad = CopyMonadArguments(primal_user, j_user);
  868. // Remove the UpdateState nodes after primal_user if need.
  869. if (has_monad) {
  870. RemovePrimalUpdateStates(manager, primal_user);
  871. }
  872. // Insert tuple_getitem after primal user cnode.
  873. auto construct_wrapper = primal_user->func_graph();
  874. auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
  875. auto imm0 = std::make_shared<Int64Imm>(0);
  876. auto idx0 = NewValueNode(SizeToLong(0));
  877. idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
  878. auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
  879. getitem0->set_abstract(primal_abs);
  880. manager->Replace(primal_user, getitem0);
  881. }
  882. } // namespace ad
  883. } // namespace mindspore