You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.cc 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "operator/composite/composite.h"
  19. #include <algorithm>
  20. #include <utility>
  21. #include <sstream>
  22. #include "ir/anf.h"
  23. #include "ir/func_graph.h"
  24. #include "pipeline/static_analysis/abstract_value.h"
  25. #include "pipeline/static_analysis/abstract_function.h"
  26. #include "pipeline/static_analysis/dshape.h"
  27. #include "pipeline/static_analysis/param_validator.h"
  28. #include "operator/cc_implementations.h"
  29. #include "optimizer/opt.h"
  30. #include "utils/symbolic.h"
  31. #include "pybind_api/api_register.h"
  32. #include "./common.h"
  33. #include "ir/signature.h"
  34. #include "debug/trace.h"
  35. namespace mindspore {
  36. // namespace to support composite operators definition
  37. namespace prim {
  38. using AbstractTensor = mindspore::abstract::AbstractTensor;
  39. using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
  40. using mindspore::abstract::AbstractAttribute;
  41. using mindspore::abstract::AbstractBase;
  42. using mindspore::abstract::AbstractClass;
  43. using mindspore::abstract::AbstractDictionary;
  44. using mindspore::abstract::AbstractDictionaryPtr;
  45. using mindspore::abstract::AbstractEllipsis;
  46. using mindspore::abstract::AbstractEllipsisPtr;
  47. using mindspore::abstract::AbstractFunction;
  48. using mindspore::abstract::AbstractFunctionPtr;
  49. using mindspore::abstract::AbstractList;
  50. using mindspore::abstract::AbstractNone;
  51. using mindspore::abstract::AbstractScalar;
  52. using mindspore::abstract::AbstractSlice;
  53. using mindspore::abstract::AbstractTuple;
  54. ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
  55. {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
  56. {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
  57. {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
  58. {"__ge__", kPrimScalarGe}};
  59. const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail");
  60. // copy from python API: reduce.
  61. // Apply a function of two arguments cumulatively to the items of a sequence,
  62. // from left to right, so as to reduce the sequence to a single value.For example,
  63. // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
  64. AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) {
  65. std::shared_ptr<Any> ret;
  66. size_t size = list.size();
  67. if (size < 2) {
  68. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  69. }
  70. AnyPtrList input;
  71. input.push_back(list[0]);
  72. input.push_back(list[1]);
  73. ret = std::make_shared<Any>(func(input));
  74. for (size_t i = 2; i < size; ++i) {
  75. input.clear();
  76. input.push_back(ret);
  77. input.push_back(list[i]);
  78. ret = std::make_shared<Any>(func(input));
  79. }
  80. return ret;
  81. }
  82. AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) {
  83. size_t size = list.size();
  84. if (size < 2) {
  85. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  86. }
  87. std::vector<AnfNodePtr> input;
  88. input.push_back(list[0]);
  89. input.push_back(list[1]);
  90. AnfNodePtr ret = func(input);
  91. for (size_t i = 2; i < size; ++i) {
  92. input.clear();
  93. input.push_back(ret);
  94. input.push_back(list[i]);
  95. ret = func(input);
  96. }
  97. return ret;
  98. }
  99. ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
  100. void HyperMap::Init() {
  101. if (fn_leaf_) {
  102. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  103. }
  104. signatures_ =
  105. // def hypermap(func:read, *args:ref):
  106. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  107. {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
  108. }
  109. HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
  110. : MetaFuncGraph("hyper_map"),
  111. fn_leaf_(fn_leaf),
  112. broadcast_(false),
  113. nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
  114. Init();
  115. }
  116. HyperMap::HyperMap(const HyperMap &h)
  117. : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
  118. Init();
  119. }
  120. AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  121. const ArgsPairList &arg_map) {
  122. MS_EXCEPTION_IF_NULL(func_graph);
  123. std::vector<AnfNodePtr> inputs;
  124. if (fn_arg != nullptr) {
  125. inputs.push_back(fn_arg);
  126. } else {
  127. inputs.push_back(NewValueNode(fn_leaf_));
  128. }
  129. (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
  130. [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
  131. return func_graph->NewCNode(inputs);
  132. }
  133. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
  134. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  135. MS_EXCEPTION_IF_NULL(func_graph);
  136. MS_EXCEPTION_IF_NULL(type);
  137. std::size_t size = type->elements().size();
  138. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  139. auto lhs = std::static_pointer_cast<List>(item.second);
  140. MS_EXCEPTION_IF_NULL(lhs);
  141. return lhs->elements().size() != size;
  142. });
  143. if (is_not_same) {
  144. MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
  145. }
  146. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  147. // hypermap and graph generated, it will cause memory leak.
  148. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  149. std::vector<AnfNodePtr> inputs;
  150. inputs.push_back(NewValueNode(prim::kPrimMakeList));
  151. for (int i = 0; i < SizeToInt(size); ++i) {
  152. std::vector<AnfNodePtr> inputs2;
  153. inputs2.push_back(fn_rec);
  154. if (fn_arg != nullptr) {
  155. inputs2.push_back(fn_arg);
  156. }
  157. (void)std::transform(
  158. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
  159. [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
  160. return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
  161. });
  162. inputs.push_back(func_graph->NewCNode(inputs2));
  163. }
  164. return func_graph->NewCNode(inputs);
  165. }
  166. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
  167. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  168. MS_EXCEPTION_IF_NULL(func_graph);
  169. MS_EXCEPTION_IF_NULL(type);
  170. std::size_t size = type->elements().size();
  171. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  172. auto lhs = std::static_pointer_cast<Tuple>(item.second);
  173. MS_EXCEPTION_IF_NULL(lhs);
  174. return lhs->elements().size() != size;
  175. });
  176. if (is_not_same) {
  177. MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length";
  178. }
  179. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  180. // hypermap and graph generated, it will cause memory leak.
  181. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  182. std::vector<AnfNodePtr> inputs;
  183. inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  184. for (int i = 0; i < SizeToInt(size); ++i) {
  185. std::vector<AnfNodePtr> inputs2;
  186. inputs2.push_back(fn_rec);
  187. if (fn_arg != nullptr) {
  188. inputs2.push_back(fn_arg);
  189. }
  190. (void)std::transform(
  191. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
  192. return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
  193. });
  194. inputs.push_back(func_graph->NewCNode(inputs2));
  195. }
  196. return func_graph->NewCNode(inputs);
  197. }
  198. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
  199. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  200. MS_EXCEPTION_IF_NULL(type);
  201. MS_EXCEPTION_IF_NULL(func_graph);
  202. std::vector<AnfNodePtr> inputs;
  203. inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
  204. inputs.push_back(NewValueNode(type));
  205. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  206. // hypermap and graph generated, it will cause memory leak.
  207. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  208. std::size_t attrSize = type->GetAttributes().size();
  209. for (std::size_t i = 0; i < attrSize; ++i) {
  210. std::vector<AnfNodePtr> inputs2;
  211. inputs2.push_back(fn_rec);
  212. if (fn_arg) {
  213. inputs2.push_back(fn_arg);
  214. }
  215. int j = 0;
  216. for (auto item : arg_map) {
  217. inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
  218. j++;
  219. }
  220. inputs.push_back(func_graph->NewCNode(inputs2));
  221. }
  222. return func_graph->NewCNode(inputs);
  223. }
  224. AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  225. bool found = false;
  226. TypeId id = kObjectTypeEnd;
  227. std::pair<AnfNodePtr, TypePtr> pair;
  228. for (auto &item : arg_map) {
  229. pair = item;
  230. id = item.second->type_id();
  231. if (nonleaf_.count(id)) {
  232. found = true;
  233. break;
  234. }
  235. }
  236. if (found) {
  237. // In a nonleaf situation, all arguments must have the same generic.
  238. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
  239. if (item.first != pair.first) {
  240. return item.second->type_id() != pair.second->type_id();
  241. }
  242. return false;
  243. });
  244. if (is_not_same) {
  245. std::ostringstream oss;
  246. oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
  247. << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  248. int idx = 0;
  249. for (auto &item : arg_map) {
  250. oss << ++idx << ": " << item.second->ToString() << "\n";
  251. }
  252. MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
  253. }
  254. }
  255. switch (id) {
  256. case kObjectTypeList: {
  257. auto type = std::static_pointer_cast<List>(pair.second);
  258. return FullMake(type, func_graph, fn_arg, arg_map);
  259. }
  260. case kObjectTypeTuple: {
  261. auto type = std::static_pointer_cast<Tuple>(pair.second);
  262. return FullMake(type, func_graph, fn_arg, arg_map);
  263. }
  264. case kObjectTypeClass: {
  265. auto type = std::static_pointer_cast<Class>(pair.second);
  266. return FullMake(type, func_graph, fn_arg, arg_map);
  267. }
  268. default:
  269. return FullMake(pair.second, func_graph, fn_arg, arg_map);
  270. }
  271. }
  272. ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
  273. TypePtr type_tensor = std::make_shared<TensorType>();
  274. bool flag = std::any_of(
  275. args_spec_list.begin(), args_spec_list.end(),
  276. [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
  277. if (flag && broadcast_) {
  278. ArgsPairList ret;
  279. for (auto &item : args_spec_list) {
  280. if (!IsSubType(item.second, type_tensor)) {
  281. TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
  282. ret.push_back(
  283. std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele));
  284. } else {
  285. ret.push_back(std::make_pair(item.first, item.second));
  286. }
  287. }
  288. return ret;
  289. }
  290. return args_spec_list;
  291. }
  292. FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
  293. FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
  294. ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  295. ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  296. ptrGraph->debug_info()->set_name("hyper_map");
  297. AnfNodePtr ptrFnArg = nullptr;
  298. std::size_t i = 0;
  299. ArgsPairList argmap;
  300. ArgsPairList argmap2;
  301. if (fn_leaf_ == nullptr) {
  302. ptrFnArg = ptrGraph->add_parameter();
  303. i = 1;
  304. }
  305. std::size_t size = args_spec_list.size();
  306. for (; i < size; ++i) {
  307. argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
  308. }
  309. argmap2 = Harmonize(ptrGraph, argmap);
  310. ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2));
  311. return ptrGraph;
  312. }
  313. abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  314. if (fn_leaf_ == nullptr) {
  315. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  316. // Assert that hypermap's function param does not contain free variables
  317. if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
  318. auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
  319. auto func_graph = graph_func->func_graph();
  320. if (func_graph->parent() != nullptr) {
  321. MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
  322. }
  323. }
  324. }
  325. AbstractBasePtrList broadened;
  326. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
  327. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  328. MS_EXCEPTION_IF_NULL(arg);
  329. return arg->Broaden();
  330. });
  331. return broadened;
  332. }
  333. REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
  334. (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
  335. .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
  336. .def(py::init<>());
  337. }));
  338. FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) {
  339. MS_EXCEPTION_IF_NULL(a_tuple);
  340. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  341. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  342. ret->debug_info()->set_name("tail");
  343. AnfNodePtr ptrTup = ret->add_parameter();
  344. std::vector<AnfNodePtr> elems;
  345. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  346. int tuple_size = SizeToInt(a_tuple->size());
  347. for (int i = 1; i < tuple_size; ++i) {
  348. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)}));
  349. }
  350. ret->set_output(ret->NewCNode(elems));
  351. return ret;
  352. }
  353. FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) {
  354. MS_EXCEPTION_IF_NULL(a_list);
  355. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  356. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  357. ret->debug_info()->set_name("tail");
  358. AnfNodePtr ptrList = ret->add_parameter();
  359. std::vector<AnfNodePtr> elems;
  360. elems.push_back(NewValueNode(prim::kPrimMakeList));
  361. int list_size = SizeToInt(a_list->size());
  362. for (int i = 1; i < list_size; ++i) {
  363. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)}));
  364. }
  365. ret->set_output(ret->NewCNode(elems));
  366. return ret;
  367. }
  368. FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  369. if (args_spec_list.size() != 1) {
  370. MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
  371. }
  372. AbstractBasePtr a = args_spec_list[0];
  373. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(a);
  374. if (a_tuple != nullptr) {
  375. return GenerateTupleFuncGraph(a_tuple);
  376. }
  377. abstract::AbstractListPtr a_list = dyn_cast<AbstractList>(a);
  378. if (a_list != nullptr) {
  379. return GenerateListFuncGraph(a_list);
  380. }
  381. MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
  382. }
  383. REGISTER_PYBIND_DEFINE(
  384. Tail_, ([](const py::module *m) {
  385. (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
  386. }));
  387. FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  388. int tuple_size = SizeToInt(args_spec_list.size());
  389. std::ostringstream ss;
  390. ss << "▶make_tuple_" << tuple_size;
  391. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  392. fg->debug_info()->set_name(ss.str());
  393. std::vector<AnfNodePtr> params;
  394. params.push_back(NewValueNode(prim::kPrimMakeTuple));
  395. for (int i = 0; i < tuple_size; ++i) {
  396. params.push_back(fg->add_parameter());
  397. }
  398. // make fprob first result, maketuple's forward result.
  399. AnfNodePtr out = fg->NewCNode(params);
  400. // make fprob second result, maketuple's backward function.
  401. FuncGraphPtr b = std::make_shared<FuncGraph>();
  402. ss.clear();
  403. ss << "◀make_tuple_" << tuple_size;
  404. b->debug_info()->set_name(ss.str());
  405. AnfNodePtr dout = b->add_parameter();
  406. std::vector<AnfNodePtr> grads;
  407. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  408. grads.push_back(NewValueNode(newenv));
  409. for (int i = 0; i < tuple_size; ++i) {
  410. grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
  411. }
  412. b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  413. b->set_output(b->NewCNode(grads));
  414. fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  415. fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  416. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
  417. return fg;
  418. }
  419. GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
  420. : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
  421. if (get_by_list) {
  422. signatures_ =
  423. // def grad(func:read, weight_list:ref):
  424. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  425. {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
  426. }
  427. }
  428. FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
  429. const std::vector<AnfNodePtr> &params_list, const std::vector<AnfNodePtr> &args,
  430. bool applyJ) {
  431. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  432. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  433. auto weights_node = weights;
  434. if (weights == nullptr && !args.empty()) {
  435. weights_node = ret->NewCNode(args);
  436. }
  437. ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
  438. ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
  439. std::vector<AnfNodePtr> inputs;
  440. if (applyJ) {
  441. inputs.push_back(opsJ);
  442. inputs.push_back(node);
  443. node = ret->NewCNode(inputs);
  444. }
  445. std::vector<AnfNodePtr> params;
  446. for (size_t i = 0; i < params_list.size(); ++i) {
  447. params.push_back(ret->add_parameter());
  448. }
  449. inputs.clear();
  450. inputs.push_back(node);
  451. (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
  452. AnfNodePtr cnode = ret->NewCNode(inputs);
  453. inputs.clear();
  454. inputs.push_back(opsTupleItem);
  455. inputs.push_back(cnode);
  456. inputs.push_back(NewValueNode(0));
  457. auto out = ret->NewCNode(inputs);
  458. inputs.clear();
  459. inputs.push_back(opsTupleItem);
  460. inputs.push_back(cnode);
  461. inputs.push_back(NewValueNode(1));
  462. AnfNodePtr ptrBprop = ret->NewCNode(inputs);
  463. doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem);
  464. return ret;
  465. }
  466. void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
  467. ValueNodePtr opsTupleItem) {
  468. MS_EXCEPTION_IF_NULL(func_graph);
  469. AnfNodePtr ptrBPropArg = nullptr;
  470. if (sens_param_) {
  471. ptrBPropArg = func_graph->add_parameter();
  472. } else {
  473. auto ones_like = prim::GetPythonOps("ones_like");
  474. ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out});
  475. }
  476. AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg});
  477. CNodePtr fv_bprop = nullptr;
  478. if (get_by_list_) {
  479. // python code: grads = hyper_map(F.partial(env_get, env), weights)
  480. AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)});
  481. AnfNodePtr partial_env_get =
  482. func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
  483. MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
  484. fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
  485. }
  486. CNodePtr inputs_bprop = nullptr;
  487. if (get_all_) {
  488. inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp});
  489. }
  490. // Gradients wrt inputs and parameters
  491. if (fv_bprop != nullptr && inputs_bprop != nullptr) {
  492. func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
  493. return;
  494. }
  495. // Gradients wrt parameters
  496. if (fv_bprop != nullptr) {
  497. func_graph->set_output(fv_bprop);
  498. return;
  499. }
  500. // Gradients wrt inputs
  501. if (inputs_bprop != nullptr) {
  502. func_graph->set_output(inputs_bprop);
  503. return;
  504. }
  505. // Gradients wrt first input.
  506. // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
  507. func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)}));
  508. }
  509. // Generate the graph.
  510. FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  511. if (args_spec_list.size() < 1) {
  512. MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
  513. << args_spec_list.size() << ".";
  514. }
  515. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  516. AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
  517. if (fn == nullptr) {
  518. MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  519. }
  520. // Waiting for implementation.
  521. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
  522. MS_EXCEPTION_IF_NULL(real_fn);
  523. FuncGraphPtr ptrGraph = real_fn->func_graph();
  524. MS_EXCEPTION_IF_NULL(ptrGraph);
  525. TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
  526. FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>();
  527. TraceManager::EndTrace();
  528. auto nparam = ptrGraph->parameters().size();
  529. std::ostringstream ss;
  530. ss << "grad{" << nparam << "}";
  531. dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  532. dfBuilder->debug_info()->set_name(ss.str());
  533. ParameterPtr param_graph = dfBuilder->add_parameter();
  534. AnfNodePtr weights = nullptr;
  535. if (get_by_list_) {
  536. weights = dfBuilder->add_parameter();
  537. }
  538. std::vector<AnfNodePtr> inputs;
  539. inputs.push_back(NewValueNode(prim::kPrimJ));
  540. inputs.push_back(param_graph);
  541. auto jf = dfBuilder->NewCNode(inputs);
  542. // df is checked in GetGrad
  543. TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
  544. auto df = GetGrad(jf, weights, ptrGraph->parameters());
  545. TraceManager::EndTrace();
  546. dfBuilder->set_output(NewValueNode(df));
  547. return dfBuilder;
  548. }
  549. REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
  550. (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
  551. *m, "GradOperation_")
  552. .def(py::init<std::string &>(), py::arg("fn"))
  553. .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
  554. py::arg("get_by_list"), py::arg("sens_param"));
  555. }));
  556. // Generate the ListMap func graph.
  557. FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  558. size_t args_num = args_spec_list.size();
  559. // args: fn, list1, list2, ...
  560. if (args_num < 2) {
  561. MS_LOG(EXCEPTION) << "list_map takes at least two arguments";
  562. }
  563. for (size_t i = 1; i < args_num; ++i) {
  564. if (typeid(args_spec_list[i]) != typeid(AbstractBase)) {
  565. // The function currently not be use
  566. MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'";
  567. }
  568. }
  569. FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
  570. fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  571. fg_ptr->debug_info()->set_name("list_map");
  572. AnfNodePtr fn = fg_ptr->add_parameter();
  573. std::vector<AnfNodePtr> lists;
  574. for (size_t i = 1; i < args_num; ++i) {
  575. lists.push_back(fg_ptr->add_parameter());
  576. }
  577. std::vector<AnfNodePtr> iters;
  578. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  579. return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item});
  580. });
  581. std::vector<AnfNodePtr> nexts;
  582. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  583. return fg_ptr->NewCNode({NewValueNode(std::string("next")), item});
  584. });
  585. std::vector<AnfNodePtr> values;
  586. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  587. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item});
  588. });
  589. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  590. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)});
  591. });
  592. (void)values.insert(values.begin(), fn);
  593. AnfNodePtr cnode_graph = fg_ptr->NewCNode(values);
  594. AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph});
  595. FuncGraphPtr fgnext_ptr = std::make_shared<FuncGraph>();
  596. fgnext_ptr->debug_info()->set_name("body");
  597. FuncGraphPtr fgcond_ptr = std::make_shared<FuncGraph>();
  598. fgcond_ptr->debug_info()->set_name("cond");
  599. MakeCond(lists, fgnext_ptr, fgcond_ptr);
  600. MakeNext(lists, fgcond_ptr, fgnext_ptr);
  601. CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl});
  602. auto inputs = output_cnode->inputs();
  603. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  604. output_cnode->set_inputs(inputs);
  605. fg_ptr->set_output(output_cnode);
  606. return fg_ptr;
  607. }
  608. void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
  609. const FuncGraphPtr &fg_ptr) {
  610. MS_EXCEPTION_IF_NULL(fg_ptr);
  611. AnfNodePtr fn = fg_ptr->add_parameter();
  612. AnfNodePtr resl = fg_ptr->add_parameter();
  613. std::vector<AnfNodePtr> iters;
  614. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  615. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  616. std::vector<AnfNodePtr> hasnexts;
  617. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) {
  618. return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item});
  619. });
  620. // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
  621. FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
  622. fgtrue_ptr->debug_info()->set_name("ftrue");
  623. fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  624. CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl});
  625. auto inputs = fgtrue_output_cnode->inputs();
  626. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  627. fgtrue_output_cnode->set_inputs(inputs);
  628. fgtrue_ptr->set_output(fgtrue_output_cnode);
  629. FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
  630. fgfalse_ptr->debug_info()->set_name("ffalse");
  631. fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  632. fgfalse_ptr->set_output(resl);
  633. AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
  634. NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)});
  635. fgtrue_ptr->set_output(output_cnode);
  636. }
  637. void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
  638. const FuncGraphPtr &fg_ptr) {
  639. MS_EXCEPTION_IF_NULL(fg_ptr);
  640. AnfNodePtr fn = fg_ptr->add_parameter();
  641. std::vector<AnfNodePtr> iters;
  642. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  643. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  644. std::vector<AnfNodePtr> nexts;
  645. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  646. return fg_ptr->NewCNode({NewValueNode(std::string("next")), item});
  647. });
  648. std::vector<AnfNodePtr> values;
  649. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  650. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr});
  651. });
  652. iters.clear();
  653. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  654. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)});
  655. });
  656. (void)values.insert(values.begin(), fn);
  657. AnfNodePtr cnode_graph = fg_ptr->NewCNode(values);
  658. AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph});
  659. CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl});
  660. auto inputs = output_cnode->inputs();
  661. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  662. output_cnode->set_inputs(inputs);
  663. fg_ptr->set_output(output_cnode);
  664. }
  665. FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  666. // args: tuple1, tuple2
  667. abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
  668. AbstractBasePtr abs_a = args_spec_list[0];
  669. AbstractBasePtr abs_b = args_spec_list[1];
  670. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
  671. abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
  672. if (a_tuple == nullptr || b_tuple == nullptr) {
  673. MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
  674. << args_spec_list[1]->ToString();
  675. }
  676. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  677. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  678. AnfNodePtr p_tup_a = ret->add_parameter();
  679. AnfNodePtr p_tup_b = ret->add_parameter();
  680. std::vector<AnfNodePtr> elems;
  681. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  682. int tuple_size = SizeToInt(a_tuple->size());
  683. for (int i = 0; i < tuple_size; ++i) {
  684. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
  685. }
  686. tuple_size = SizeToInt(b_tuple->size());
  687. for (int i = 0; i < tuple_size; ++i) {
  688. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
  689. }
  690. ret->set_output(ret->NewCNode(elems));
  691. return ret;
  692. }
  693. int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
  694. MS_EXCEPTION_IF_NULL(scalar);
  695. return GetValue<int>(scalar->BuildValue());
  696. }
  697. bool CheckIndexInRange(int index, int min, int max) { return (index >= min && index <= max); }
  698. int GetPositiveIndex(int index, int length) {
  699. if (index < 0) {
  700. index += length;
  701. }
  702. return index;
  703. }
  704. int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) {
  705. MS_EXCEPTION_IF_NULL(member);
  706. if (member->isa<AbstractScalar>()) {
  707. return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
  708. }
  709. if (member->isa<AbstractNone>()) {
  710. return default_value;
  711. }
  712. MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString();
  713. }
  714. void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index,
  715. int *stop_index, int *step_value) {
  716. MS_EXCEPTION_IF_NULL(tuple);
  717. MS_EXCEPTION_IF_NULL(slice);
  718. MS_EXCEPTION_IF_NULL(start_index);
  719. MS_EXCEPTION_IF_NULL(stop_index);
  720. MS_EXCEPTION_IF_NULL(step_value);
  721. const std::string start_name("Slice start index");
  722. const std::string stop_name("Slice stop index");
  723. const std::string step_name("Slice step value");
  724. int tuple_size = SizeToInt(tuple->size());
  725. int start_default = 0;
  726. int stop_default = tuple_size;
  727. int step_default = 1;
  728. *step_value = CheckSliceMember(slice->step(), step_default, step_name);
  729. if (*step_value == 0) {
  730. MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0.";
  731. }
  732. if (*step_value < 0) {
  733. start_default = tuple_size - 1;
  734. stop_default = -1;
  735. }
  736. *start_index = CheckSliceMember(slice->start(), start_default, start_name);
  737. *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
  738. if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) ||
  739. !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) {
  740. MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
  741. << " out of range, tuple size " << tuple_size << ".";
  742. }
  743. *start_index = GetPositiveIndex(*start_index, tuple_size);
  744. if (!slice->stop()->isa<AbstractNone>()) {
  745. *stop_index = GetPositiveIndex(*stop_index, tuple_size);
  746. }
  747. }
  748. FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  749. // slice a tuple
  750. // args: tuple, start index, end index, step
  751. const std::string op_name("TupleSlice");
  752. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  753. AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  754. AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
  755. int start_index;
  756. int stop_index;
  757. int step_value;
  758. GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
  759. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  760. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  761. AnfNodePtr p_tuple = ret->add_parameter();
  762. (void)ret->add_parameter();
  763. std::vector<AnfNodePtr> elems;
  764. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  765. if (step_value > 0) {
  766. for (int index = start_index; index < stop_index; index = index + step_value) {
  767. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  768. }
  769. } else {
  770. for (int index = start_index; index > stop_index; index = index + step_value) {
  771. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  772. }
  773. }
  774. ret->set_output(ret->NewCNode(elems));
  775. return ret;
  776. }
  777. FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  778. // select indexed item
  779. // args: tuple of items, index
  780. const std::string op_name = std::string("TupleGetItemTensor");
  781. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  782. AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  783. AbstractBasePtrList branches = branches_abs->elements();
  784. if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
  785. FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
  786. ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  787. AnfNodePtr functions = ret_graph->add_parameter();
  788. auto index = ret_graph->add_parameter();
  789. ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
  790. return ret_graph;
  791. }
  792. MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << ".";
  793. }
  794. REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
  795. (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
  796. .def(py::init<std::string &>());
  797. }));
  798. REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
  799. (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
  800. .def(py::init<std::string &>());
  801. }));
  802. REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
  803. (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
  804. *m, "TupleGetItemTensor_")
  805. .def(py::init<std::string &>());
  806. }));
  807. } // namespace prim
  808. } // namespace mindspore