You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dfunctor.cc 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/ad/dfunctor.h"
  17. #include <map>
  18. #include <memory>
  19. #include <string>
  20. #include "ir/anf.h"
  21. #include "utils/info.h"
  22. #include "ir/func_graph_cloner.h"
  23. #include "ir/manager.h"
  24. #include "pipeline/jit/resource.h"
  25. #include "pipeline/pynative/pynative_execute.h"
  26. #include "frontend/optimizer/ad/adjoint.h"
  27. #include "frontend/operator/ops.h"
  28. #include "utils/symbolic.h"
  29. #include "utils/ms_context.h"
  30. #include "pipeline/jit/action.h"
  31. #include "pipeline/jit/parse/resolve.h"
  32. namespace mindspore {
  33. namespace ad {
  34. std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
  35. std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
  36. FuncGraphSet DFunctor::scope_;
  37. DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
  38. : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
  39. TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
  40. k_graph_ = std::make_shared<FuncGraph>();
  41. if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  42. std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
  43. k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
  44. }
  45. // To keep switch_layer's inputs from being inlined
  46. k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
  47. TraceManager::EndTrace();
  48. TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
  49. tape_ = std::make_shared<FuncGraph>();
  50. // Add "_Grad" postfix
  51. if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  52. std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
  53. tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
  54. }
  55. TraceManager::EndTrace();
  56. dout_ = tape_->add_parameter();
  57. }
  58. void DFunctor::Init(bool is_top) {
  59. func_graph_to_functor_[primal_graph_] = shared_from_this();
  60. is_top_ = is_top;
  61. if (is_top) {
  62. scope_ = primal_graph_->scope();
  63. }
  64. }
  65. void DFunctor::Finish() {
  66. CallDoutHoleOnTape();
  67. EliminatePrimalGraph();
  68. }
  69. void DFunctor::Clear() {
  70. func_graph_to_functor_.clear();
  71. anfnode_to_adjoin_definition_.clear();
  72. scope_.clear();
  73. }
  74. void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
  75. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  76. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  77. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
  78. << " " << fv->ToString() << ".";
  79. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  80. if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
  81. MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
  82. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  83. auto parent_adjoint = FindAdjoint(fv);
  84. AdjointPtr adjoint = nullptr;
  85. if (parent_adjoint != nullptr) {
  86. adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
  87. } else {
  88. MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
  89. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  90. adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
  91. }
  92. anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
  93. fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
  94. }
  95. }
  96. auto fv_node = fv_adjoint->second->k();
  97. auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
  98. CNodePtr embed_node, default_val_node;
  99. if (cached_envitem_iter != anfnode_to_envitem_.end()) {
  100. embed_node = cached_envitem_iter->second.first;
  101. default_val_node = cached_envitem_iter->second.second;
  102. } else {
  103. embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node});
  104. default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node});
  105. fv_adjoint->second->RegisterKUser(embed_node, 1);
  106. fv_adjoint->second->RegisterKUser(default_val_node, 1);
  107. anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
  108. }
  109. auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
  110. MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
  111. << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
  112. MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
  113. fv_adjoint->second->AccumulateDout(dfv);
  114. }
  115. void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
  116. // Take switch_layer as a set of candidate functions.
  117. auto input = cnode_morph->input(2);
  118. if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
  119. MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
  120. }
  121. auto tuple_graphs = input->cast<CNodePtr>();
  122. for (size_t i = 1; i < tuple_graphs->size(); ++i) {
  123. auto graph = tuple_graphs->input(i);
  124. if (!IsValueNode<FuncGraph>(graph)) {
  125. MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
  126. << " as the " << i << "th element.";
  127. }
  128. auto func_graph = GetValueNode<FuncGraphPtr>(graph);
  129. auto functor = func_graph_to_functor_.find(func_graph);
  130. if (functor == func_graph_to_functor_.end()) {
  131. MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
  132. << func_graph->ToString() << ".";
  133. }
  134. // Consider direct and indirect fvs.
  135. for (auto fv : func_graph->free_variables_nodes()) {
  136. BackPropagateFv(fv, env);
  137. }
  138. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  139. MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
  140. << indirect_fv.first->ToString() << ".";
  141. BackPropagateFv(indirect_fv.first, env);
  142. }
  143. }
  144. }
  145. void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
  146. auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)});
  147. // Call with delimited continuation dout.
  148. auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
  149. node_adjoint->RegisterDoutUser(bprop_app, 1);
  150. // Special case for switch_layer
  151. if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
  152. auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)});
  153. BackPropagateSwitchLayer(cnode_morph, din);
  154. return;
  155. }
  156. for (size_t i = 0; i < cnode_morph->size(); i++) {
  157. auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))});
  158. auto input = cnode_morph->input(i);
  159. // Backprop sens wrt fvs.
  160. if (IsValueNode<FuncGraph>(input)) {
  161. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  162. auto functor = func_graph_to_functor_.find(func_graph);
  163. if (functor == func_graph_to_functor_.end()) {
  164. MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] "
  165. << func_graph->ToString() << ".";
  166. }
  167. // Consider direct and indirect fvs.
  168. for (auto fv : func_graph->free_variables_nodes()) {
  169. BackPropagateFv(fv, din);
  170. }
  171. for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
  172. MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " "
  173. << indirect_fv.first->ToString() << ".";
  174. BackPropagateFv(indirect_fv.first, din);
  175. }
  176. continue;
  177. }
  178. // Backprop sens wrt inputs.
  179. auto input_adjoint = anfnode_to_adjoin_.find(input);
  180. if (input_adjoint == anfnode_to_adjoin_.end()) {
  181. MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
  182. }
  183. input_adjoint->second->AccumulateDout(din);
  184. }
  185. }
  186. // Map a morphism.
  187. AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
  188. MS_LOG(DEBUG) << "start MapMorphism:" << morph->DebugString(4);
  189. // MapMorphism All type except CNode should already be mapped by MapObject.
  190. if (!morph->isa<CNode>()) {
  191. return nullptr;
  192. }
  193. // for free variable, which may be handled in MapValueObject, just return it
  194. auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
  195. if (node_adjoint_found != anfnode_to_adjoin_.end()) {
  196. return node_adjoint_found->second;
  197. }
  198. ScopeGuard scope_guard(morph->scope());
  199. auto cnode_morph = morph->cast<CNodePtr>();
  200. std::vector<AnfNodePtr> inputs;
  201. std::vector<AdjointPtr> param_adjoints;
  202. for (size_t i = 0; i < cnode_morph->size(); i++) {
  203. auto node = cnode_morph->input(i);
  204. auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
  205. AdjointPtr node_adjoint = nullptr;
  206. AnfNodePtr k = nullptr;
  207. if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
  208. node_adjoint = node_adjoint_iter->second;
  209. } else {
  210. // Input might be a CNode that needs to be handled before hand.
  211. node_adjoint = MapMorphism(node);
  212. }
  213. MS_EXCEPTION_IF_NULL(node_adjoint);
  214. k = node_adjoint->k();
  215. if (k == nullptr) {
  216. MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
  217. }
  218. inputs.push_back(k);
  219. param_adjoints.push_back(node_adjoint);
  220. }
  221. TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
  222. auto k_app = k_graph_->NewCNode(inputs);
  223. TraceManager::EndTrace();
  224. ReplaceEquivdout(k_app, cnode_morph);
  225. cnode_morph->set_forward(nullptr, "");
  226. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  227. param_adjoints[i]->RegisterKUser(k_app, i);
  228. }
  229. // Do forward computation
  230. auto foward_app = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(0)});
  231. // K:: cnode -> forward_app
  232. auto node_adjoint = std::make_shared<Adjoint>(morph, foward_app, tape_);
  233. UpdateAdjoint(node_adjoint);
  234. anfnode_to_adjoin_[morph] = node_adjoint;
  235. if (cnode_morph->stop_gradient()) {
  236. MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
  237. return node_adjoint;
  238. }
  239. // Do sens backpropagation
  240. BackPropagate(cnode_morph, k_app, node_adjoint);
  241. MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
  242. return node_adjoint;
  243. }
  244. void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) {
  245. MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>();
  246. if (value->isa<tensor::Tensor>()) {
  247. auto tnode = value->cast<tensor::TensorPtr>();
  248. if (tuple_tensors->find(tnode->id()) != tuple_tensors->end()) {
  249. MS_LOG(DEBUG) << "Set tensor" << tnode->device_address();
  250. (*tuple_tensors)[tnode->id()]->set_device_address(tnode->device_address());
  251. }
  252. }
  253. if (value->isa<ValueTuple>()) {
  254. auto tuple = value->cast<ValueTuplePtr>();
  255. for (size_t i = 0; i < tuple->size(); i++) {
  256. MS_LOG(DEBUG) << "Set tuple tensor" << (*tuple)[i]->ToString();
  257. TensorSetAddress((*tuple)[i], tuple_tensors);
  258. }
  259. }
  260. }
  261. ValuePtr GenNewTensorInner(const ValuePtr &value) {
  262. std::vector<ValuePtr> value_list;
  263. if (value->isa<tensor::Tensor>()) {
  264. auto tensor = value->cast<tensor::TensorPtr>();
  265. auto new_tensor = std::make_shared<tensor::Tensor>(*tensor);
  266. new_tensor->set_device_address(nullptr);
  267. return new_tensor;
  268. }
  269. if (value->isa<ValueTuple>()) {
  270. auto tuple = value->cast<ValueTuplePtr>();
  271. for (size_t i = 0; i < tuple->size(); i++) {
  272. value_list.push_back(GenNewTensorInner((*tuple)[i]));
  273. }
  274. return std::make_shared<ValueTuple>(value_list);
  275. }
  276. return value;
  277. }
  278. ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value) {
  279. ValuePtr out = value;
  280. auto ref_size = mng->node_users()[node].size();
  281. if (ref_size < 2) {
  282. out = GenNewTensorInner(value);
  283. }
  284. return out;
  285. }
  286. void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
  287. auto forward = cnode_morph->forward().first;
  288. auto forward_id = cnode_morph->forward().second;
  289. if (forward == nullptr) {
  290. return;
  291. }
  292. auto &input = cnode->input(0);
  293. if (!IsValueNode<FuncGraph>(input)) {
  294. return;
  295. }
  296. auto fg = GetValueNode<FuncGraphPtr>(input);
  297. auto output = fg->output();
  298. if (!output->isa<CNode>()) {
  299. return;
  300. }
  301. auto cnode_output = output->cast<CNodePtr>();
  302. auto &cnode_input = cnode_output->input(1);
  303. if (!cnode_input->isa<CNode>()) {
  304. return;
  305. }
  306. auto &input_fg = cnode_output->input(2);
  307. if (!IsValueNode<FuncGraph>(input_fg)) {
  308. return;
  309. }
  310. std::map<std::string, tensor::TensorPtr> tuple_tensors;
  311. auto equivdout = cnode_input->cast<CNodePtr>();
  312. auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
  313. auto manager = Manage({fg, func_graph}, false);
  314. auto ref_size = manager->node_users()[equivdout].size();
  315. auto forward_value = forward;
  316. if (!forward_id.empty() && ref_size > 1) {
  317. auto inst = pynative::PynativeExecutor::GetInstance();
  318. inst->SaveOpForwardValue(forward_id, forward_value, &tuple_tensors);
  319. }
  320. forward_value = GenNewTensor(manager, equivdout, forward);
  321. MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
  322. auto value_node = NewValueNode(forward_value);
  323. value_node->set_has_new_value(true);
  324. manager->Replace(equivdout, value_node);
  325. auto paras = fg->parameters();
  326. auto inputs_value = cnode_morph->inputs_value();
  327. if (inputs_value.size() == 0) {
  328. return;
  329. }
  330. if (inputs_value.size() != paras.size()) {
  331. MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " is not equal to inputs size:" << inputs_value.size();
  332. }
  333. for (size_t i = 0; i < paras.size(); i++) {
  334. auto para_ref_size = manager->node_users()[paras[i]].size();
  335. auto input_value = inputs_value[i];
  336. if (para_ref_size > 0 && input_value.first != nullptr) {
  337. MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
  338. auto inst = pynative::PynativeExecutor::GetInstance();
  339. if (!input_value.second.empty()) {
  340. inst->SaveOpForwardValue(input_value.second, input_value.first, &tuple_tensors);
  341. }
  342. auto input_value_node = NewValueNode(input_value.first);
  343. input_value_node->set_has_new_value(true);
  344. manager->Replace(paras[i], input_value_node);
  345. }
  346. }
  347. MS_LOG(DEBUG) << "Start opt node" << fg->output()->DebugString(4);
  348. auto res = std::make_shared<pipeline::Resource>();
  349. res->set_manager(manager);
  350. res->set_func_graph(fg);
  351. PynativeElimOpt(res);
  352. auto out = fg->output()->cast<CNodePtr>();
  353. auto c_input = out->input(1);
  354. if (!c_input->isa<ValueNode>()) {
  355. return;
  356. }
  357. auto out_node = c_input->cast<ValueNodePtr>();
  358. out_node->set_value(GenNewTensor(manager, out_node, out_node->value()));
  359. cnode_morph->clear_inputs_value();
  360. if (tuple_tensors.size() != 0) {
  361. MS_LOG(DEBUG) << "Start tuple out" << fg->output()->DebugString(4);
  362. for (auto &g : manager->func_graphs()) {
  363. for (auto &node : g->value_nodes()) {
  364. MS_LOG(DEBUG) << "Set Tensor addr" << node.first->ToString();
  365. auto vnode = node.first->cast<ValueNodePtr>()->value();
  366. TensorSetAddress(vnode, &tuple_tensors);
  367. }
  368. }
  369. }
  370. fg->ClearAllManagerInfo();
  371. func_graph->ClearAllManagerInfo();
  372. return;
  373. }
  374. bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
  375. // Do not care about non-CNode
  376. if (!node->isa<CNode>()) {
  377. return false;
  378. }
  379. // Do not care about kPrimReturn
  380. if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
  381. return false;
  382. }
  383. auto &users = primal_graph_->manager()->node_users()[node];
  384. // Do not care about isolated morphisms
  385. if (users.empty()) {
  386. return false;
  387. }
  388. // Not free if it's used by some node in primal_graph
  389. bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
  390. auto &user = kv.first;
  391. return user->func_graph() == primal_graph_;
  392. });
  393. return !nonfree;
  394. }
  395. void DFunctor::MapFreeMorphism() {
  396. // Handle cnode not attached to output, that might be refered in other functions.
  397. for (auto &node : primal_graph_->nodes()) {
  398. if (!IsFreeMorphism(node)) {
  399. continue;
  400. }
  401. MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
  402. (void)MapMorphism(node);
  403. }
  404. }
  405. AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
  406. AnfNodePtr new_grad_fv = grad_fv;
  407. // Add grads wrt fv.
  408. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  409. for (auto &fv : free_variables_nodes) {
  410. auto fv_adjoint = anfnode_to_adjoin_.find(fv);
  411. if (fv_adjoint == anfnode_to_adjoin_.end()) {
  412. MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
  413. }
  414. auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
  415. fv_adjoint->second->RegisterKUser(node, 1);
  416. auto sens = fv_adjoint->second->dout();
  417. new_grad_fv = tape_->NewCNode({
  418. NewValueNode(prim::kPrimEnvSetItem),
  419. new_grad_fv,
  420. node,
  421. sens,
  422. });
  423. fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
  424. MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " "
  425. << fv->ToString() << " " << primal_graph_->ToString() << ".";
  426. }
  427. return new_grad_fv;
  428. }
  429. AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
  430. AnfNodePtr new_grad_fv = grad_fv;
  431. // Add indirect fv bprop.
  432. for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
  433. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " "
  434. << primal_graph_->ToString() << ".";
  435. auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
  436. fv_adjoint.second->RegisterKUser(node, 1);
  437. auto sens = fv_adjoint.second->dout();
  438. new_grad_fv = tape_->NewCNode({
  439. NewValueNode(prim::kPrimEnvSetItem),
  440. new_grad_fv,
  441. node,
  442. sens,
  443. });
  444. fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), 3);
  445. MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to "
  446. << new_grad_fv->ToString() << ".";
  447. }
  448. return new_grad_fv;
  449. }
  450. void DFunctor::MapMorphism() {
  451. // Set stop_gradient before MapMorphism.
  452. BroadCastStopFlag();
  453. // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
  454. MapFreeMorphism();
  455. // Handle morphism from output.
  456. (void)MapMorphism(primal_graph_->output());
  457. // Construct K for primal_graph_
  458. auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
  459. // Attach dout_ parameter to output_adjoint.
  460. output_adjoint->second->AccumulateDout(dout_);
  461. // Set output for tape closure.
  462. auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
  463. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
  464. // Add grads wrt inputs.
  465. std::vector<AdjointPtr> param_adjoints;
  466. for (auto &param : primal_graph_->parameters()) {
  467. auto param_adjoint = anfnode_to_adjoin_.find(param);
  468. inputs.push_back(param_adjoint->second->dout());
  469. param_adjoints.push_back(param_adjoint->second);
  470. }
  471. auto tape_output = tape_->NewCNode(inputs);
  472. for (size_t i = 0; i < param_adjoints.size(); ++i) {
  473. param_adjoints[i]->RegisterDoutUser(tape_output, i + 2);
  474. }
  475. tape_->set_output(tape_output);
  476. // Set output for k_graph_, K:: cnode->forward_app.
  477. auto forward_app = output_adjoint->second->k();
  478. auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
  479. output_adjoint->second->RegisterKUser(output, 1);
  480. k_graph_->set_output(output);
  481. (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_)));
  482. (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_)));
  483. }
  484. FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
  485. // K user defined cell bprop.
  486. auto bprop = primal->transforms().find("bprop");
  487. if (bprop != primal->transforms().end()) {
  488. FuncGraphPtr bprop_graph = bprop->second.func_graph();
  489. resources_->manager()->AddFuncGraph(bprop_graph);
  490. if (!bprop_graph->free_variables_nodes().empty() || !primal->free_variables_nodes().empty()) {
  491. MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
  492. << primal->output()->scope()->name() << " does not support Parameter data type.";
  493. }
  494. auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph);
  495. if (fg == nullptr) {
  496. MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
  497. << primal->output()->scope()->name() << ".";
  498. }
  499. // Cache the grad func
  500. (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
  501. (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
  502. // Reset defer_inline to enable successive inlining
  503. primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
  504. auto functor = std::make_shared<DFunctor>(primal, resources_);
  505. functor->Init();
  506. functor->k_graph_ = fg;
  507. return fg;
  508. }
  509. return nullptr;
  510. }
  511. // MapToK(func)
  512. AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
  513. auto f = func_graph_to_functor_.find(primal);
  514. if (f != func_graph_to_functor_.end()) {
  515. MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << ".";
  516. return NewValueNode(f->second->k_graph_);
  517. }
  518. auto k_user_defined = KUserDefined(primal);
  519. if (k_user_defined != nullptr) {
  520. MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << ".";
  521. return NewValueNode(k_user_defined);
  522. }
  523. auto functor = std::make_shared<DFunctor>(primal, resources_);
  524. functor->Init();
  525. functor->MapObject();
  526. functor->MapMorphism();
  527. MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << ".";
  528. return NewValueNode(functor->k_graph_);
  529. }
  530. // Construct representation graph for given node.
  531. AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
  532. ScopeGuard scope_guard(primal->scope());
  533. // MapToK(prim)
  534. if (IsValueNode<Primitive>(primal)) {
  535. auto value_node = primal->cast<ValueNodePtr>();
  536. auto prim = GetValueNode<PrimitivePtr>(value_node);
  537. if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
  538. MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
  539. need_cut_ = true;
  540. }
  541. auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
  542. if (k_prim != nullptr) {
  543. return NewValueNode(k_prim);
  544. }
  545. // When failed to find k_prim, try k_meta.
  546. auto k_meta = g_k_prims.KMetaFuncGraph(prim);
  547. if (k_meta != nullptr) {
  548. return NewValueNode(k_meta);
  549. }
  550. }
  551. // MapToK(func)
  552. if (IsValueNode<FuncGraph>(primal)) {
  553. auto func_graph = GetValueNode<FuncGraphPtr>(primal);
  554. auto k_func = MapToK(func_graph);
  555. return k_func;
  556. }
  557. if (primal->isa<Parameter>()) {
  558. TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal->debug_info()));
  559. auto ret = k_graph_->add_parameter();
  560. TraceManager::EndTrace();
  561. return ret;
  562. }
  563. if (!primal->isa<ValueNode>()) {
  564. MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode.";
  565. }
  566. return primal;
  567. }
  568. bool DFunctor::IsInScope(const AnfNodePtr &node) {
  569. return std::any_of(scope_.begin(), scope_.end(),
  570. [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; });
  571. }
  572. void DFunctor::MapFvObject() {
  573. // Map free variable.
  574. const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
  575. for (auto &node : free_variables_nodes) {
  576. ScopeGuard scope_guard(node->scope());
  577. MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << ".";
  578. // Find fv's K from parent.
  579. AdjointPtr adjoint = nullptr;
  580. auto parent_adjoint = FindAdjoint(node);
  581. if (parent_adjoint != nullptr) {
  582. adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  583. } else {
  584. if (is_top_ || node->isa<Parameter>()) {
  585. // Out of ad scope, add adjoint for free variables.
  586. adjoint = std::make_shared<Adjoint>(node, node, tape_);
  587. UpdateAdjoint(adjoint);
  588. } else {
  589. MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << ".";
  590. adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
  591. }
  592. }
  593. if (adjoint == nullptr) {
  594. MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << ".";
  595. }
  596. anfnode_to_adjoin_[node] = adjoint;
  597. }
  598. }
  599. void DFunctor::MapParamObject() {
  600. // Map parameter.
  601. for (auto &p : primal_graph_->parameters()) {
  602. ScopeGuard scope_guard(p->scope());
  603. MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
  604. auto adjoint = std::make_shared<Adjoint>(p, MapToK(p), tape_);
  605. UpdateAdjoint(adjoint);
  606. anfnode_to_adjoin_[p] = adjoint;
  607. }
  608. }
  609. void DFunctor::MapValueObject() {
  610. // Map ValueNode.
  611. auto manager = resources_->manager();
  612. auto &value_nodes = primal_graph_->value_nodes();
  613. for (const auto &value_pair : value_nodes) {
  614. auto node = value_pair.first;
  615. auto parent_adjoint = FindAdjoint(node);
  616. if (parent_adjoint != nullptr) {
  617. auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
  618. anfnode_to_adjoin_[node] = adjoint;
  619. continue;
  620. }
  621. // Skip Return.
  622. if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
  623. continue;
  624. }
  625. MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << ".";
  626. auto adjoint = std::make_shared<Adjoint>(node, MapToK(node), tape_);
  627. UpdateAdjoint(adjoint);
  628. anfnode_to_adjoin_[node] = adjoint;
  629. }
  630. }
  631. // Skip morphism.
  632. void DFunctor::MapObject() {
  633. // The order does not matter
  634. MapFvObject();
  635. MapParamObject();
  636. MapValueObject();
  637. }
  638. void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
  639. auto primal = adjoint_definition->primal();
  640. if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
  641. MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
  642. << primal->ToString() << ".";
  643. }
  644. anfnode_to_adjoin_definition_[primal] = adjoint_definition;
  645. // Update k hole for primal.
  646. for (auto &f : func_graph_to_functor_) {
  647. auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
  648. if (adjoint != f.second->anfnode_to_adjoin_.end()) {
  649. adjoint->second->UpdateK(adjoint_definition->k());
  650. }
  651. adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
  652. if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
  653. adjoint->second->UpdateK(adjoint_definition->k());
  654. }
  655. }
  656. }
  657. AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
  658. auto adjoint = anfnode_to_adjoin_definition_.find(primal);
  659. if (adjoint != anfnode_to_adjoin_definition_.end()) {
  660. MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << ".";
  661. return adjoint->second;
  662. }
  663. MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << ".";
  664. return nullptr;
  665. }
  666. void DFunctor::CallDoutHoleOnTape() {
  667. if (!is_top_) {
  668. return;
  669. }
  670. // Call dout hole of all adjoint.
  671. for (auto &f : func_graph_to_functor_) {
  672. for (auto &adjoint : f.second->anfnode_to_adjoin_) {
  673. adjoint.second->CallDoutHole();
  674. }
  675. for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
  676. adjoint.second->CallDoutHole();
  677. }
  678. }
  679. }
  680. FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
  681. void DFunctor::BroadCastStopFlag() {
  682. // As stop set expanding, all directly or indirectly stopped CNode will be cut off
  683. while (need_cut_) {
  684. need_cut_ = false;
  685. for (auto &node : primal_graph_->nodes()) {
  686. if (node->isa<CNode>()) {
  687. auto cnode = node->cast<CNodePtr>();
  688. if (!cnode->stop_gradient()) {
  689. // Cut off the cnode only when it's not referred any more
  690. if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) {
  691. MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
  692. cnode->set_stop_gradient(true);
  693. // The stop set changed, more cut required
  694. need_cut_ = true;
  695. }
  696. }
  697. }
  698. }
  699. }
  700. }
  701. bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
  702. auto &users = primal_graph_->manager()->node_users()[node];
  703. // Only care about stop_gradient caused cutting
  704. if (users.empty()) {
  705. return false;
  706. }
  707. for (auto &kv : users) {
  708. auto &user = kv.first;
  709. if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
  710. return false;
  711. }
  712. }
  713. return true;
  714. }
  715. // To replace the primal graph with k graph
  716. void DFunctor::EliminatePrimalGraph() {
  717. auto k_vnode = NewValueNode(k_graph_);
  718. auto idx0 = NewValueNode(SizeToInt(0));
  719. auto imm0 = std::make_shared<Int32Imm>(0);
  720. idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
  721. auto manager = primal_graph_->manager();
  722. auto users = primal_graph_->func_graph_cnodes_index();
  723. for (auto &it : users) {
  724. auto cnode = it.first->first->cast<CNodePtr>();
  725. auto index = it.first->second;
  726. auto vnode = cnode->inputs()[index];
  727. if (index != 0) {
  728. MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
  729. continue;
  730. }
  731. cnode->set_input(0, k_vnode); // Replace primal graph with k graph
  732. auto construct_wrapper = cnode->func_graph();
  733. auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0});
  734. manager->Replace(cnode, getitem0);
  735. }
  736. }
  737. } // namespace ad
  738. } // namespace mindspore