You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.cc 38 kB

5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "frontend/operator/composite/composite.h"
  19. #include <algorithm>
  20. #include <utility>
  21. #include <sstream>
  22. #include "ir/anf.h"
  23. #include "ir/func_graph.h"
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/abstract_function.h"
  26. #include "abstract/dshape.h"
  27. #include "abstract/param_validator.h"
  28. #include "frontend/operator/cc_implementations.h"
  29. #include "frontend/optimizer/opt.h"
  30. #include "utils/symbolic.h"
  31. #include "pybind_api/api_register.h"
  32. #include "ir/signature.h"
  33. #include "debug/trace.h"
  34. namespace mindspore {
  35. // namespace to support composite operators definition
  36. namespace prim {
  37. using AbstractTensor = mindspore::abstract::AbstractTensor;
  38. using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
  39. using mindspore::abstract::AbstractAttribute;
  40. using mindspore::abstract::AbstractBase;
  41. using mindspore::abstract::AbstractClass;
  42. using mindspore::abstract::AbstractDictionary;
  43. using mindspore::abstract::AbstractDictionaryPtr;
  44. using mindspore::abstract::AbstractEllipsis;
  45. using mindspore::abstract::AbstractEllipsisPtr;
  46. using mindspore::abstract::AbstractFunction;
  47. using mindspore::abstract::AbstractFunctionPtr;
  48. using mindspore::abstract::AbstractList;
  49. using mindspore::abstract::AbstractNone;
  50. using mindspore::abstract::AbstractScalar;
  51. using mindspore::abstract::AbstractSlice;
  52. using mindspore::abstract::AbstractTuple;
  53. ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
  54. {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
  55. {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
  56. {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
  57. {"__ge__", kPrimScalarGe}};
  58. // copy from python API: reduce.
  59. // Apply a function of two arguments cumulatively to the items of a sequence,
  60. // from left to right, so as to reduce the sequence to a single value.For example,
  61. // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
  62. AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) {
  63. std::shared_ptr<Any> ret;
  64. size_t size = list.size();
  65. if (size < 2) {
  66. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  67. }
  68. AnyPtrList input;
  69. input.push_back(list[0]);
  70. input.push_back(list[1]);
  71. ret = std::make_shared<Any>(func(input));
  72. for (size_t i = 2; i < size; ++i) {
  73. input.clear();
  74. input.push_back(ret);
  75. input.push_back(list[i]);
  76. ret = std::make_shared<Any>(func(input));
  77. }
  78. return ret;
  79. }
  80. AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) {
  81. size_t size = list.size();
  82. if (size < 2) {
  83. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  84. }
  85. std::vector<AnfNodePtr> input;
  86. input.push_back(list[0]);
  87. input.push_back(list[1]);
  88. AnfNodePtr ret = func(input);
  89. for (size_t i = 2; i < size; ++i) {
  90. input.clear();
  91. input.push_back(ret);
  92. input.push_back(list[i]);
  93. ret = func(input);
  94. }
  95. return ret;
  96. }
  97. ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
  98. void HyperMap::Init() {
  99. if (fn_leaf_) {
  100. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  101. }
  102. signatures_ =
  103. // def hypermap(func:read, *args:ref):
  104. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  105. {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
  106. }
  107. HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
  108. : MetaFuncGraph("hyper_map"),
  109. fn_leaf_(fn_leaf),
  110. broadcast_(false),
  111. nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
  112. Init();
  113. }
  114. HyperMap::HyperMap(const HyperMap &h)
  115. : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
  116. Init();
  117. }
  118. AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  119. const ArgsPairList &arg_map) {
  120. MS_EXCEPTION_IF_NULL(func_graph);
  121. std::vector<AnfNodePtr> inputs;
  122. if (fn_arg != nullptr) {
  123. inputs.push_back(fn_arg);
  124. } else {
  125. inputs.push_back(NewValueNode(fn_leaf_));
  126. }
  127. (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
  128. [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
  129. return func_graph->NewCNodeInOrder(inputs);
  130. }
  131. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
  132. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  133. MS_EXCEPTION_IF_NULL(func_graph);
  134. MS_EXCEPTION_IF_NULL(type);
  135. std::size_t size = type->elements().size();
  136. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  137. auto lhs = std::static_pointer_cast<List>(item.second);
  138. MS_EXCEPTION_IF_NULL(lhs);
  139. return lhs->elements().size() != size;
  140. });
  141. if (is_not_same) {
  142. MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
  143. }
  144. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  145. // hypermap and graph generated, it will cause memory leak.
  146. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  147. std::vector<AnfNodePtr> inputs;
  148. inputs.push_back(NewValueNode(prim::kPrimMakeList));
  149. for (int64_t i = 0; i < SizeToLong(size); ++i) {
  150. std::vector<AnfNodePtr> inputs2;
  151. inputs2.push_back(fn_rec);
  152. if (fn_arg != nullptr) {
  153. inputs2.push_back(fn_arg);
  154. }
  155. (void)std::transform(
  156. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
  157. [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
  158. return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
  159. });
  160. inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
  161. }
  162. return func_graph->NewCNodeInOrder(inputs);
  163. }
  164. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
  165. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  166. MS_EXCEPTION_IF_NULL(func_graph);
  167. MS_EXCEPTION_IF_NULL(type);
  168. std::size_t size = type->elements().size();
  169. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  170. auto lhs = std::static_pointer_cast<Tuple>(item.second);
  171. MS_EXCEPTION_IF_NULL(lhs);
  172. return lhs->elements().size() != size;
  173. });
  174. if (is_not_same) {
  175. MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length";
  176. }
  177. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  178. // hypermap and graph generated, it will cause memory leak.
  179. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  180. std::vector<AnfNodePtr> inputs;
  181. inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  182. for (int64_t i = 0; i < SizeToLong(size); ++i) {
  183. std::vector<AnfNodePtr> inputs2;
  184. inputs2.push_back(fn_rec);
  185. if (fn_arg != nullptr) {
  186. inputs2.push_back(fn_arg);
  187. }
  188. (void)std::transform(
  189. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
  190. return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
  191. });
  192. inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
  193. }
  194. return func_graph->NewCNodeInOrder(inputs);
  195. }
  196. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
  197. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  198. MS_EXCEPTION_IF_NULL(type);
  199. MS_EXCEPTION_IF_NULL(func_graph);
  200. std::vector<AnfNodePtr> inputs;
  201. inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
  202. inputs.push_back(NewValueNode(type));
  203. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  204. // hypermap and graph generated, it will cause memory leak.
  205. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  206. std::size_t attrSize = type->GetAttributes().size();
  207. for (std::size_t i = 0; i < attrSize; ++i) {
  208. std::vector<AnfNodePtr> inputs2;
  209. inputs2.push_back(fn_rec);
  210. if (fn_arg) {
  211. inputs2.push_back(fn_arg);
  212. }
  213. int64_t j = 0;
  214. for (auto item : arg_map) {
  215. inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
  216. j++;
  217. }
  218. inputs.push_back(func_graph->NewCNodeInOrder(inputs2));
  219. }
  220. return func_graph->NewCNodeInOrder(inputs);
  221. }
  222. AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  223. bool found = false;
  224. TypeId id = kObjectTypeEnd;
  225. std::pair<AnfNodePtr, TypePtr> pair;
  226. for (auto &item : arg_map) {
  227. pair = item;
  228. id = item.second->type_id();
  229. if (nonleaf_.count(id)) {
  230. found = true;
  231. break;
  232. }
  233. }
  234. if (found) {
  235. // In a nonleaf situation, all arguments must have the same generic.
  236. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
  237. if (item.first != pair.first) {
  238. return item.second->type_id() != pair.second->type_id();
  239. }
  240. return false;
  241. });
  242. if (is_not_same) {
  243. std::ostringstream oss;
  244. oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
  245. << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  246. int64_t idx = 0;
  247. for (auto &item : arg_map) {
  248. oss << ++idx << ": " << item.second->ToString() << "\n";
  249. }
  250. MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
  251. }
  252. }
  253. switch (id) {
  254. case kObjectTypeList: {
  255. auto type = std::static_pointer_cast<List>(pair.second);
  256. return FullMake(type, func_graph, fn_arg, arg_map);
  257. }
  258. case kObjectTypeTuple: {
  259. auto type = std::static_pointer_cast<Tuple>(pair.second);
  260. return FullMake(type, func_graph, fn_arg, arg_map);
  261. }
  262. case kObjectTypeClass: {
  263. auto type = std::static_pointer_cast<Class>(pair.second);
  264. return FullMake(type, func_graph, fn_arg, arg_map);
  265. }
  266. default:
  267. return FullMake(pair.second, func_graph, fn_arg, arg_map);
  268. }
  269. }
  270. ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
  271. TypePtr type_tensor = std::make_shared<TensorType>();
  272. bool flag = std::any_of(
  273. args_spec_list.begin(), args_spec_list.end(),
  274. [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
  275. if (flag && broadcast_) {
  276. ArgsPairList ret;
  277. for (auto &item : args_spec_list) {
  278. if (!IsSubType(item.second, type_tensor)) {
  279. TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
  280. ret.push_back(std::make_pair(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimScalarToArray), item.first}),
  281. type_tensor_ele));
  282. } else {
  283. ret.push_back(std::make_pair(item.first, item.second));
  284. }
  285. }
  286. return ret;
  287. }
  288. return args_spec_list;
  289. }
  290. FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
  291. FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
  292. ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  293. ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  294. ptr_graph->debug_info()->set_name("hyper_map");
  295. AnfNodePtr ptrFnArg = nullptr;
  296. std::size_t i = 0;
  297. ArgsPairList argmap;
  298. ArgsPairList argmap2;
  299. if (fn_leaf_ == nullptr) {
  300. ptrFnArg = ptr_graph->add_parameter();
  301. i = 1;
  302. }
  303. std::size_t size = args_spec_list.size();
  304. for (; i < size; ++i) {
  305. argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
  306. }
  307. argmap2 = Harmonize(ptr_graph, argmap);
  308. ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
  309. return ptr_graph;
  310. }
  311. abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  312. if (fn_leaf_ == nullptr) {
  313. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  314. // Assert that hypermap's function param does not contain free variables
  315. if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
  316. auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
  317. auto func_graph = graph_func->func_graph();
  318. if (func_graph->parent() != nullptr) {
  319. MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
  320. }
  321. }
  322. }
  323. AbstractBasePtrList broadened;
  324. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
  325. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  326. MS_EXCEPTION_IF_NULL(arg);
  327. return arg->Broaden();
  328. });
  329. return broadened;
  330. }
  331. REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
  332. (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
  333. .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
  334. .def(py::init<>());
  335. }));
  336. FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
  337. MS_EXCEPTION_IF_NULL(sequeue);
  338. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  339. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  340. ret->debug_info()->set_name("tail");
  341. AnfNodePtr ptrTup = ret->add_parameter();
  342. std::vector<AnfNodePtr> elems;
  343. PrimitivePtr op = nullptr;
  344. if (sequeue->isa<AbstractTuple>()) {
  345. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  346. op = prim::kPrimTupleGetItem;
  347. } else {
  348. elems.push_back(NewValueNode(prim::kPrimMakeList));
  349. op = prim::kPrimListGetItem;
  350. }
  351. if (tail_type_ == kGradFirst) {
  352. if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && (*sequeue)[1]->isa<abstract::AbstractUndetermined>()) {
  353. ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
  354. } else {
  355. ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
  356. }
  357. return ret;
  358. }
  359. for (size_t i = 1; i < sequeue->size(); ++i) {
  360. if (tail_type_ == kGradAll) {
  361. MS_EXCEPTION_IF_NULL((*sequeue)[i]);
  362. if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>()) {
  363. elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
  364. }
  365. } else {
  366. elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
  367. }
  368. }
  369. ret->set_output(ret->NewCNodeInOrder(elems));
  370. return ret;
  371. }
  372. FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  373. if (args_spec_list.size() != 1) {
  374. MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
  375. }
  376. AbstractBasePtr a = args_spec_list[0];
  377. if (a->isa<AbstractTuple>() || a->isa<AbstractList>()) {
  378. return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>());
  379. }
  380. MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
  381. }
  382. REGISTER_PYBIND_DEFINE(
  383. Tail_, ([](const py::module *m) {
  384. (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
  385. }));
  386. FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  387. int64_t tuple_size = SizeToLong(args_spec_list.size());
  388. std::ostringstream ss;
  389. ss << "▶make_tuple_" << tuple_size;
  390. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  391. fg->debug_info()->set_name(ss.str());
  392. std::vector<AnfNodePtr> params;
  393. params.push_back(NewValueNode(prim::kPrimMakeTuple));
  394. for (int64_t i = 0; i < tuple_size; ++i) {
  395. params.push_back(fg->add_parameter());
  396. }
  397. // make fprob first result, maketuple's forward result.
  398. AnfNodePtr out = fg->NewCNodeInOrder(params);
  399. // make fprob second result, maketuple's backward function.
  400. FuncGraphPtr b = std::make_shared<FuncGraph>();
  401. ss.clear();
  402. ss << "◀make_tuple_" << tuple_size;
  403. b->debug_info()->set_name(ss.str());
  404. AnfNodePtr dout = b->add_parameter();
  405. std::vector<AnfNodePtr> grads;
  406. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  407. grads.push_back(NewValueNode(newenv));
  408. for (int64_t i = 0; i < tuple_size; ++i) {
  409. grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
  410. }
  411. b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  412. b->set_output(b->NewCNodeInOrder(grads));
  413. fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  414. fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  415. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
  416. return fg;
  417. }
  418. FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  419. int64_t list_size = SizeToLong(args_spec_list.size());
  420. std::ostringstream ss;
  421. ss << "▶make_list_" << list_size;
  422. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  423. fg->debug_info()->set_name(ss.str());
  424. std::vector<AnfNodePtr> params;
  425. params.push_back(NewValueNode(prim::kPrimMakeList));
  426. for (int64_t i = 0; i < list_size; ++i) {
  427. params.push_back(fg->add_parameter());
  428. }
  429. // make fprob first result, maketuple's forward result.
  430. AnfNodePtr out = fg->NewCNodeInOrder(params);
  431. // make fprob second result, maketuple's backward function.
  432. FuncGraphPtr b = std::make_shared<FuncGraph>();
  433. ss.clear();
  434. ss << "◀make_list_" << list_size;
  435. b->debug_info()->set_name(ss.str());
  436. AnfNodePtr dout = b->add_parameter();
  437. std::vector<AnfNodePtr> grads;
  438. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  439. grads.push_back(NewValueNode(newenv));
  440. for (int64_t i = 0; i < list_size; ++i) {
  441. grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
  442. }
  443. b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  444. b->set_output(b->NewCNodeInOrder(grads));
  445. fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  446. fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  447. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
  448. return fg;
  449. }
  450. GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
  451. : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
  452. if (get_by_list) {
  453. signatures_ =
  454. // def grad(func:read, weight_list:ref):
  455. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  456. {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
  457. }
  458. }
  459. FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
  460. const std::vector<AnfNodePtr> &forward_graph_params,
  461. const std::vector<AnfNodePtr> &weight_args) {
  462. FuncGraphPtr k_child = std::make_shared<FuncGraph>();
  463. k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  464. AnfNodePtr weights_node = nullptr;
  465. if (weights != nullptr) {
  466. weights_node = weights;
  467. } else if (!weight_args.empty()) {
  468. weights_node = k_child->NewCNodeInOrder(weight_args);
  469. }
  470. std::vector<AnfNodePtr> inputs;
  471. inputs.push_back(k);
  472. for (size_t i = 0; i < forward_graph_params.size(); ++i) {
  473. inputs.push_back(k_child->add_parameter());
  474. }
  475. auto k_app = k_child->NewCNodeInOrder(inputs);
  476. auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
  477. auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
  478. auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
  479. GradByParameter(k_child, f_app, bprop, weights_node);
  480. return k_child;
  481. }
  482. // Do grad by the parameter of GradOperation.
  483. void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
  484. const AnfNodePtr &weights) {
  485. MS_EXCEPTION_IF_NULL(k_child);
  486. AnfNodePtr bprop_arg = nullptr;
  487. if (sens_param_) {
  488. bprop_arg = k_child->add_parameter();
  489. } else {
  490. auto ones_like = prim::GetPythonOps("ones_like");
  491. bprop_arg = k_child->NewCNodeInOrder({NewValueNode(ones_like), f_app});
  492. }
  493. AnfNodePtr b_app = k_child->NewCNodeInOrder({bprop, bprop_arg});
  494. CNodePtr fv_bprop = nullptr;
  495. if (get_by_list_) {
  496. // python code: grads = hyper_map(F.partial(env_get, env), weights)
  497. AnfNodePtr env =
  498. k_child->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
  499. AnfNodePtr partial_env_get =
  500. k_child->NewCNodeInOrder({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
  501. MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
  502. fv_bprop = k_child->NewCNodeInOrder({NewValueNode(hyper_map), partial_env_get, weights});
  503. }
  504. CNodePtr inputs_bprop = nullptr;
  505. if (get_all_) {
  506. TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
  507. inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app});
  508. }
  509. // Gradients wrt inputs and parameters
  510. if (fv_bprop != nullptr && inputs_bprop != nullptr) {
  511. k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
  512. return;
  513. }
  514. // Gradients wrt parameters
  515. if (fv_bprop != nullptr) {
  516. k_child->set_output(fv_bprop);
  517. return;
  518. }
  519. // Gradients wrt inputs
  520. if (inputs_bprop != nullptr) {
  521. k_child->set_output(inputs_bprop);
  522. return;
  523. }
  524. // Gradients wrt first input.
  525. // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
  526. // so obtain first input grad by setting tail_type of Tail to kGradFirst.
  527. TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
  528. k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
  529. }
  530. // Generate the graph.
  531. FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  532. if (args_spec_list.empty()) {
  533. MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
  534. << args_spec_list.size() << ".";
  535. }
  536. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  537. AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
  538. if (fn == nullptr) {
  539. MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  540. }
  541. // Waiting for implementation.
  542. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
  543. MS_EXCEPTION_IF_NULL(real_fn);
  544. FuncGraphPtr forward_graph = real_fn->func_graph();
  545. MS_EXCEPTION_IF_NULL(forward_graph);
  546. FuncGraphPtr grad_fg = nullptr;
  547. {
  548. TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
  549. grad_fg = std::make_shared<FuncGraph>();
  550. }
  551. auto nparam = forward_graph->parameters().size();
  552. std::ostringstream ss;
  553. ss << "grad{" << nparam << "}";
  554. grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  555. grad_fg->debug_info()->set_name(ss.str());
  556. ParameterPtr param_graph = grad_fg->add_parameter();
  557. AnfNodePtr weights = nullptr;
  558. if (get_by_list_) {
  559. weights = grad_fg->add_parameter();
  560. }
  561. std::vector<AnfNodePtr> inputs;
  562. inputs.push_back(NewValueNode(prim::kPrimJ));
  563. inputs.push_back(param_graph);
  564. auto j = grad_fg->NewCNodeInOrder(inputs);
  565. // df is checked in GetGrad
  566. FuncGraphPtr k_child = nullptr;
  567. {
  568. TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
  569. k_child = GetGrad(j, weights, forward_graph->parameters());
  570. }
  571. grad_fg->set_output(NewValueNode(k_child));
  572. return grad_fg;
  573. }
  574. REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
  575. (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
  576. *m, "GradOperation_")
  577. .def(py::init<std::string &>(), py::arg("fn"))
  578. .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
  579. py::arg("get_by_list"), py::arg("sens_param"));
  580. }));
  581. // Generate the ListMap func graph.
  582. FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  583. size_t args_num = args_spec_list.size();
  584. // args: fn, list1, list2, ...
  585. if (args_num < 2) {
  586. MS_LOG(EXCEPTION) << "list_map takes at least two arguments";
  587. }
  588. for (size_t i = 1; i < args_num; ++i) {
  589. if (typeid(args_spec_list[i]) != typeid(AbstractBase)) {
  590. // The function currently not be use
  591. MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'";
  592. }
  593. }
  594. FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
  595. fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  596. fg_ptr->debug_info()->set_name("list_map");
  597. AnfNodePtr fn = fg_ptr->add_parameter();
  598. std::vector<AnfNodePtr> lists;
  599. for (size_t i = 1; i < args_num; ++i) {
  600. lists.push_back(fg_ptr->add_parameter());
  601. }
  602. std::vector<AnfNodePtr> iters;
  603. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  604. return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("list_iter")), item});
  605. });
  606. std::vector<AnfNodePtr> nexts;
  607. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  608. return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item});
  609. });
  610. std::vector<AnfNodePtr> values;
  611. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  612. return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item});
  613. });
  614. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  615. return fg_ptr->NewCNodeInOrder(
  616. {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast<int64_t>(1))});
  617. });
  618. (void)values.insert(values.begin(), fn);
  619. AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values);
  620. AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimMakeList), cnode_graph});
  621. FuncGraphPtr fgnext_ptr = std::make_shared<FuncGraph>();
  622. fgnext_ptr->debug_info()->set_name("body");
  623. FuncGraphPtr fgcond_ptr = std::make_shared<FuncGraph>();
  624. fgcond_ptr->debug_info()->set_name("cond");
  625. MakeCond(lists, fgnext_ptr, fgcond_ptr);
  626. MakeNext(lists, fgcond_ptr, fgnext_ptr);
  627. CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl});
  628. auto inputs = output_cnode->inputs();
  629. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  630. output_cnode->set_inputs(inputs);
  631. fg_ptr->set_output(output_cnode);
  632. return fg_ptr;
  633. }
  634. void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
  635. const FuncGraphPtr &fg_ptr) {
  636. MS_EXCEPTION_IF_NULL(fg_ptr);
  637. AnfNodePtr fn = fg_ptr->add_parameter();
  638. AnfNodePtr resl = fg_ptr->add_parameter();
  639. std::vector<AnfNodePtr> iters;
  640. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  641. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  642. std::vector<AnfNodePtr> hasnexts;
  643. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) {
  644. return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("hasnext")), item});
  645. });
  646. // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
  647. FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
  648. fgtrue_ptr->debug_info()->set_name("ftrue");
  649. fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  650. CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNodeInOrder({NewValueNode(fgnext_ptr), fn, resl});
  651. auto inputs = fgtrue_output_cnode->inputs();
  652. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  653. fgtrue_output_cnode->set_inputs(inputs);
  654. fgtrue_ptr->set_output(fgtrue_output_cnode);
  655. FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
  656. fgfalse_ptr->debug_info()->set_name("ffalse");
  657. fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  658. fgfalse_ptr->set_output(resl);
  659. AnfNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
  660. NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)});
  661. fgtrue_ptr->set_output(output_cnode);
  662. }
  663. void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
  664. const FuncGraphPtr &fg_ptr) {
  665. MS_EXCEPTION_IF_NULL(fg_ptr);
  666. AnfNodePtr fn = fg_ptr->add_parameter();
  667. std::vector<AnfNodePtr> iters;
  668. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  669. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  670. std::vector<AnfNodePtr> nexts;
  671. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  672. return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item});
  673. });
  674. std::vector<AnfNodePtr> values;
  675. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  676. return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item, nullptr});
  677. });
  678. iters.clear();
  679. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  680. return fg_ptr->NewCNodeInOrder(
  681. {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast<int64_t>(1))});
  682. });
  683. (void)values.insert(values.begin(), fn);
  684. AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values);
  685. AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimListAppend), cnode_graph});
  686. CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl});
  687. auto inputs = output_cnode->inputs();
  688. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  689. output_cnode->set_inputs(inputs);
  690. fg_ptr->set_output(output_cnode);
  691. }
  692. FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  693. // args: tuple1, tuple2
  694. abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
  695. AbstractBasePtr abs_a = args_spec_list[0];
  696. AbstractBasePtr abs_b = args_spec_list[1];
  697. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
  698. abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
  699. if (a_tuple == nullptr || b_tuple == nullptr) {
  700. TypePtrList types;
  701. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
  702. [](const AbstractBasePtr &arg) -> TypePtr {
  703. MS_EXCEPTION_IF_NULL(arg);
  704. return arg->BuildType();
  705. });
  706. auto stub = GenerateStubFunc(types);
  707. if (stub != nullptr) {
  708. MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
  709. << ", function: " << stub->ToString();
  710. return stub;
  711. }
  712. MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
  713. << args_spec_list[1]->ToString();
  714. }
  715. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  716. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  717. AnfNodePtr p_tup_a = ret->add_parameter();
  718. AnfNodePtr p_tup_b = ret->add_parameter();
  719. std::vector<AnfNodePtr> elems;
  720. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  721. int64_t tuple_size = SizeToLong(a_tuple->size());
  722. for (int64_t i = 0; i < tuple_size; ++i) {
  723. elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
  724. }
  725. tuple_size = SizeToLong(b_tuple->size());
  726. for (int64_t i = 0; i < tuple_size; ++i) {
  727. elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
  728. }
  729. ret->set_output(ret->NewCNodeInOrder(elems));
  730. return ret;
  731. }
  732. int64_t GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
  733. MS_EXCEPTION_IF_NULL(scalar);
  734. return GetValue<int64_t>(scalar->BuildValue());
  735. }
  736. bool CheckIndexInRange(int64_t index, int64_t min, int64_t max) { return (index >= min && index <= max); }
  737. int64_t GetPositiveIndex(int64_t index, int64_t length) {
  738. if (index < 0) {
  739. index += length;
  740. }
  741. return index;
  742. }
  743. int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, const std::string &member_name) {
  744. MS_EXCEPTION_IF_NULL(member);
  745. if (member->isa<AbstractScalar>()) {
  746. return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
  747. }
  748. if (member->isa<AbstractNone>()) {
  749. return default_value;
  750. }
  751. MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString();
  752. }
  753. void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int64_t *start_index,
  754. int64_t *stop_index, int64_t *step_value) {
  755. MS_EXCEPTION_IF_NULL(tuple);
  756. MS_EXCEPTION_IF_NULL(slice);
  757. MS_EXCEPTION_IF_NULL(start_index);
  758. MS_EXCEPTION_IF_NULL(stop_index);
  759. MS_EXCEPTION_IF_NULL(step_value);
  760. const std::string start_name("Slice start index");
  761. const std::string stop_name("Slice stop index");
  762. const std::string step_name("Slice step value");
  763. int64_t tuple_size = SizeToLong(tuple->size());
  764. int64_t start_default = 0;
  765. int64_t stop_default = tuple_size;
  766. int64_t step_default = 1;
  767. *step_value = CheckSliceMember(slice->step(), step_default, step_name);
  768. if (*step_value == 0) {
  769. MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0.";
  770. }
  771. if (*step_value < 0) {
  772. start_default = tuple_size - 1;
  773. stop_default = -1;
  774. }
  775. *start_index = CheckSliceMember(slice->start(), start_default, start_name);
  776. *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
  777. if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) ||
  778. !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) {
  779. MS_EXCEPTION(ValueError) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
  780. << " out of range, tuple size " << tuple_size << ".";
  781. }
  782. *start_index = GetPositiveIndex(*start_index, tuple_size);
  783. if (!slice->stop()->isa<AbstractNone>()) {
  784. *stop_index = GetPositiveIndex(*stop_index, tuple_size);
  785. }
  786. }
  787. FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  788. // slice a tuple
  789. // args: tuple, start index, end index, step
  790. const std::string op_name("TupleSlice");
  791. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  792. AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  793. AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
  794. int64_t start_index;
  795. int64_t stop_index;
  796. int64_t step_value;
  797. GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
  798. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  799. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  800. AnfNodePtr p_tuple = ret->add_parameter();
  801. (void)ret->add_parameter();
  802. std::vector<AnfNodePtr> elems;
  803. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  804. if (step_value > 0) {
  805. for (int64_t index = start_index; index < stop_index; index = index + step_value) {
  806. elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  807. }
  808. } else {
  809. for (int64_t index = start_index; index > stop_index; index = index + step_value) {
  810. elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  811. }
  812. }
  813. ret->set_output(ret->NewCNodeInOrder(elems));
  814. return ret;
  815. }
  816. FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  817. // select indexed item
  818. // args: tuple of items, index
  819. const std::string op_name = std::string("TupleGetItemTensor");
  820. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  821. auto ret_graph = std::make_shared<FuncGraph>();
  822. ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  823. auto functions = ret_graph->add_parameter();
  824. auto index = ret_graph->add_parameter();
  825. ret_graph->set_output(ret_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
  826. return ret_graph;
  827. }
  828. REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
  829. (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
  830. .def(py::init<std::string &>());
  831. }));
  832. REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
  833. (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
  834. .def(py::init<std::string &>());
  835. }));
  836. REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
  837. (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
  838. *m, "TupleGetItemTensor_")
  839. .def(py::init<std::string &>());
  840. }));
  841. } // namespace prim
  842. } // namespace mindspore