You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kprim.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. #include "ir/anf.h"
  22. #include "ir/primitive.h"
  23. #include "ir/meta_func_graph.h"
  24. #include "ir/func_graph_cloner.h"
  25. #include "ir/manager.h"
  26. #include "pipeline/resource.h"
  27. #include "pipeline/parse/parse.h"
  28. #include "optimizer/ad/dfunctor.h"
  29. #include "optimizer/opt.h"
  30. #include "operator/ops.h"
  31. #include "operator/composite/composite.h"
  32. #include "utils/symbolic.h"
  33. #include "utils/primitive_utils.h"
  34. #include "utils/context/ms_context.h"
  35. #include "debug/info.h"
  36. #include "debug/trace.h"
  37. #include "./common.h"
  38. namespace mindspore {
  39. namespace ad {
  40. using PatternListType = std::initializer_list<BaseRef>;
  41. KPrim g_k_prims;
  42. FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
  43. // Set a child scope named "grad'PrimitiveName'" for the bprop function,
  44. // and add "Gradients" to the front.
  45. static const std::string gradients_scope = "Gradients/";
  46. static const std::string grad_op_child_scope_prefix = "/grad";
  47. MS_EXCEPTION_IF_NULL(prim);
  48. auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
  49. grad_op_child_scope_prefix + prim->name());
  50. ScopeGuard scope_guard(scope);
  51. py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast<PrimitivePyPtr>()->GetBpropFunction();
  52. if (fn == nullptr || py::isinstance<py::none>(fn)) {
  53. MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
  54. return nullptr;
  55. }
  56. FuncGraphPtr func_graph = parse::ParsePythonCode(fn);
  57. if (func_graph == nullptr) {
  58. MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
  59. return nullptr;
  60. }
  61. return func_graph;
  62. }
  63. FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
  64. static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
  65. std::string func_name = "_fprop_" + prim->name();
  66. py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name);
  67. auto func_graph = parse::ParsePythonCode(fn);
  68. MS_EXCEPTION_IF_NULL(func_graph);
  69. return BasicClone(func_graph);
  70. }
  71. MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
  72. MS_EXCEPTION_IF_NULL(prim);
  73. auto iter = bprop_registry_meta_.find(prim);
  74. if (iter != bprop_registry_meta_.end()) {
  75. return iter->second;
  76. }
  77. if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
  78. MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
  79. bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
  80. return meta;
  81. }
  82. MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
  83. }
  84. FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  85. if (!IsValueNode<Primitive>(value_node)) {
  86. MS_LOG(EXCEPTION) << "Primitive node is not valid.";
  87. }
  88. auto prim = GetValueNode<PrimitivePtr>(value_node);
  89. if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
  90. auto fprop = GetFprop(prim);
  91. fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
  92. return fprop;
  93. } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
  94. return nullptr;
  95. }
  96. FuncGraphPtr bprop_fg = nullptr;
  97. if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
  98. bprop_fg = BpropCut(value_node, resources);
  99. } else {
  100. auto iter = bprop_registry_.find(prim);
  101. if (iter != bprop_registry_.end()) {
  102. bprop_fg = iter->second;
  103. }
  104. if (bprop_fg == nullptr) {
  105. bprop_fg = GetBprop(prim);
  106. if (bprop_fg != nullptr) {
  107. // Set bprop_g graph cache
  108. bprop_registry_[prim] = bprop_fg;
  109. } else {
  110. bprop_fg = FakeBprop(value_node, resources);
  111. }
  112. }
  113. }
  114. auto expanded_fg = BpropToK(prim, bprop_fg);
  115. if (expanded_fg == nullptr) {
  116. MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
  117. << " prim bprop function to J expanded func graph. NodeInfo: "
  118. << trace::GetDebugInfo(bprop_fg->debug_info());
  119. }
  120. return expanded_fg;
  121. }
  122. AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg) {
  123. // bprop_fg has been checked in caller
  124. if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
  125. // Set bprop output as (env, dx, dy, dz, ...)
  126. auto cbprop = bprop_fg->output()->cast<CNodePtr>();
  127. auto &inputs = cbprop->inputs();
  128. std::vector<AnfNodePtr> args;
  129. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  130. args.push_back(NewValueNode(newenv));
  131. (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
  132. return NewCNode(args, bprop_fg);
  133. }
  134. // Set bprop output as (env, dx)
  135. std::string model_name("mindspore.ops.composite.multitype_ops.add_impl");
  136. std::string python_ops("_tuple_add");
  137. auto tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg);
  138. return NewCNode({NewValueNode(prim::GetPythonOps(python_ops, model_name)), tuple, bprop_fg->output()}, bprop_fg);
  139. }
  140. void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
  141. std::vector<AnfNodePtr> *const transf_args) {
  142. MS_EXCEPTION_IF_NULL(mng);
  143. // bprop_fg has been checked in caller
  144. // transform except the last 2 parameters: out, dout.
  145. for (size_t i = 0; i < bprop_fg->parameters().size() - 2; ++i) {
  146. auto p = bprop_fg->parameters()[i];
  147. MS_EXCEPTION_IF_NULL(p);
  148. TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(p->debug_info()));
  149. auto transf_p = outer->add_parameter();
  150. TraceManager::EndTrace();
  151. (void)mng->Replace(p, transf_p);
  152. transf_args->push_back(transf_p);
  153. }
  154. }
  155. void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
  156. auto context = MsContext::GetInstance();
  157. MS_EXCEPTION_IF_NULL(context);
  158. bool check_bprop_flag = context->check_bprop_flag();
  159. // Skip checking if check_bprop not set
  160. if (!check_bprop_flag) {
  161. return;
  162. }
  163. // bprop_fg has been checked in caller
  164. auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
  165. MS_EXCEPTION_IF_NULL(check_bprop_class);
  166. auto check_bprop =
  167. bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
  168. std::vector<AnfNodePtr> inputs;
  169. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  170. inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2);
  171. AnfNodePtr params = bprop_fg->NewCNode(inputs);
  172. inputs.clear();
  173. inputs.push_back(check_bprop);
  174. inputs.push_back(bprop_fg->output());
  175. inputs.push_back(params);
  176. AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
  177. bprop_fg->set_output(bprop_out);
  178. }
  179. FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
  180. MS_EXCEPTION_IF_NULL(bprop_fg);
  181. auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph();
  182. auto expanded_fg = BpropToK(fprop_fg, bprop_fg);
  183. if (expanded_fg == nullptr) {
  184. MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString()
  185. << " Cell bprop function to K expanded func graph. NodeInfo: "
  186. << trace::GetDebugInfo(fprop_fg->debug_info());
  187. }
  188. return expanded_fg;
  189. }
  190. FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  191. auto prim = GetValueNode<PrimitivePtr>(value_node);
  192. MS_EXCEPTION_IF_NULL(prim);
  193. auto &node_users = resources->manager()->node_users();
  194. auto &users = node_users[value_node];
  195. auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int> &user) -> bool {
  196. return IsPrimitiveCNode(user.first, prim);
  197. });
  198. if (cnode == users.end()) {
  199. MS_LOG(EXCEPTION) << "Fail to find cnode.";
  200. }
  201. auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
  202. auto func_graph = std::make_shared<FuncGraph>();
  203. std::vector<AnfNodePtr> outputs;
  204. auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object());
  205. if (!prim->is_base()) {
  206. PrimitivePyPtr prim_py = dyn_cast<PrimitivePy>(prim);
  207. bprop_cut->set_hook(prim_py->hook());
  208. }
  209. auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
  210. if (cell_id != "") {
  211. (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
  212. (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
  213. }
  214. outputs.push_back(NewValueNode(bprop_cut));
  215. for (size_t i = 0; i < inputs_num; ++i) {
  216. auto param = func_graph->add_parameter();
  217. outputs.push_back(param);
  218. }
  219. auto p1 = func_graph->add_parameter();
  220. auto p2 = func_graph->add_parameter();
  221. outputs.push_back(p1);
  222. outputs.push_back(p2);
  223. func_graph->set_output(func_graph->NewCNode(outputs));
  224. return func_graph;
  225. }
  226. FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  227. auto prim = value_node->value()->cast<PrimitivePtr>();
  228. MS_EXCEPTION_IF_NULL(prim);
  229. auto &node_users = resources->manager()->node_users();
  230. auto &users = node_users[value_node];
  231. auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int> &user) -> bool {
  232. return IsPrimitiveCNode(user.first, prim);
  233. });
  234. if (cnode == users.end()) {
  235. MS_LOG(EXCEPTION) << "Fail to find cnode.";
  236. }
  237. auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
  238. auto func_graph = std::make_shared<FuncGraph>();
  239. std::vector<AnfNodePtr> outputs;
  240. outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  241. auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
  242. (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
  243. for (size_t i = 0; i < inputs_num; ++i) {
  244. // Mock params for inputs
  245. auto param = func_graph->add_parameter();
  246. // Mock derivatives for each inputs
  247. outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param}));
  248. }
  249. // mock params for out and dout
  250. (void)func_graph->add_parameter();
  251. (void)func_graph->add_parameter();
  252. func_graph->set_output(func_graph->NewCNode(outputs));
  253. return func_graph;
  254. }
  255. } // namespace ad
  256. } // namespace mindspore