You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kprim.cc 12 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. #include "ir/anf.h"
  22. #include "pybind_api/ir/primitive_py.h"
  23. #include "ir/meta_func_graph.h"
  24. #include "ir/func_graph_cloner.h"
  25. #include "ir/manager.h"
  26. #include "pipeline/jit/resource.h"
  27. #include "pipeline/jit/parse/parse.h"
  28. #include "frontend/optimizer/ad/dfunctor.h"
  29. #include "frontend/operator/ops.h"
  30. #include "frontend/operator/composite/composite.h"
  31. #include "utils/symbolic.h"
  32. #include "utils/primitive_utils.h"
  33. #include "utils/ms_context.h"
  34. #include "utils/info.h"
  35. #include "debug/trace.h"
  36. namespace mindspore {
  37. namespace ad {
  38. using PatternListType = std::initializer_list<BaseRef>;
  39. KPrim g_k_prims;
  40. FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
  41. // Set a child scope named "grad'PrimitiveName'" for the bprop function,
  42. // and add "Gradients" to the front.
  43. static const std::string gradients_scope = "Gradients/";
  44. static const std::string grad_op_child_scope_prefix = "/grad";
  45. MS_EXCEPTION_IF_NULL(prim);
  46. auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
  47. grad_op_child_scope_prefix + prim->name());
  48. ScopeGuard scope_guard(scope);
  49. py::function fn;
  50. if (prim->is_base()) {
  51. fn = GetBpropFunction(prim->name());
  52. } else {
  53. fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
  54. if (py::isinstance<py::none>(fn)) {
  55. fn = GetBpropFunction(prim->name());
  56. }
  57. }
  58. if (!fn || py::isinstance<py::none>(fn)) {
  59. MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
  60. return nullptr;
  61. }
  62. FuncGraphPtr func_graph = parse::ParsePythonCode(fn);
  63. if (func_graph == nullptr) {
  64. MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
  65. return nullptr;
  66. }
  67. return func_graph;
  68. }
  69. FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
  70. static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
  71. std::string func_name = "_fprop_" + prim->name();
  72. py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name);
  73. auto func_graph = parse::ParsePythonCode(fn);
  74. MS_EXCEPTION_IF_NULL(func_graph);
  75. return BasicClone(func_graph);
  76. }
  77. MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
  78. MS_EXCEPTION_IF_NULL(prim);
  79. auto iter = bprop_registry_meta_.find(prim);
  80. if (iter != bprop_registry_meta_.end()) {
  81. return iter->second;
  82. }
  83. if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
  84. MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
  85. bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
  86. return meta;
  87. }
  88. if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
  89. MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
  90. bprop_registry_meta_[prim::kPrimMakeList] = meta;
  91. return meta;
  92. }
  93. MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
  94. }
  95. FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  96. if (!IsValueNode<Primitive>(value_node)) {
  97. MS_LOG(EXCEPTION) << "Primitive node is not valid.";
  98. }
  99. auto prim = GetValueNode<PrimitivePtr>(value_node);
  100. if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
  101. auto fprop = GetFprop(prim);
  102. fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
  103. return fprop;
  104. } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
  105. return nullptr;
  106. } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
  107. return nullptr;
  108. }
  109. FuncGraphPtr bprop_fg = nullptr;
  110. if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
  111. if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
  112. MS_LOG(EXCEPTION) << "HookBackward is not supported in graph mode.";
  113. }
  114. bprop_fg = BpropCut(value_node, resources);
  115. } else {
  116. auto iter = bprop_registry_.find(prim);
  117. if (iter != bprop_registry_.end()) {
  118. bprop_fg = iter->second;
  119. }
  120. if (bprop_fg == nullptr) {
  121. bprop_fg = GetBprop(prim);
  122. if (bprop_fg != nullptr) {
  123. // Set bprop_g graph cache
  124. bprop_registry_[prim] = bprop_fg;
  125. } else {
  126. bprop_fg = FakeBprop(value_node, resources);
  127. }
  128. }
  129. }
  130. auto expanded_fg = BpropToK(prim, bprop_fg);
  131. if (expanded_fg == nullptr) {
  132. MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
  133. << " prim bprop function to J expanded func graph. NodeInfo: "
  134. << trace::GetDebugInfo(bprop_fg->debug_info());
  135. }
  136. return expanded_fg;
  137. }
  138. AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg) {
  139. // bprop_fg has been checked in caller
  140. if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
  141. // Set bprop output as (env, dx, dy, dz, ...)
  142. auto cbprop = bprop_fg->output()->cast<CNodePtr>();
  143. auto &inputs = cbprop->inputs();
  144. std::vector<AnfNodePtr> args;
  145. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  146. args.push_back(NewValueNode(newenv));
  147. (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
  148. return NewCNode(args, bprop_fg);
  149. }
  150. // Set bprop output as (env, dx)
  151. std::string model_name("mindspore.ops.composite.multitype_ops.add_impl");
  152. std::string python_ops("_tuple_add");
  153. auto tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg);
  154. return NewCNode({NewValueNode(prim::GetPythonOps(python_ops, model_name)), tuple, bprop_fg->output()}, bprop_fg);
  155. }
  156. void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
  157. std::vector<AnfNodePtr> *const transf_args) {
  158. MS_EXCEPTION_IF_NULL(mng);
  159. // bprop_fg has been checked in caller
  160. // transform except the last 2 parameters: out, dout.
  161. for (size_t i = 0; i < bprop_fg->parameters().size() - 2; ++i) {
  162. auto p = bprop_fg->parameters()[i];
  163. MS_EXCEPTION_IF_NULL(p);
  164. TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(p->debug_info()));
  165. auto transf_p = outer->add_parameter();
  166. TraceManager::EndTrace();
  167. (void)mng->Replace(p, transf_p);
  168. transf_args->push_back(transf_p);
  169. }
  170. }
  171. void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
  172. auto context = MsContext::GetInstance();
  173. MS_EXCEPTION_IF_NULL(context);
  174. bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
  175. // Skip checking if check_bprop not set
  176. if (!check_bprop_flag) {
  177. return;
  178. }
  179. // bprop_fg has been checked in caller
  180. auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
  181. MS_EXCEPTION_IF_NULL(check_bprop_class);
  182. auto check_bprop =
  183. bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
  184. std::vector<AnfNodePtr> inputs;
  185. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  186. inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2);
  187. AnfNodePtr params = bprop_fg->NewCNode(inputs);
  188. inputs.clear();
  189. inputs.push_back(check_bprop);
  190. inputs.push_back(bprop_fg->output());
  191. inputs.push_back(params);
  192. AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
  193. bprop_fg->set_output(bprop_out);
  194. }
  195. FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
  196. MS_EXCEPTION_IF_NULL(bprop_fg);
  197. auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph();
  198. auto expanded_fg = BpropToK(fprop_fg, bprop_fg);
  199. if (expanded_fg == nullptr) {
  200. MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString()
  201. << " Cell bprop function to K expanded func graph. NodeInfo: "
  202. << trace::GetDebugInfo(fprop_fg->debug_info());
  203. }
  204. return expanded_fg;
  205. }
  206. FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  207. auto prim = GetValueNode<PrimitivePtr>(value_node);
  208. MS_EXCEPTION_IF_NULL(prim);
  209. auto &node_users = resources->manager()->node_users();
  210. auto &users = node_users[value_node];
  211. auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int> &user) -> bool {
  212. return IsPrimitiveCNode(user.first, prim);
  213. });
  214. if (cnode == users.end()) {
  215. MS_LOG(EXCEPTION) << "Fail to find cnode.";
  216. }
  217. auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
  218. auto func_graph = std::make_shared<FuncGraph>();
  219. std::vector<AnfNodePtr> outputs;
  220. auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object());
  221. bprop_cut->CopyHookFunction(prim);
  222. auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
  223. if (cell_id != "") {
  224. (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
  225. (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
  226. }
  227. outputs.push_back(NewValueNode(bprop_cut));
  228. for (size_t i = 0; i < inputs_num; ++i) {
  229. auto param = func_graph->add_parameter();
  230. outputs.push_back(param);
  231. }
  232. auto p1 = func_graph->add_parameter();
  233. auto p2 = func_graph->add_parameter();
  234. outputs.push_back(p1);
  235. outputs.push_back(p2);
  236. func_graph->set_output(func_graph->NewCNode(outputs));
  237. return func_graph;
  238. }
  239. FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  240. auto prim = value_node->value()->cast<PrimitivePtr>();
  241. MS_EXCEPTION_IF_NULL(prim);
  242. auto &node_users = resources->manager()->node_users();
  243. auto &users = node_users[value_node];
  244. auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int> &user) -> bool {
  245. return IsPrimitiveCNode(user.first, prim);
  246. });
  247. if (cnode == users.end()) {
  248. MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
  249. }
  250. auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
  251. auto func_graph = std::make_shared<FuncGraph>();
  252. std::vector<AnfNodePtr> outputs;
  253. outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  254. auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
  255. (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
  256. for (size_t i = 0; i < inputs_num; ++i) {
  257. // Mock params for inputs
  258. auto param = func_graph->add_parameter();
  259. // Mock derivatives for each inputs
  260. outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param}));
  261. }
  262. // mock params for out and dout
  263. (void)func_graph->add_parameter();
  264. (void)func_graph->add_parameter();
  265. func_graph->set_output(func_graph->NewCNode(outputs));
  266. return func_graph;
  267. }
  268. } // namespace ad
  269. } // namespace mindspore