You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.cc 38 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "frontend/operator/composite/composite.h"
  19. #include <algorithm>
  20. #include <utility>
  21. #include <sstream>
  22. #include "ir/anf.h"
  23. #include "ir/func_graph.h"
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/abstract_function.h"
  26. #include "abstract/dshape.h"
  27. #include "abstract/param_validator.h"
  28. #include "frontend/operator/cc_implementations.h"
  29. #include "frontend/optimizer/opt.h"
  30. #include "utils/symbolic.h"
  31. #include "pybind_api/api_register.h"
  32. #include "ir/signature.h"
  33. #include "debug/trace.h"
  34. namespace mindspore {
  35. // namespace to support composite operators definition
  36. namespace prim {
  37. using AbstractTensor = mindspore::abstract::AbstractTensor;
  38. using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
  39. using mindspore::abstract::AbstractAttribute;
  40. using mindspore::abstract::AbstractBase;
  41. using mindspore::abstract::AbstractClass;
  42. using mindspore::abstract::AbstractDictionary;
  43. using mindspore::abstract::AbstractDictionaryPtr;
  44. using mindspore::abstract::AbstractEllipsis;
  45. using mindspore::abstract::AbstractEllipsisPtr;
  46. using mindspore::abstract::AbstractFunction;
  47. using mindspore::abstract::AbstractFunctionPtr;
  48. using mindspore::abstract::AbstractList;
  49. using mindspore::abstract::AbstractNone;
  50. using mindspore::abstract::AbstractScalar;
  51. using mindspore::abstract::AbstractSlice;
  52. using mindspore::abstract::AbstractTuple;
  53. ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
  54. {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
  55. {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
  56. {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
  57. {"__ge__", kPrimScalarGe}};
  58. const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail");
  59. // copy from python API: reduce.
  60. // Apply a function of two arguments cumulatively to the items of a sequence,
  61. // from left to right, so as to reduce the sequence to a single value.For example,
  62. // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
  63. AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) {
  64. std::shared_ptr<Any> ret;
  65. size_t size = list.size();
  66. if (size < 2) {
  67. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  68. }
  69. AnyPtrList input;
  70. input.push_back(list[0]);
  71. input.push_back(list[1]);
  72. ret = std::make_shared<Any>(func(input));
  73. for (size_t i = 2; i < size; ++i) {
  74. input.clear();
  75. input.push_back(ret);
  76. input.push_back(list[i]);
  77. ret = std::make_shared<Any>(func(input));
  78. }
  79. return ret;
  80. }
  81. AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) {
  82. size_t size = list.size();
  83. if (size < 2) {
  84. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  85. }
  86. std::vector<AnfNodePtr> input;
  87. input.push_back(list[0]);
  88. input.push_back(list[1]);
  89. AnfNodePtr ret = func(input);
  90. for (size_t i = 2; i < size; ++i) {
  91. input.clear();
  92. input.push_back(ret);
  93. input.push_back(list[i]);
  94. ret = func(input);
  95. }
  96. return ret;
  97. }
  98. ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
  99. void HyperMap::Init() {
  100. if (fn_leaf_) {
  101. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  102. }
  103. signatures_ =
  104. // def hypermap(func:read, *args:ref):
  105. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  106. {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
  107. }
  108. HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
  109. : MetaFuncGraph("hyper_map"),
  110. fn_leaf_(fn_leaf),
  111. broadcast_(false),
  112. nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
  113. Init();
  114. }
  115. HyperMap::HyperMap(const HyperMap &h)
  116. : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
  117. Init();
  118. }
  119. AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  120. const ArgsPairList &arg_map) {
  121. MS_EXCEPTION_IF_NULL(func_graph);
  122. std::vector<AnfNodePtr> inputs;
  123. if (fn_arg != nullptr) {
  124. inputs.push_back(fn_arg);
  125. } else {
  126. inputs.push_back(NewValueNode(fn_leaf_));
  127. }
  128. (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
  129. [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
  130. return func_graph->NewCNode(inputs);
  131. }
  132. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
  133. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  134. MS_EXCEPTION_IF_NULL(func_graph);
  135. MS_EXCEPTION_IF_NULL(type);
  136. std::size_t size = type->elements().size();
  137. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  138. auto lhs = std::static_pointer_cast<List>(item.second);
  139. MS_EXCEPTION_IF_NULL(lhs);
  140. return lhs->elements().size() != size;
  141. });
  142. if (is_not_same) {
  143. MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
  144. }
  145. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  146. // hypermap and graph generated, it will cause memory leak.
  147. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  148. std::vector<AnfNodePtr> inputs;
  149. inputs.push_back(NewValueNode(prim::kPrimMakeList));
  150. for (int i = 0; i < SizeToInt(size); ++i) {
  151. std::vector<AnfNodePtr> inputs2;
  152. inputs2.push_back(fn_rec);
  153. if (fn_arg != nullptr) {
  154. inputs2.push_back(fn_arg);
  155. }
  156. (void)std::transform(
  157. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
  158. [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
  159. return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
  160. });
  161. inputs.push_back(func_graph->NewCNode(inputs2));
  162. }
  163. return func_graph->NewCNode(inputs);
  164. }
  165. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
  166. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  167. MS_EXCEPTION_IF_NULL(func_graph);
  168. MS_EXCEPTION_IF_NULL(type);
  169. std::size_t size = type->elements().size();
  170. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  171. auto lhs = std::static_pointer_cast<Tuple>(item.second);
  172. MS_EXCEPTION_IF_NULL(lhs);
  173. return lhs->elements().size() != size;
  174. });
  175. if (is_not_same) {
  176. MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length";
  177. }
  178. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  179. // hypermap and graph generated, it will cause memory leak.
  180. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  181. std::vector<AnfNodePtr> inputs;
  182. inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  183. for (int i = 0; i < SizeToInt(size); ++i) {
  184. std::vector<AnfNodePtr> inputs2;
  185. inputs2.push_back(fn_rec);
  186. if (fn_arg != nullptr) {
  187. inputs2.push_back(fn_arg);
  188. }
  189. (void)std::transform(
  190. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
  191. return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
  192. });
  193. inputs.push_back(func_graph->NewCNode(inputs2));
  194. }
  195. return func_graph->NewCNode(inputs);
  196. }
  197. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
  198. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  199. MS_EXCEPTION_IF_NULL(type);
  200. MS_EXCEPTION_IF_NULL(func_graph);
  201. std::vector<AnfNodePtr> inputs;
  202. inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
  203. inputs.push_back(NewValueNode(type));
  204. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  205. // hypermap and graph generated, it will cause memory leak.
  206. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  207. std::size_t attrSize = type->GetAttributes().size();
  208. for (std::size_t i = 0; i < attrSize; ++i) {
  209. std::vector<AnfNodePtr> inputs2;
  210. inputs2.push_back(fn_rec);
  211. if (fn_arg) {
  212. inputs2.push_back(fn_arg);
  213. }
  214. int j = 0;
  215. for (auto item : arg_map) {
  216. inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
  217. j++;
  218. }
  219. inputs.push_back(func_graph->NewCNode(inputs2));
  220. }
  221. return func_graph->NewCNode(inputs);
  222. }
  223. AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  224. bool found = false;
  225. TypeId id = kObjectTypeEnd;
  226. std::pair<AnfNodePtr, TypePtr> pair;
  227. for (auto &item : arg_map) {
  228. pair = item;
  229. id = item.second->type_id();
  230. if (nonleaf_.count(id)) {
  231. found = true;
  232. break;
  233. }
  234. }
  235. if (found) {
  236. // In a nonleaf situation, all arguments must have the same generic.
  237. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
  238. if (item.first != pair.first) {
  239. return item.second->type_id() != pair.second->type_id();
  240. }
  241. return false;
  242. });
  243. if (is_not_same) {
  244. std::ostringstream oss;
  245. oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
  246. << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  247. int idx = 0;
  248. for (auto &item : arg_map) {
  249. oss << ++idx << ": " << item.second->ToString() << "\n";
  250. }
  251. MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
  252. }
  253. }
  254. switch (id) {
  255. case kObjectTypeList: {
  256. auto type = std::static_pointer_cast<List>(pair.second);
  257. return FullMake(type, func_graph, fn_arg, arg_map);
  258. }
  259. case kObjectTypeTuple: {
  260. auto type = std::static_pointer_cast<Tuple>(pair.second);
  261. return FullMake(type, func_graph, fn_arg, arg_map);
  262. }
  263. case kObjectTypeClass: {
  264. auto type = std::static_pointer_cast<Class>(pair.second);
  265. return FullMake(type, func_graph, fn_arg, arg_map);
  266. }
  267. default:
  268. return FullMake(pair.second, func_graph, fn_arg, arg_map);
  269. }
  270. }
  271. ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
  272. TypePtr type_tensor = std::make_shared<TensorType>();
  273. bool flag = std::any_of(
  274. args_spec_list.begin(), args_spec_list.end(),
  275. [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
  276. if (flag && broadcast_) {
  277. ArgsPairList ret;
  278. for (auto &item : args_spec_list) {
  279. if (!IsSubType(item.second, type_tensor)) {
  280. TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
  281. ret.push_back(
  282. std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele));
  283. } else {
  284. ret.push_back(std::make_pair(item.first, item.second));
  285. }
  286. }
  287. return ret;
  288. }
  289. return args_spec_list;
  290. }
  291. FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
  292. FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
  293. ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  294. ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  295. ptr_graph->debug_info()->set_name("hyper_map");
  296. AnfNodePtr ptrFnArg = nullptr;
  297. std::size_t i = 0;
  298. ArgsPairList argmap;
  299. ArgsPairList argmap2;
  300. if (fn_leaf_ == nullptr) {
  301. ptrFnArg = ptr_graph->add_parameter();
  302. i = 1;
  303. }
  304. std::size_t size = args_spec_list.size();
  305. for (; i < size; ++i) {
  306. argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
  307. }
  308. argmap2 = Harmonize(ptr_graph, argmap);
  309. ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
  310. return ptr_graph;
  311. }
  312. abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  313. if (fn_leaf_ == nullptr) {
  314. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  315. // Assert that hypermap's function param does not contain free variables
  316. if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
  317. auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
  318. auto func_graph = graph_func->func_graph();
  319. if (func_graph->parent() != nullptr) {
  320. MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
  321. }
  322. }
  323. }
  324. AbstractBasePtrList broadened;
  325. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
  326. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  327. MS_EXCEPTION_IF_NULL(arg);
  328. return arg->Broaden();
  329. });
  330. return broadened;
  331. }
  332. REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
  333. (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
  334. .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
  335. .def(py::init<>());
  336. }));
  337. FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) {
  338. MS_EXCEPTION_IF_NULL(a_tuple);
  339. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  340. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  341. ret->debug_info()->set_name("tail");
  342. AnfNodePtr ptrTup = ret->add_parameter();
  343. std::vector<AnfNodePtr> elems;
  344. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  345. int tuple_size = SizeToInt(a_tuple->size());
  346. for (int i = 1; i < tuple_size; ++i) {
  347. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)}));
  348. }
  349. ret->set_output(ret->NewCNode(elems));
  350. return ret;
  351. }
  352. FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) {
  353. MS_EXCEPTION_IF_NULL(a_list);
  354. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  355. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  356. ret->debug_info()->set_name("tail");
  357. AnfNodePtr ptrList = ret->add_parameter();
  358. std::vector<AnfNodePtr> elems;
  359. elems.push_back(NewValueNode(prim::kPrimMakeList));
  360. int list_size = SizeToInt(a_list->size());
  361. for (int i = 1; i < list_size; ++i) {
  362. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)}));
  363. }
  364. ret->set_output(ret->NewCNode(elems));
  365. return ret;
  366. }
  367. FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  368. if (args_spec_list.size() != 1) {
  369. MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
  370. }
  371. AbstractBasePtr a = args_spec_list[0];
  372. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(a);
  373. if (a_tuple != nullptr) {
  374. return GenerateTupleFuncGraph(a_tuple);
  375. }
  376. abstract::AbstractListPtr a_list = dyn_cast<AbstractList>(a);
  377. if (a_list != nullptr) {
  378. return GenerateListFuncGraph(a_list);
  379. }
  380. MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
  381. }
  382. REGISTER_PYBIND_DEFINE(
  383. Tail_, ([](const py::module *m) {
  384. (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
  385. }));
  386. FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  387. int tuple_size = SizeToInt(args_spec_list.size());
  388. std::ostringstream ss;
  389. ss << "▶make_tuple_" << tuple_size;
  390. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  391. fg->debug_info()->set_name(ss.str());
  392. std::vector<AnfNodePtr> params;
  393. params.push_back(NewValueNode(prim::kPrimMakeTuple));
  394. for (int i = 0; i < tuple_size; ++i) {
  395. params.push_back(fg->add_parameter());
  396. }
  397. // make fprob first result, maketuple's forward result.
  398. AnfNodePtr out = fg->NewCNode(params);
  399. // make fprob second result, maketuple's backward function.
  400. FuncGraphPtr b = std::make_shared<FuncGraph>();
  401. ss.clear();
  402. ss << "◀make_tuple_" << tuple_size;
  403. b->debug_info()->set_name(ss.str());
  404. AnfNodePtr dout = b->add_parameter();
  405. std::vector<AnfNodePtr> grads;
  406. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  407. grads.push_back(NewValueNode(newenv));
  408. for (int i = 0; i < tuple_size; ++i) {
  409. grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
  410. }
  411. b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  412. b->set_output(b->NewCNode(grads));
  413. fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  414. fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  415. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
  416. return fg;
  417. }
  418. FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  419. int list_size = SizeToInt(args_spec_list.size());
  420. std::ostringstream ss;
  421. ss << "▶make_list_" << list_size;
  422. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  423. fg->debug_info()->set_name(ss.str());
  424. std::vector<AnfNodePtr> params;
  425. params.push_back(NewValueNode(prim::kPrimMakeList));
  426. for (int i = 0; i < list_size; ++i) {
  427. params.push_back(fg->add_parameter());
  428. }
  429. // make fprob first result, maketuple's forward result.
  430. AnfNodePtr out = fg->NewCNode(params);
  431. // make fprob second result, maketuple's backward function.
  432. FuncGraphPtr b = std::make_shared<FuncGraph>();
  433. ss.clear();
  434. ss << "◀make_list_" << list_size;
  435. b->debug_info()->set_name(ss.str());
  436. AnfNodePtr dout = b->add_parameter();
  437. std::vector<AnfNodePtr> grads;
  438. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  439. grads.push_back(NewValueNode(newenv));
  440. for (int i = 0; i < list_size; ++i) {
  441. grads.push_back(b->NewCNode({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
  442. }
  443. b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  444. b->set_output(b->NewCNode(grads));
  445. fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  446. fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  447. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
  448. return fg;
  449. }
  450. GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
  451. : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
  452. if (get_by_list) {
  453. signatures_ =
  454. // def grad(func:read, weight_list:ref):
  455. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  456. {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
  457. }
  458. }
  459. FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
  460. const std::vector<AnfNodePtr> &params_list, const std::vector<AnfNodePtr> &args,
  461. bool applyJ) {
  462. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  463. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  464. auto weights_node = weights;
  465. if (weights == nullptr && !args.empty()) {
  466. weights_node = ret->NewCNode(args);
  467. }
  468. ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
  469. ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
  470. std::vector<AnfNodePtr> inputs;
  471. if (applyJ) {
  472. inputs.push_back(opsJ);
  473. inputs.push_back(node);
  474. node = ret->NewCNode(inputs);
  475. }
  476. std::vector<AnfNodePtr> params;
  477. for (size_t i = 0; i < params_list.size(); ++i) {
  478. params.push_back(ret->add_parameter());
  479. }
  480. inputs.clear();
  481. inputs.push_back(node);
  482. (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
  483. AnfNodePtr cnode = ret->NewCNode(inputs);
  484. inputs.clear();
  485. inputs.push_back(opsTupleItem);
  486. inputs.push_back(cnode);
  487. inputs.push_back(NewValueNode(0));
  488. auto out = ret->NewCNode(inputs);
  489. inputs.clear();
  490. inputs.push_back(opsTupleItem);
  491. inputs.push_back(cnode);
  492. inputs.push_back(NewValueNode(1));
  493. AnfNodePtr ptr_bprop = ret->NewCNode(inputs);
  494. doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
  495. return ret;
  496. }
  497. void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
  498. ValueNodePtr opsTupleItem) {
  499. MS_EXCEPTION_IF_NULL(func_graph);
  500. AnfNodePtr ptr_bprop_arg = nullptr;
  501. if (sens_param_) {
  502. ptr_bprop_arg = func_graph->add_parameter();
  503. } else {
  504. auto ones_like = prim::GetPythonOps("ones_like");
  505. ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
  506. }
  507. AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});
  508. CNodePtr fv_bprop = nullptr;
  509. if (get_by_list_) {
  510. // python code: grads = hyper_map(F.partial(env_get, env), weights)
  511. AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)});
  512. AnfNodePtr partial_env_get =
  513. func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
  514. MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
  515. fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
  516. }
  517. CNodePtr inputs_bprop = nullptr;
  518. if (get_all_) {
  519. inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp});
  520. }
  521. // Gradients wrt inputs and parameters
  522. if (fv_bprop != nullptr && inputs_bprop != nullptr) {
  523. func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
  524. return;
  525. }
  526. // Gradients wrt parameters
  527. if (fv_bprop != nullptr) {
  528. func_graph->set_output(fv_bprop);
  529. return;
  530. }
  531. // Gradients wrt inputs
  532. if (inputs_bprop != nullptr) {
  533. func_graph->set_output(inputs_bprop);
  534. return;
  535. }
  536. // Gradients wrt first input.
  537. // ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
  538. func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)}));
  539. }
  540. // Generate the graph.
  541. FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  542. if (args_spec_list.empty()) {
  543. MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
  544. << args_spec_list.size() << ".";
  545. }
  546. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  547. AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
  548. if (fn == nullptr) {
  549. MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  550. }
  551. // Waiting for implementation.
  552. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
  553. MS_EXCEPTION_IF_NULL(real_fn);
  554. FuncGraphPtr ptr_graph = real_fn->func_graph();
  555. MS_EXCEPTION_IF_NULL(ptr_graph);
  556. TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
  557. FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
  558. TraceManager::EndTrace();
  559. auto nparam = ptr_graph->parameters().size();
  560. std::ostringstream ss;
  561. ss << "grad{" << nparam << "}";
  562. df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  563. df_builder->debug_info()->set_name(ss.str());
  564. ParameterPtr param_graph = df_builder->add_parameter();
  565. AnfNodePtr weights = nullptr;
  566. if (get_by_list_) {
  567. weights = df_builder->add_parameter();
  568. }
  569. std::vector<AnfNodePtr> inputs;
  570. inputs.push_back(NewValueNode(prim::kPrimJ));
  571. inputs.push_back(param_graph);
  572. auto jf = df_builder->NewCNode(inputs);
  573. // df is checked in GetGrad
  574. TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
  575. auto df = GetGrad(jf, weights, ptr_graph->parameters());
  576. TraceManager::EndTrace();
  577. df_builder->set_output(NewValueNode(df));
  578. return df_builder;
  579. }
  580. REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
  581. (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
  582. *m, "GradOperation_")
  583. .def(py::init<std::string &>(), py::arg("fn"))
  584. .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
  585. py::arg("get_by_list"), py::arg("sens_param"));
  586. }));
  587. // Generate the ListMap func graph.
  588. FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  589. size_t args_num = args_spec_list.size();
  590. // args: fn, list1, list2, ...
  591. if (args_num < 2) {
  592. MS_LOG(EXCEPTION) << "list_map takes at least two arguments";
  593. }
  594. for (size_t i = 1; i < args_num; ++i) {
  595. if (typeid(args_spec_list[i]) != typeid(AbstractBase)) {
  596. // The function currently not be use
  597. MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'";
  598. }
  599. }
  600. FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
  601. fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  602. fg_ptr->debug_info()->set_name("list_map");
  603. AnfNodePtr fn = fg_ptr->add_parameter();
  604. std::vector<AnfNodePtr> lists;
  605. for (size_t i = 1; i < args_num; ++i) {
  606. lists.push_back(fg_ptr->add_parameter());
  607. }
  608. std::vector<AnfNodePtr> iters;
  609. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  610. return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item});
  611. });
  612. std::vector<AnfNodePtr> nexts;
  613. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  614. return fg_ptr->NewCNode({NewValueNode(std::string("next")), item});
  615. });
  616. std::vector<AnfNodePtr> values;
  617. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  618. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item});
  619. });
  620. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  621. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)});
  622. });
  623. (void)values.insert(values.begin(), fn);
  624. AnfNodePtr cnode_graph = fg_ptr->NewCNode(values);
  625. AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph});
  626. FuncGraphPtr fgnext_ptr = std::make_shared<FuncGraph>();
  627. fgnext_ptr->debug_info()->set_name("body");
  628. FuncGraphPtr fgcond_ptr = std::make_shared<FuncGraph>();
  629. fgcond_ptr->debug_info()->set_name("cond");
  630. MakeCond(lists, fgnext_ptr, fgcond_ptr);
  631. MakeNext(lists, fgcond_ptr, fgnext_ptr);
  632. CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl});
  633. auto inputs = output_cnode->inputs();
  634. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  635. output_cnode->set_inputs(inputs);
  636. fg_ptr->set_output(output_cnode);
  637. return fg_ptr;
  638. }
  639. void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
  640. const FuncGraphPtr &fg_ptr) {
  641. MS_EXCEPTION_IF_NULL(fg_ptr);
  642. AnfNodePtr fn = fg_ptr->add_parameter();
  643. AnfNodePtr resl = fg_ptr->add_parameter();
  644. std::vector<AnfNodePtr> iters;
  645. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  646. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  647. std::vector<AnfNodePtr> hasnexts;
  648. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) {
  649. return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item});
  650. });
  651. // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
  652. FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
  653. fgtrue_ptr->debug_info()->set_name("ftrue");
  654. fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  655. CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl});
  656. auto inputs = fgtrue_output_cnode->inputs();
  657. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  658. fgtrue_output_cnode->set_inputs(inputs);
  659. fgtrue_ptr->set_output(fgtrue_output_cnode);
  660. FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
  661. fgfalse_ptr->debug_info()->set_name("ffalse");
  662. fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  663. fgfalse_ptr->set_output(resl);
  664. AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
  665. NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)});
  666. fgtrue_ptr->set_output(output_cnode);
  667. }
  668. void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
  669. const FuncGraphPtr &fg_ptr) {
  670. MS_EXCEPTION_IF_NULL(fg_ptr);
  671. AnfNodePtr fn = fg_ptr->add_parameter();
  672. std::vector<AnfNodePtr> iters;
  673. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  674. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  675. std::vector<AnfNodePtr> nexts;
  676. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  677. return fg_ptr->NewCNode({NewValueNode(std::string("next")), item});
  678. });
  679. std::vector<AnfNodePtr> values;
  680. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  681. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr});
  682. });
  683. iters.clear();
  684. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  685. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)});
  686. });
  687. (void)values.insert(values.begin(), fn);
  688. AnfNodePtr cnode_graph = fg_ptr->NewCNode(values);
  689. AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph});
  690. CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl});
  691. auto inputs = output_cnode->inputs();
  692. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  693. output_cnode->set_inputs(inputs);
  694. fg_ptr->set_output(output_cnode);
  695. }
  696. FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  697. // args: tuple1, tuple2
  698. abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
  699. AbstractBasePtr abs_a = args_spec_list[0];
  700. AbstractBasePtr abs_b = args_spec_list[1];
  701. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
  702. abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
  703. if (a_tuple == nullptr || b_tuple == nullptr) {
  704. TypePtrList types;
  705. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
  706. [](const AbstractBasePtr &arg) -> TypePtr {
  707. MS_EXCEPTION_IF_NULL(arg);
  708. return arg->BuildType();
  709. });
  710. auto stub = GenerateStubFunc(types);
  711. if (stub != nullptr) {
  712. MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
  713. << ", function: " << stub->ToString();
  714. return stub;
  715. }
  716. MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
  717. << args_spec_list[1]->ToString();
  718. }
  719. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  720. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  721. AnfNodePtr p_tup_a = ret->add_parameter();
  722. AnfNodePtr p_tup_b = ret->add_parameter();
  723. std::vector<AnfNodePtr> elems;
  724. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  725. int tuple_size = SizeToInt(a_tuple->size());
  726. for (int i = 0; i < tuple_size; ++i) {
  727. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
  728. }
  729. tuple_size = SizeToInt(b_tuple->size());
  730. for (int i = 0; i < tuple_size; ++i) {
  731. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
  732. }
  733. ret->set_output(ret->NewCNode(elems));
  734. return ret;
  735. }
  736. int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
  737. MS_EXCEPTION_IF_NULL(scalar);
  738. return GetValue<int>(scalar->BuildValue());
  739. }
  740. bool CheckIndexInRange(int index, int min, int max) { return (index >= min && index <= max); }
  741. int GetPositiveIndex(int index, int length) {
  742. if (index < 0) {
  743. index += length;
  744. }
  745. return index;
  746. }
  747. int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) {
  748. MS_EXCEPTION_IF_NULL(member);
  749. if (member->isa<AbstractScalar>()) {
  750. return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
  751. }
  752. if (member->isa<AbstractNone>()) {
  753. return default_value;
  754. }
  755. MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString();
  756. }
  757. void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index,
  758. int *stop_index, int *step_value) {
  759. MS_EXCEPTION_IF_NULL(tuple);
  760. MS_EXCEPTION_IF_NULL(slice);
  761. MS_EXCEPTION_IF_NULL(start_index);
  762. MS_EXCEPTION_IF_NULL(stop_index);
  763. MS_EXCEPTION_IF_NULL(step_value);
  764. const std::string start_name("Slice start index");
  765. const std::string stop_name("Slice stop index");
  766. const std::string step_name("Slice step value");
  767. int tuple_size = SizeToInt(tuple->size());
  768. int start_default = 0;
  769. int stop_default = tuple_size;
  770. int step_default = 1;
  771. *step_value = CheckSliceMember(slice->step(), step_default, step_name);
  772. if (*step_value == 0) {
  773. MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0.";
  774. }
  775. if (*step_value < 0) {
  776. start_default = tuple_size - 1;
  777. stop_default = -1;
  778. }
  779. *start_index = CheckSliceMember(slice->start(), start_default, start_name);
  780. *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
  781. if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) ||
  782. !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) {
  783. MS_EXCEPTION(ValueError) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
  784. << " out of range, tuple size " << tuple_size << ".";
  785. }
  786. *start_index = GetPositiveIndex(*start_index, tuple_size);
  787. if (!slice->stop()->isa<AbstractNone>()) {
  788. *stop_index = GetPositiveIndex(*stop_index, tuple_size);
  789. }
  790. }
  791. FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  792. // slice a tuple
  793. // args: tuple, start index, end index, step
  794. const std::string op_name("TupleSlice");
  795. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  796. AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  797. AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
  798. int start_index;
  799. int stop_index;
  800. int step_value;
  801. GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
  802. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  803. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  804. AnfNodePtr p_tuple = ret->add_parameter();
  805. (void)ret->add_parameter();
  806. std::vector<AnfNodePtr> elems;
  807. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  808. if (step_value > 0) {
  809. for (int index = start_index; index < stop_index; index = index + step_value) {
  810. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  811. }
  812. } else {
  813. for (int index = start_index; index > stop_index; index = index + step_value) {
  814. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  815. }
  816. }
  817. ret->set_output(ret->NewCNode(elems));
  818. return ret;
  819. }
  820. FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  821. // select indexed item
  822. // args: tuple of items, index
  823. const std::string op_name = std::string("TupleGetItemTensor");
  824. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  825. auto ret_graph = std::make_shared<FuncGraph>();
  826. ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  827. auto functions = ret_graph->add_parameter();
  828. auto index = ret_graph->add_parameter();
  829. ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
  830. return ret_graph;
  831. }
  832. REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
  833. (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
  834. .def(py::init<std::string &>());
  835. }));
  836. REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
  837. (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
  838. .def(py::init<std::string &>());
  839. }));
  840. REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
  841. (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
  842. *m, "TupleGetItemTensor_")
  843. .def(py::init<std::string &>());
  844. }));
  845. } // namespace prim
  846. } // namespace mindspore