You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kprim.cc 33 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef _WIN32
  19. #include <dirent.h>
  20. #endif
  21. #include <memory>
  22. #include <string>
  23. #include <utility>
  24. #include "ir/anf.h"
  25. #include "pybind_api/ir/primitive_py.h"
  26. #include "ir/meta_func_graph.h"
  27. #include "ir/func_graph_cloner.h"
  28. #include "ir/manager.h"
  29. #include "pipeline/jit/resource.h"
  30. #include "pipeline/jit/parse/parse.h"
  31. #include "pipeline/jit/parse/resolve.h"
  32. #include "frontend/optimizer/ad/dfunctor.h"
  33. #include "frontend/operator/ops.h"
  34. #include "frontend/operator/composite/composite.h"
  35. #include "utils/utils.h"
  36. #include "utils/symbolic.h"
  37. #include "utils/primitive_utils.h"
  38. #include "utils/ms_context.h"
  39. #include "utils/info.h"
  40. #include "debug/trace.h"
  41. #include "debug/common.h"
  42. #include "debug/dump_proto.h"
  43. #include "mindspore/core/load_mindir/load_model.h"
  44. #include "utils/system/sha256.h"
  45. #include "utils/file_utils.h"
  46. namespace mindspore {
  47. namespace ad {
  48. KPrim g_k_prims;
  49. namespace {
  50. constexpr char kBpropMindIRSuffix[] = "_bprop.mindir";
  51. constexpr char kBpropMindIRDir[] = "/../bprop_mindir/";
  52. constexpr char serializable_bprop_ops[] = "serializable_bprop_ops";
  53. constexpr char bprop_mindir_module[] = "mindspore.ops.bprop_mindir";
  54. #ifndef _WIN32
  55. std::string GetBpropDir() {
  56. static std::string bprop_dir;
  57. if (bprop_dir.empty()) {
  58. py::module mod = py::module::import("mindspore.ops._grad");
  59. auto grad_file_path = mod.attr("__file__").cast<std::string>();
  60. bprop_dir = grad_file_path.substr(0, grad_file_path.find_last_of('/'));
  61. }
  62. return bprop_dir;
  63. }
  64. bool BpropMindirDirExists() {
  65. auto bprop_mindir_dir = GetBpropDir() + kBpropMindIRDir;
  66. DIR *dir = opendir(bprop_mindir_dir.c_str());
  67. if (dir != nullptr) {
  68. if (closedir(dir) == -1) {
  69. MS_LOG(WARNING) << "The bprop mindir dir \"" << bprop_mindir_dir << "\" close failed!";
  70. }
  71. return true;
  72. }
  73. MS_LOG(ERROR) << "Open bprop mindir dir \"" << bprop_mindir_dir << "\" failed." << ErrnoToString(errno);
  74. return false;
  75. }
  76. // Get the serializable bprop list from the module mindspore.ops.bprop_mindir in python.
  77. mindspore::HashSet<std::string> GetSerializableBpropList() {
  78. mindspore::HashSet<std::string> serializable_bprop_list;
  79. if (!BpropMindirDirExists()) {
  80. return serializable_bprop_list;
  81. }
  82. py::module mod = py::module::import(bprop_mindir_module);
  83. py::object serializable_bprop_ops_attr = mod.attr(serializable_bprop_ops);
  84. if (!py::isinstance<py::list>(serializable_bprop_ops_attr)) {
  85. MS_LOG(WARNING) << "Can not get the the serializable bprop ops list from python, it is not a python list.";
  86. return serializable_bprop_list;
  87. }
  88. auto ops_list = serializable_bprop_ops_attr.cast<py::list>();
  89. for (size_t i = 0; i < ops_list.size(); ++i) {
  90. auto prim_adapter = ops_list[i].cast<PrimitivePyAdapterPtr>();
  91. if (prim_adapter == nullptr) {
  92. MS_LOG(EXCEPTION) << "The python obj in serializable bprop list should be a Primitive, but it is "
  93. << py::str(ops_list[i]);
  94. }
  95. serializable_bprop_list.insert(prim_adapter->name());
  96. }
  97. return serializable_bprop_list;
  98. }
  99. bool IsSerializableBprop(const std::string &prim_name) {
  100. static mindspore::HashSet<std::string> serializable_bprop_list = GetSerializableBpropList();
  101. return std::any_of(serializable_bprop_list.begin(), serializable_bprop_list.end(),
  102. [&prim_name](const std::string &serializable_bprop_prim_name) {
  103. return prim_name == serializable_bprop_prim_name;
  104. });
  105. }
  106. void GetFilesHash(const std::string &dir, mindspore::HashMap<std::string, std::string> *bprop_hash_to_file) {
  107. if (dir.empty()) {
  108. MS_LOG(ERROR) << "The directory path is empty.";
  109. return;
  110. }
  111. struct stat s {};
  112. int ret = stat(dir.c_str(), &s);
  113. if (ret != 0) {
  114. MS_LOG(ERROR) << "stat dir \"" << dir << "\" failed, ret is : " << ret;
  115. return;
  116. }
  117. if (!S_ISDIR(s.st_mode)) {
  118. MS_LOG(ERROR) << "The path \"" << dir << "\" is not a directory.";
  119. return;
  120. }
  121. DIR *open_dir = opendir(dir.c_str());
  122. if (open_dir == nullptr) {
  123. MS_LOG(ERROR) << "open dir " << dir.c_str() << " failed";
  124. return;
  125. }
  126. struct dirent *filename;
  127. while ((filename = readdir(open_dir)) != nullptr) {
  128. std::string d_name = std::string(filename->d_name);
  129. if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) {
  130. continue;
  131. }
  132. auto real_path = std::string(dir) + "/" + filename->d_name;
  133. (void)bprop_hash_to_file->emplace(system::sha256::GetHashFromFile(real_path), real_path);
  134. }
  135. closedir(open_dir);
  136. }
  137. mindspore::HashMap<std::string, std::string> GetAllBpropFileHash() {
  138. mindspore::HashMap<std::string, std::string> bprop_hash_to_file;
  139. auto bprop_dir = GetBpropDir();
  140. auto realpath = FileUtils::GetRealPath(common::SafeCStr(bprop_dir));
  141. if (!realpath.has_value()) {
  142. MS_LOG(EXCEPTION) << "Get real path of bprop dir failed. path=" << bprop_dir;
  143. }
  144. GetFilesHash(realpath.value(), &bprop_hash_to_file);
  145. return bprop_hash_to_file;
  146. }
  147. bool CheckBpropHash(const std::string &hash) {
  148. // Get every hash of all the bprop files.
  149. static auto bprop_hash_to_file = GetAllBpropFileHash();
  150. if (bprop_hash_to_file.find(hash) != bprop_hash_to_file.end()) {
  151. return true;
  152. }
  153. std::string bprop_dir = GetBpropDir();
  154. auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
  155. MS_LOG(ERROR) << "The bprop mindir files are not up to date. Please run the " << bprop_mindir_path
  156. << "generate_mindir.py to generate new mindir files.\n"
  157. << "bprop_fg hash: " << hash << "\n"
  158. << "bprop hash list: \n";
  159. for (const auto &iter : bprop_hash_to_file) {
  160. MS_LOG(ERROR) << iter.first;
  161. }
  162. return false;
  163. }
  164. FuncGraphPtr ImportBpropFromMindIR(const PrimitivePtr &prim) {
  165. MS_EXCEPTION_IF_NULL(prim);
  166. std::string bprop_dir = GetBpropDir();
  167. auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
  168. std::optional<std::string> bprop_mindir_realpath =
  169. FileUtils::GetRealPath(common::SafeCStr(bprop_mindir_path + prim->name() + kBpropMindIRSuffix));
  170. bool bprop_cache_file_exists = bprop_mindir_realpath.has_value() && Common::FileExists(bprop_mindir_realpath.value());
  171. if (!bprop_cache_file_exists) {
  172. return nullptr;
  173. }
  174. MindIRLoader mindir_loader;
  175. auto bprop_fg = mindir_loader.LoadMindIR(bprop_mindir_realpath.value());
  176. if (!CheckBpropHash(bprop_fg->bprop_hash())) {
  177. MS_LOG(EXCEPTION) << "The bprop mindir files are not up to date.";
  178. }
  179. return bprop_fg;
  180. }
  181. void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_graph) {
  182. MS_EXCEPTION_IF_NULL(prim);
  183. std::string bprop_dir = GetBpropDir();
  184. auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
  185. std::optional<std::string> bprop_mindir_realpath =
  186. Common::CreatePrefixPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix, true);
  187. if (!bprop_mindir_realpath.has_value()) {
  188. MS_LOG(ERROR) << "Failed to get the realpath of bprop mindir: " << bprop_mindir_path << prim->name()
  189. << kBpropMindIRSuffix;
  190. return;
  191. }
  192. std::ofstream fout(bprop_mindir_realpath.value());
  193. if (!fout.is_open()) {
  194. MS_LOG(ERROR) << "Open cache file '" << bprop_mindir_realpath.value() << "' failed!" << ErrnoToString(errno);
  195. return;
  196. }
  197. ModelProtoPtr fg_model = GetBinaryProto(func_graph);
  198. if (fg_model == nullptr) {
  199. MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed.";
  200. fout.close();
  201. return;
  202. }
  203. if (!fg_model->SerializeToOstream(&fout)) {
  204. MS_LOG(ERROR) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \""
  205. << bprop_mindir_realpath.value() << "\".";
  206. fout.close();
  207. return;
  208. }
  209. fout.close();
  210. ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR);
  211. }
  212. AnfNodePtr GetPythonOps(const FuncGraphPtr &fg, const AnfNodePtr &origin_node, const PrimitivePtr &prim) {
  213. MS_EXCEPTION_IF_NULL(fg);
  214. MS_EXCEPTION_IF_NULL(origin_node);
  215. MS_EXCEPTION_IF_NULL(prim);
  216. // DoSignaturePrimitive to the pair of primitive name and module name.
  217. static mindspore::HashMap<std::string, std::pair<std::string, std::string>> python_ops{
  218. {"S-Prim-zeros_like_leaf", {"zeros_like", ""}},
  219. {"S-Prim-getitem", {"getitem", "mindspore.ops.composite.multitype_ops.getitem_impl"}}};
  220. auto iter = python_ops.find(prim->name());
  221. if (iter == python_ops.end()) {
  222. return nullptr;
  223. }
  224. ValuePtr python_ops_value;
  225. if (!iter->second.second.empty()) {
  226. python_ops_value = prim::GetPythonOps(iter->second.first, iter->second.second);
  227. } else {
  228. python_ops_value = prim::GetPythonOps(iter->second.first);
  229. }
  230. auto origin_cnode = origin_node->cast<CNodePtr>();
  231. MS_EXCEPTION_IF_NULL(origin_cnode);
  232. auto &origin_inputs = origin_cnode->inputs();
  233. std::vector<AnfNodePtr> new_inputs{NewValueNode(python_ops_value)};
  234. (void)std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs));
  235. return fg->NewCNode(new_inputs);
  236. }
  237. // Replace the nodes whose python obj of primitive is needed in the renormalize process,
  238. // with the new created python ops, such as zeros_like.
  239. void ReplacePythonOps(const FuncGraphPtr &fg) {
  240. MS_EXCEPTION_IF_NULL(fg);
  241. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(fg->get_return());
  242. for (const auto &node : all_nodes) {
  243. MS_EXCEPTION_IF_NULL(node);
  244. if (!node->isa<CNode>()) {
  245. continue;
  246. }
  247. auto cnode = node->cast<CNodePtr>();
  248. for (size_t i = 0; i < cnode->size(); ++i) {
  249. auto prim = GetCNodePrimitive(cnode->input(i));
  250. if (prim == nullptr) {
  251. continue;
  252. }
  253. auto new_input = GetPythonOps(fg, cnode->input(i), prim);
  254. if (new_input == nullptr) {
  255. continue;
  256. }
  257. cnode->set_input(i, new_input);
  258. }
  259. }
  260. }
  261. std::string GetBpropFileHash(const py::function &fn) {
  262. static auto bprop_hash_to_file = GetAllBpropFileHash();
  263. // Get the file where the bprop function is defined.
  264. auto filename = fn.attr("__code__").attr("co_filename").cast<std::string>();
  265. // Get the hash of the file.
  266. auto it = std::find_if(bprop_hash_to_file.begin(), bprop_hash_to_file.end(),
  267. [&filename](const auto &item) { return item.second == filename; });
  268. if (it != bprop_hash_to_file.end()) {
  269. return it->first;
  270. }
  271. return "";
  272. }
  273. #endif
  274. } // namespace
  275. #ifndef _WIN32
  276. // Given a python primitive, export a mindir file from the bprop defined in python.
  277. void KPrim::ExportBpropMindir(const py::object &obj) {
  278. auto prim_adapter = obj.cast<PrimitivePyAdapterPtr>();
  279. if (prim_adapter == nullptr) {
  280. MS_LOG(EXCEPTION) << "The python obj to be exported to bprop mindir should be a Primitive, but it is "
  281. << py::str(obj);
  282. }
  283. auto prim = prim_adapter->attached_primitive();
  284. if (prim == nullptr) {
  285. prim = std::make_shared<PrimitivePy>(obj, prim_adapter);
  286. prim_adapter->set_attached_primitive(prim);
  287. }
  288. // Get the bprop function from python.
  289. py::function fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
  290. if (py::isinstance<py::none>(fn)) {
  291. fn = GetBpropFunction(prim->name());
  292. }
  293. if (!fn || py::isinstance<py::none>(fn)) {
  294. MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
  295. }
  296. std::string bprop_hash = GetBpropFileHash(fn);
  297. if (bprop_hash.empty()) {
  298. MS_LOG(EXCEPTION) << "Fail to get the file hash for " << prim->name();
  299. }
  300. // Parse and resolve.
  301. auto func_graph = parse::ParsePythonCode(fn);
  302. if (func_graph == nullptr) {
  303. MS_LOG(EXCEPTION) << "Fail to parse bprop function for " << prim->name() << ".";
  304. }
  305. auto res = std::make_shared<pipeline::Resource>();
  306. (void)parse::ResolveFuncGraph(func_graph, res);
  307. func_graph->set_bprop_hash(bprop_hash);
  308. ExportBpropToMindIR(prim, func_graph);
  309. }
  310. #endif
  311. FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
  312. // Set a child scope named "grad'PrimitiveName'" for the bprop function,
  313. // and add "Gradients" to the front.
  314. static const std::string gradients_scope = "Gradients/";
  315. static const std::string grad_op_child_scope_prefix = "/grad";
  316. MS_EXCEPTION_IF_NULL(prim);
  317. auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
  318. grad_op_child_scope_prefix + prim->name());
  319. ScopeGuard scope_guard(scope);
  320. // Firstly we get bprop from mindir. If failed, parse the python function registered.
  321. FuncGraphPtr func_graph = nullptr;
  322. #ifndef _WIN32
  323. bool serializable = IsSerializableBprop(prim->name());
  324. if (serializable) {
  325. func_graph = ImportBpropFromMindIR(prim);
  326. if (func_graph != nullptr) {
  327. ReplacePythonOps(func_graph);
  328. if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP)) {
  329. func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
  330. }
  331. return func_graph;
  332. }
  333. }
  334. #endif
  335. py::function fn;
  336. if (prim->is_base()) {
  337. fn = GetBpropFunction(prim->name());
  338. } else {
  339. fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
  340. if (py::isinstance<py::none>(fn)) {
  341. fn = GetBpropFunction(prim->name());
  342. }
  343. }
  344. if (!fn || py::isinstance<py::none>(fn)) {
  345. MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
  346. return nullptr;
  347. }
  348. func_graph = parse::ParsePythonCode(fn);
  349. if (func_graph == nullptr) {
  350. MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
  351. return nullptr;
  352. }
  353. auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
  354. if (bprop_flag) {
  355. func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
  356. }
  357. pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
  358. (void)parse::ResolveFuncGraph(func_graph, res);
  359. #ifndef _WIN32
  360. // Check whether the bprop needs to be exported.
  361. if (serializable) {
  362. std::string bprop_hash = GetBpropFileHash(fn);
  363. if (!bprop_hash.empty()) {
  364. func_graph->set_bprop_hash(bprop_hash);
  365. ExportBpropToMindIR(prim, func_graph);
  366. }
  367. }
  368. #endif
  369. return func_graph;
  370. }
  371. FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) {
  372. FuncGraphPtr bprop_fg = nullptr;
  373. auto iter = bprop_registry_.find(prim);
  374. if (iter != bprop_registry_.end()) {
  375. bprop_fg = iter->second;
  376. }
  377. if (bprop_fg == nullptr) {
  378. bprop_fg = GetBprop(prim);
  379. if (bprop_fg != nullptr) {
  380. // Set bprop_g graph cache
  381. bprop_registry_[prim] = bprop_fg;
  382. }
  383. }
  384. return bprop_fg;
  385. }
  386. FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
  387. static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
  388. std::string func_name = "_fprop_" + prim->name();
  389. py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name);
  390. auto func_graph = parse::ParsePythonCode(fn);
  391. MS_EXCEPTION_IF_NULL(func_graph);
  392. return BasicClone(func_graph);
  393. }
  394. MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
  395. MS_EXCEPTION_IF_NULL(prim);
  396. auto iter = bprop_registry_meta_.find(prim);
  397. if (iter != bprop_registry_meta_.end()) {
  398. return iter->second;
  399. }
  400. if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
  401. MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
  402. bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
  403. return meta;
  404. }
  405. if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
  406. MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
  407. bprop_registry_meta_[prim::kPrimMakeList] = meta;
  408. return meta;
  409. }
  410. MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
  411. }
  412. static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) {
  413. const auto &output = bprop_fg->output();
  414. MS_EXCEPTION_IF_NULL(output);
  415. auto output_cnode = output->cast<CNodePtr>();
  416. if (output_cnode != nullptr) {
  417. // If output_cnode has the form like (make_tuple, x, y).
  418. output_cnode->add_input(monad);
  419. return;
  420. }
  421. // If output is an empty tuple, create a (make_tuple, monad) as the new output.
  422. auto make_tuple = NewValueNode(prim::kPrimMakeTuple);
  423. output_cnode = bprop_fg->NewCNode({make_tuple, monad});
  424. bprop_fg->set_output(output_cnode);
  425. }
  426. // Append U or/and IO monad to output of Bprop funcgraph.
  427. static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
  428. auto effect_info = GetPrimEffectInfo(prim);
  429. if (effect_info.memory) {
  430. MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString();
  431. auto u = NewValueNode(kUMonad);
  432. u->set_abstract(kUMonad->ToAbstract());
  433. AppendMonadOutput(bprop_fg, u);
  434. }
  435. if (effect_info.io) {
  436. MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString();
  437. auto io = NewValueNode(kIOMonad);
  438. io->set_abstract(kIOMonad->ToAbstract());
  439. AppendMonadOutput(bprop_fg, io);
  440. }
  441. }
  442. std::vector<NodeDebugInfoPtr> GeneratePrimalDebugInfo(const ValueNodePtr &value_node,
  443. const pipeline::ResourceBasePtr &resources) {
  444. std::vector<NodeDebugInfoPtr> primal_debug_infos;
  445. if (resources != nullptr) {
  446. auto manager = resources->manager();
  447. auto &users = manager->node_users()[value_node];
  448. for (auto user_iter = users.begin(); user_iter != users.end(); ++user_iter) {
  449. primal_debug_infos.push_back(user_iter->first->debug_info());
  450. }
  451. }
  452. return primal_debug_infos;
  453. }
  454. FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
  455. const pipeline::ResourceBasePtr &resources) {
  456. if (!IsValueNode<Primitive>(value_node)) {
  457. MS_LOG(EXCEPTION) << "Primitive node is not valid.";
  458. }
  459. auto prim = GetValueNode<PrimitivePtr>(value_node);
  460. if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
  461. auto fprop = GetFprop(prim);
  462. fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
  463. return fprop;
  464. } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
  465. return nullptr;
  466. } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
  467. return nullptr;
  468. }
  469. FuncGraphPtr bprop_fg = nullptr;
  470. if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
  471. if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
  472. MS_LOG(EXCEPTION)
  473. << "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative mode.\n"
  474. << trace::GetDebugInfo(cnode->debug_info());
  475. }
  476. bprop_fg = BpropCut(value_node, resources);
  477. } else {
  478. auto iter = bprop_registry_.find(prim);
  479. if (iter != bprop_registry_.end()) {
  480. bprop_fg = iter->second;
  481. }
  482. if (bprop_fg == nullptr) {
  483. bprop_fg = GetBprop(prim, resources);
  484. if (bprop_fg != nullptr) {
  485. // Set bprop_g graph cache
  486. bprop_registry_[prim] = bprop_fg;
  487. } else {
  488. bprop_fg = FakeBprop(value_node, resources);
  489. }
  490. }
  491. }
  492. AdjustForAutoMonad(prim, bprop_fg);
  493. mindspore::HashMap<std::string, ValuePtr> primal_attrs;
  494. std::vector<NodeDebugInfoPtr> primal_debug_infos = GeneratePrimalDebugInfo(value_node, resources);
  495. if (cnode != nullptr) {
  496. primal_attrs = cnode->primal_attrs();
  497. const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
  498. primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
  499. }
  500. auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
  501. if (expanded_fg == nullptr) {
  502. MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
  503. << " prim bprop function to J expanded func graph. NodeInfo: "
  504. << trace::GetDebugInfo(bprop_fg->debug_info());
  505. }
  506. if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
  507. // Inline fprop_switch before renormalize;
  508. expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true);
  509. MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString();
  510. }
  511. return expanded_fg;
  512. }
  513. AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) {
  514. // current_primal_fg may have extra parameters like u_monad, io_monad
  515. std::vector<AnfNodePtr> extra_args;
  516. // caller had checked size() - 2 is greater than 0.
  517. auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
  518. if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) {
  519. auto current_primal_fg_param_size = current_primal_fg->parameters().size();
  520. MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so "
  521. "Insert it. Extra parameters size: "
  522. << current_primal_fg_param_size - bprop_fg_param_size;
  523. for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) {
  524. const auto &primal_node = current_primal_fg->parameters()[i];
  525. AnfNodePtr extra_node;
  526. // Simplify zeros_like(primal_node) to U or IO, so extra_node in bprop_fg will not refer to primal_node
  527. // as a free variable of primal_graph.
  528. // Notes: if the implementation of zeros_like changes, here too.
  529. if (HasAbstractUMonad(primal_node)) {
  530. extra_node = NewValueNode(kUMonad);
  531. } else if (HasAbstractIOMonad(primal_node)) {
  532. extra_node = NewValueNode(kIOMonad);
  533. } else {
  534. MS_EXCEPTION(TypeError)
  535. << "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well "
  536. "as the 'out' and 'dout'.\n"
  537. << trace::GetDebugInfo(bprop_fg->debug_info());
  538. }
  539. extra_args.push_back(extra_node);
  540. MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString();
  541. }
  542. }
  543. // bprop_fg has been checked in caller
  544. if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
  545. // Set bprop output as (env, dx, dy, dz, ...)
  546. auto cbprop = bprop_fg->output()->cast<CNodePtr>();
  547. auto &inputs = cbprop->inputs();
  548. std::vector<AnfNodePtr> args;
  549. args.push_back(NewValueNode(prim::kPrimMakeTuple));
  550. args.push_back(NewValueNode(newenv));
  551. (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
  552. if (!extra_args.empty()) {
  553. args.insert(args.end(), extra_args.cbegin(), extra_args.cend());
  554. }
  555. return NewCNode(args, bprop_fg);
  556. }
  557. // Set bprop output as (env, dx)
  558. std::string model_name("mindspore.ops.composite.multitype_ops.add_impl");
  559. std::string python_ops("_tuple_add");
  560. auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg);
  561. auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
  562. if (!extra_args.empty()) {
  563. extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple));
  564. auto extra_tuple = NewCNode(extra_args, bprop_fg);
  565. auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg);
  566. return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg);
  567. }
  568. return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg);
  569. }
  570. static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
  571. std::vector<AnfNodePtr> *const transf_args) {
  572. // bprop_fg has been checked in caller
  573. // transform except the last 2 parameters: out, dout.
  574. const size_t last_parameter_sizes = 2;
  575. auto bprop_fg_param_size = bprop_fg->parameters().size() - last_parameter_sizes;
  576. for (size_t i = 0; i < bprop_fg_param_size; ++i) {
  577. auto p = bprop_fg->parameters()[i];
  578. MS_EXCEPTION_IF_NULL(p);
  579. TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
  580. auto transf_p = outer->add_parameter();
  581. (void)mng->Replace(p, transf_p);
  582. transf_args->push_back(transf_p);
  583. }
  584. }
  585. void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
  586. const PrimitivePtr &primitive, const FuncGraphPtr &outer,
  587. std::vector<AnfNodePtr> *const transf_args) {
  588. MS_EXCEPTION_IF_NULL(mng);
  589. TransformNormalArgs(mng, bprop_fg, outer, transf_args);
  590. // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter.
  591. auto effect_info = GetPrimEffectInfo(primitive);
  592. if (effect_info.memory) {
  593. MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString();
  594. auto transf_p = outer->add_parameter();
  595. transf_args->push_back(transf_p);
  596. }
  597. if (effect_info.io) {
  598. MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString();
  599. auto transf_p = outer->add_parameter();
  600. transf_args->push_back(transf_p);
  601. }
  602. }
  603. template <typename T>
  604. void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
  605. const T &current_primal_fg, const FuncGraphPtr &outer,
  606. std::vector<AnfNodePtr> *const transf_args) {
  607. MS_EXCEPTION_IF_NULL(mng);
  608. TransformNormalArgs(mng, bprop_fg, outer, transf_args);
  609. constexpr size_t need_filter_size = 2;
  610. auto bprop_fg_param_size = bprop_fg->parameters().size() - need_filter_size;
  611. // current_primal_fg may have extra parameters after AutoMonad
  612. const auto &current_primal_fg_params = current_primal_fg->parameters();
  613. if (bprop_fg_param_size < current_primal_fg_params.size()) {
  614. for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) {
  615. auto p = current_primal_fg_params[i];
  616. MS_EXCEPTION_IF_NULL(p);
  617. // extra parameters should be Monad.
  618. if (!HasAbstractMonad(p)) {
  619. continue;
  620. }
  621. MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString()
  622. << ", has extra monad parameter: " << p->DebugString()
  623. << ", abstract: " << p->abstract()->ToString();
  624. TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
  625. auto transf_p = outer->add_parameter();
  626. // See also Notes on extra_node of BuildOutput.
  627. // Notes: No need to replace p with transf_p as the only use of p is here.
  628. // If extra_node in bprop_fg use p as free variable, a replacement of p is required here.
  629. // This replacement will make the usage of p in current_primal_fg got replaced with transf_p
  630. // of outer. outer will be released after it is being cloned to fprop_fg, so the func_graph_
  631. // in transf_p will be nullptr.
  632. // So the RULE is DONT tamper the current_primal_fg;
  633. transf_args->push_back(transf_p);
  634. }
  635. }
  636. if (transf_args->size() != current_primal_fg_params.size()) {
  637. MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString()
  638. << ", The number of parameter of this primal function is "
  639. << current_primal_fg_params.size() << ", but the number of parameters of bprop is "
  640. << bprop_fg_param_size;
  641. }
  642. }
  643. void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
  644. auto context = MsContext::GetInstance();
  645. MS_EXCEPTION_IF_NULL(context);
  646. bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
  647. // Skip checking if check_bprop not set
  648. if (!check_bprop_flag) {
  649. return;
  650. }
  651. // bprop_fg has been checked in caller
  652. auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
  653. MS_EXCEPTION_IF_NULL(check_bprop_class);
  654. auto check_bprop =
  655. bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
  656. std::vector<AnfNodePtr> inputs;
  657. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  658. constexpr int primitive_size = 1;
  659. constexpr int brprop_offset_size = 2;
  660. (void)inputs.insert(inputs.begin() + primitive_size, bprop_fg->parameters().begin(),
  661. bprop_fg->parameters().end() - brprop_offset_size);
  662. AnfNodePtr params = bprop_fg->NewCNode(inputs);
  663. inputs.clear();
  664. inputs.push_back(check_bprop);
  665. inputs.push_back(bprop_fg->output());
  666. inputs.push_back(params);
  667. AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
  668. bprop_fg->set_output(bprop_out);
  669. }
  670. FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) {
  671. MS_EXCEPTION_IF_NULL(bprop_fg);
  672. // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph.
  673. // current_primal_fg is specalized and AutoMoaded primal_fg;
  674. auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph();
  675. auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {});
  676. if (expanded_fg == nullptr) {
  677. MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString()
  678. << " Cell bprop function to K expanded func graph. NodeInfo: "
  679. << trace::GetDebugInfo(primal_fg->debug_info());
  680. }
  681. return expanded_fg;
  682. }
  683. FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  684. auto prim = GetValueNode<PrimitivePtr>(value_node);
  685. MS_EXCEPTION_IF_NULL(prim);
  686. auto &node_users = resources->manager()->node_users();
  687. auto &users = node_users[value_node];
  688. auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
  689. return IsPrimitiveCNode(user.first, prim);
  690. });
  691. if (cnode == users.end()) {
  692. MS_LOG(EXCEPTION) << "Fail to find cnode.";
  693. }
  694. auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
  695. auto func_graph = std::make_shared<FuncGraph>();
  696. std::vector<AnfNodePtr> outputs;
  697. auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
  698. bprop_cut->CopyHookFunction(prim);
  699. auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
  700. if (cell_id != "") {
  701. (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
  702. (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
  703. }
  704. outputs.push_back(NewValueNode(bprop_cut));
  705. for (size_t i = 0; i < inputs_num; ++i) {
  706. auto param = func_graph->add_parameter();
  707. outputs.push_back(param);
  708. }
  709. auto p1 = func_graph->add_parameter();
  710. auto p2 = func_graph->add_parameter();
  711. outputs.push_back(p1);
  712. outputs.push_back(p2);
  713. func_graph->set_output(func_graph->NewCNode(outputs));
  714. return func_graph;
  715. }
  716. FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
  717. auto prim = value_node->value()->cast<PrimitivePtr>();
  718. MS_EXCEPTION_IF_NULL(prim);
  719. auto &node_users = resources->manager()->node_users();
  720. auto &users = node_users[value_node];
  721. auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
  722. return IsPrimitiveCNode(user.first, prim);
  723. });
  724. if (cnode == users.end()) {
  725. MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
  726. }
  727. auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
  728. auto effect_info = GetPrimEffectInfo(prim);
  729. // Don't add U or IO monad parameters as it will be added later.
  730. size_t monad_params_size = 0;
  731. if (effect_info.memory) {
  732. monad_params_size++;
  733. }
  734. if (effect_info.io) {
  735. monad_params_size++;
  736. }
  737. if (inputs_num < monad_params_size) {
  738. MS_LOG(EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size
  739. << ", but the CNode is: " << cnode->first->DebugString();
  740. }
  741. inputs_num -= monad_params_size;
  742. auto func_graph = std::make_shared<FuncGraph>();
  743. std::vector<AnfNodePtr> outputs;
  744. outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  745. auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
  746. (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
  747. for (size_t i = 0; i < inputs_num; ++i) {
  748. // Mock params for inputs
  749. auto param = func_graph->add_parameter();
  750. // Mock derivatives for each inputs
  751. outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param}));
  752. }
  753. // mock params for out and dout
  754. (void)func_graph->add_parameter();
  755. (void)func_graph->add_parameter();
  756. func_graph->set_output(func_graph->NewCNode(outputs));
  757. return func_graph;
  758. }
  759. } // namespace ad
  760. } // namespace mindspore