You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kprim.cc 18 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. #include "ir/anf.h"
  22. #include "pybind_api/ir/primitive_py.h"
  23. #include "ir/meta_func_graph.h"
  24. #include "ir/func_graph_cloner.h"
  25. #include "ir/manager.h"
  26. #include "pipeline/jit/resource.h"
  27. #include "pipeline/jit/parse/parse.h"
  28. #include "frontend/optimizer/ad/dfunctor.h"
  29. #include "frontend/operator/ops.h"
  30. #include "frontend/operator/composite/composite.h"
  31. #include "utils/symbolic.h"
  32. #include "utils/primitive_utils.h"
  33. #include "utils/ms_context.h"
  34. #include "utils/info.h"
  35. #include "debug/trace.h"
  36. namespace mindspore {
  37. namespace ad {
  38. using PatternListType = std::initializer_list<BaseRef>;
  39. KPrim g_k_prims;
  40. FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
  41. // Set a child scope named "grad'PrimitiveName'" for the bprop function,
  42. // and add "Gradients" to the front.
  43. static const std::string gradients_scope = "Gradients/";
  44. static const std::string grad_op_child_scope_prefix = "/grad";
  45. MS_EXCEPTION_IF_NULL(prim);
  46. auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
  47. grad_op_child_scope_prefix + prim->name());
  48. ScopeGuard scope_guard(scope);
  49. py::function fn;
  50. if (prim->is_base()) {
  51. fn = GetBpropFunction(prim->name());
  52. } else {
  53. fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
  54. if (py::isinstance<py::none>(fn)) {
  55. fn = GetBpropFunction(prim->name());
  56. }
  57. }
  58. if (!fn || py::isinstance<py::none>(fn)) {
  59. MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
  60. return nullptr;
  61. }
  62. FuncGraphPtr func_graph = parse::ParsePythonCode(fn);
  63. if (func_graph == nullptr) {
  64. MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
  65. return nullptr;
  66. }
  67. auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
  68. if (bprop_flag) {
  69. func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
  70. }
  71. return func_graph;
  72. }
  73. FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
  74. static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
  75. std::string func_name = "_fprop_" + prim->name();
  76. py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name);
  77. auto func_graph = parse::ParsePythonCode(fn);
  78. MS_EXCEPTION_IF_NULL(func_graph);
  79. return BasicClone(func_graph);
  80. }
  81. MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
  82. MS_EXCEPTION_IF_NULL(prim);
  83. auto iter = bprop_registry_meta_.find(prim);
  84. if (iter != bprop_registry_meta_.end()) {
  85. return iter->second;
  86. }
  87. if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
  88. MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
  89. bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
  90. return meta;
  91. }
  92. if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
  93. MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
  94. bprop_registry_meta_[prim::kPrimMakeList] = meta;
  95. return meta;
  96. }
  97. MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
  98. }
  99. static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) {
  100. const auto &output = bprop_fg->output();
  101. MS_EXCEPTION_IF_NULL(output);
  102. auto output_cnode = output->cast<CNodePtr>();
  103. if (output_cnode != nullptr) {
  104. // If output_cnode has the form like (make_tuple, x, y).
  105. output_cnode->add_input(monad);
  106. return;
  107. }
  108. // If output is an empty tuple, create a (make_tuple, monad) as the new output.
  109. auto make_tuple = NewValueNode(prim::kPrimMakeTuple);
  110. output_cnode = bprop_fg->NewCNode({make_tuple, monad});
  111. bprop_fg->set_output(output_cnode);
  112. }
  113. // Append U or/and IO monad to output of Bprop funcgraph.
  114. static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
  115. auto effect_info = GetPrimEffectInfo(prim);
  116. if (effect_info.memory) {
  117. MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString();
  118. auto u = NewValueNode(kUMonad);
  119. u->set_abstract(kUMonad->ToAbstract());
  120. AppendMonadOutput(bprop_fg, u);
  121. }
  122. if (effect_info.io) {
  123. MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString();
  124. auto io = NewValueNode(kIOMonad);
  125. io->set_abstract(kIOMonad->ToAbstract());
  126. AppendMonadOutput(bprop_fg, io);
  127. }
  128. }
  129. FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
  130. const pipeline::ResourceBasePtr &resources) {
  131. if (!IsValueNode<Primitive>(value_node)) {
  132. MS_LOG(EXCEPTION) << "Primitive node is not valid.";
  133. }
  134. auto prim = GetValueNode<PrimitivePtr>(value_node);
  135. if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
  136. auto fprop = GetFprop(prim);
  137. fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
  138. return fprop;
  139. } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
  140. return nullptr;
  141. } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
  142. return nullptr;
  143. }
  144. FuncGraphPtr bprop_fg = nullptr;
  145. if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
  146. if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
  147. MS_LOG(EXCEPTION) << "HookBackward is not supported in graph mode.";
  148. }
  149. bprop_fg = BpropCut(value_node, resources);
  150. } else {
  151. auto iter = bprop_registry_.find(prim);
  152. if (iter != bprop_registry_.end()) {
  153. bprop_fg = iter->second;
  154. }
  155. if (bprop_fg == nullptr) {
  156. bprop_fg = GetBprop(prim);
  157. if (bprop_fg != nullptr) {
  158. // Set bprop_g graph cache
  159. bprop_registry_[prim] = bprop_fg;
  160. } else {
  161. bprop_fg = FakeBprop(value_node, resources);
  162. }
  163. }
  164. }
  165. AdjustForAutoMonad(prim, bprop_fg);
  166. auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode);
  167. if (expanded_fg == nullptr) {
  168. MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
  169. << " prim bprop function to J expanded func graph. NodeInfo: "
  170. << trace::GetDebugInfo(bprop_fg->debug_info());
  171. }
  172. return expanded_fg;
  173. }
  174. AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) {
  175. // current_primal_fg may have extra parameters like u_monad, io_monad
  176. std::vector<AnfNodePtr> extra_args;
  177. // caller had checked size() - 2 is greater than 0.
  178. auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
  179. if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) {
  180. auto current_primal_fg_param_size = current_primal_fg->parameters().size();
  181. MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so "
  182. "Insert it. Extra parameters size: "
  183. << current_primal_fg_param_size - bprop_fg_param_size;
  184. for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) {
  185. const auto &primal_node = current_primal_fg->parameters()[i];
  186. auto extra_node = bprop_fg->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), primal_node});
  187. extra_args.push_back(extra_node);
  188. MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString();
  189. }
  190. }
  191. // bprop_fg has been checked in caller
  192. if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
  193. // Set bprop output as (env, dx, dy, dz, ...)
  194. auto cbprop = bprop_fg->output()->cast<CNodePtr>();
  195. auto &inputs = cbprop->inputs();
  196. std::vector<AnfNodePtr> args;
  197. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  198. args.push_back(NewValueNode(newenv));
  199. (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
  200. if (!extra_args.empty()) {
  201. args.insert(args.end(), extra_args.cbegin(), extra_args.cend());
  202. }
  203. return NewCNode(args, bprop_fg);
  204. }
  205. // Set bprop output as (env, dx)
  206. std::string model_name("mindspore.ops.composite.multitype_ops.add_impl");
  207. std::string python_ops("_tuple_add");
  208. auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg);
  209. auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
  210. if (!extra_args.empty()) {
  211. extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple));
  212. auto extra_tuple = NewCNode(extra_args, bprop_fg);
  213. auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg);
  214. return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg);
  215. }
  216. return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg);
  217. }
  218. static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
  219. std::vector<AnfNodePtr> *const transf_args) {
  220. // bprop_fg has been checked in caller
  221. // transform except the last 2 parameters: out, dout.
  222. auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
  223. for (size_t i = 0; i < bprop_fg_param_size; ++i) {
  224. auto p = bprop_fg->parameters()[i];
  225. MS_EXCEPTION_IF_NULL(p);
  226. TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
  227. auto transf_p = outer->add_parameter();
  228. (void)mng->Replace(p, transf_p);
  229. transf_args->push_back(transf_p);
  230. }
  231. }
  232. void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
  233. const PrimitivePtr &primitive, const FuncGraphPtr &outer,
  234. std::vector<AnfNodePtr> *const transf_args) {
  235. MS_EXCEPTION_IF_NULL(mng);
  236. TransformNormalArgs(mng, bprop_fg, outer, transf_args);
  237. // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter.
  238. auto effect_info = GetPrimEffectInfo(primitive);
  239. if (effect_info.memory) {
  240. MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString();
  241. auto transf_p = outer->add_parameter();
  242. transf_args->push_back(transf_p);
  243. }
  244. if (effect_info.io) {
  245. MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString();
  246. auto transf_p = outer->add_parameter();
  247. transf_args->push_back(transf_p);
  248. }
  249. }
  250. template <typename T>
  251. void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
  252. const T &current_primal_fg, const FuncGraphPtr &outer,
  253. std::vector<AnfNodePtr> *const transf_args) {
  254. MS_EXCEPTION_IF_NULL(mng);
  255. TransformNormalArgs(mng, bprop_fg, outer, transf_args);
  256. auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
  257. // current_primal_fg may have extra parameters after AutoMonad
  258. const auto &current_primal_fg_params = current_primal_fg->parameters();
  259. if (bprop_fg_param_size < current_primal_fg_params.size()) {
  260. for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) {
  261. auto p = current_primal_fg_params[i];
  262. MS_EXCEPTION_IF_NULL(p);
  263. // extra parameters should be Monad.
  264. if (!HasAbstractMonad(p)) {
  265. continue;
  266. }
  267. MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString()
  268. << ", has extra monad parameter: " << p->DebugString()
  269. << ", abstract: " << p->abstract()->ToString();
  270. TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
  271. auto transf_p = outer->add_parameter();
  272. (void)mng->Replace(p, transf_p);
  273. transf_args->push_back(transf_p);
  274. }
  275. }
  276. if (transf_args->size() != current_primal_fg_params.size()) {
  277. MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString()
  278. << ", The number of parameter of this primal function is "
  279. << current_primal_fg_params.size() << ", but the number of parameters of bprop is "
  280. << bprop_fg_param_size;
  281. }
  282. }
  283. void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
  284. auto context = MsContext::GetInstance();
  285. MS_EXCEPTION_IF_NULL(context);
  286. bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
  287. // Skip checking if check_bprop not set
  288. if (!check_bprop_flag) {
  289. return;
  290. }
  291. // bprop_fg has been checked in caller
  292. auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
  293. MS_EXCEPTION_IF_NULL(check_bprop_class);
  294. auto check_bprop =
  295. bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
  296. std::vector<AnfNodePtr> inputs;
  297. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  298. inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2);
  299. AnfNodePtr params = bprop_fg->NewCNode(inputs);
  300. inputs.clear();
  301. inputs.push_back(check_bprop);
  302. inputs.push_back(bprop_fg->output());
  303. inputs.push_back(params);
  304. AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
  305. bprop_fg->set_output(bprop_out);
  306. }
  307. FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) {
  308. MS_EXCEPTION_IF_NULL(bprop_fg);
  309. // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph.
  310. // current_primal_fg is specalized and AutoMoaded primal_fg;
  311. auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph();
  312. auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr);
  313. if (expanded_fg == nullptr) {
  314. MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString()
  315. << " Cell bprop function to K expanded func graph. NodeInfo: "
  316. << trace::GetDebugInfo(primal_fg->debug_info());
  317. }
  318. return expanded_fg;
  319. }
  320. FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  321. auto prim = GetValueNode<PrimitivePtr>(value_node);
  322. MS_EXCEPTION_IF_NULL(prim);
  323. auto &node_users = resources->manager()->node_users();
  324. auto &users = node_users[value_node];
  325. auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
  326. return IsPrimitiveCNode(user.first, prim);
  327. });
  328. if (cnode == users.end()) {
  329. MS_LOG(EXCEPTION) << "Fail to find cnode.";
  330. }
  331. auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
  332. auto func_graph = std::make_shared<FuncGraph>();
  333. std::vector<AnfNodePtr> outputs;
  334. auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object());
  335. bprop_cut->CopyHookFunction(prim);
  336. auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
  337. if (cell_id != "") {
  338. (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
  339. (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
  340. }
  341. outputs.push_back(NewValueNode(bprop_cut));
  342. for (size_t i = 0; i < inputs_num; ++i) {
  343. auto param = func_graph->add_parameter();
  344. outputs.push_back(param);
  345. }
  346. auto p1 = func_graph->add_parameter();
  347. auto p2 = func_graph->add_parameter();
  348. outputs.push_back(p1);
  349. outputs.push_back(p2);
  350. func_graph->set_output(func_graph->NewCNode(outputs));
  351. return func_graph;
  352. }
  353. FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  354. auto prim = value_node->value()->cast<PrimitivePtr>();
  355. MS_EXCEPTION_IF_NULL(prim);
  356. auto &node_users = resources->manager()->node_users();
  357. auto &users = node_users[value_node];
  358. auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
  359. return IsPrimitiveCNode(user.first, prim);
  360. });
  361. if (cnode == users.end()) {
  362. MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
  363. }
  364. auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
  365. auto effect_info = GetPrimEffectInfo(prim);
  366. // Don't add U or IO monad parameters as it will be added later.
  367. size_t monad_params_size = 0;
  368. if (effect_info.memory) {
  369. monad_params_size++;
  370. }
  371. if (effect_info.io) {
  372. monad_params_size++;
  373. }
  374. if (inputs_num < monad_params_size) {
  375. MS_LOG(EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size
  376. << ", but the CNode is: " << cnode->first->DebugString();
  377. }
  378. inputs_num -= monad_params_size;
  379. auto func_graph = std::make_shared<FuncGraph>();
  380. std::vector<AnfNodePtr> outputs;
  381. outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  382. auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
  383. (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
  384. for (size_t i = 0; i < inputs_num; ++i) {
  385. // Mock params for inputs
  386. auto param = func_graph->add_parameter();
  387. // Mock derivatives for each inputs
  388. outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param}));
  389. }
  390. // mock params for out and dout
  391. (void)func_graph->add_parameter();
  392. (void)func_graph->add_parameter();
  393. func_graph->set_output(func_graph->NewCNode(outputs));
  394. return func_graph;
  395. }
  396. } // namespace ad
  397. } // namespace mindspore