You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.cc 44 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "frontend/operator/composite/composite.h"
  19. #include <algorithm>
  20. #include <utility>
  21. #include <sstream>
  22. #include "ir/anf.h"
  23. #include "ir/func_graph.h"
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/abstract_function.h"
  26. #include "abstract/dshape.h"
  27. #include "abstract/param_validator.h"
  28. #include "frontend/operator/cc_implementations.h"
  29. #include "frontend/optimizer/opt.h"
  30. #include "utils/symbolic.h"
  31. #include "pybind_api/api_register.h"
  32. #include "ir/signature.h"
  33. #include "debug/trace.h"
  34. #include "utils/ms_context.h"
  35. #include "utils/utils.h"
  36. namespace mindspore {
  37. // namespace to support composite operators definition
  38. namespace prim {
  39. using AbstractTensor = mindspore::abstract::AbstractTensor;
  40. using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
  41. using mindspore::abstract::AbstractAttribute;
  42. using mindspore::abstract::AbstractBase;
  43. using mindspore::abstract::AbstractClass;
  44. using mindspore::abstract::AbstractDictionary;
  45. using mindspore::abstract::AbstractDictionaryPtr;
  46. using mindspore::abstract::AbstractEllipsis;
  47. using mindspore::abstract::AbstractEllipsisPtr;
  48. using mindspore::abstract::AbstractFunction;
  49. using mindspore::abstract::AbstractFunctionPtr;
  50. using mindspore::abstract::AbstractList;
  51. using mindspore::abstract::AbstractNone;
  52. using mindspore::abstract::AbstractScalar;
  53. using mindspore::abstract::AbstractSlice;
  54. using mindspore::abstract::AbstractTuple;
  55. ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
  56. {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
  57. {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
  58. {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
  59. {"__ge__", kPrimScalarGe}};
  60. ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
  61. void HyperMap::Init() {
  62. if (fn_leaf_) {
  63. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  64. }
  65. signatures_ =
  66. // def hypermap(func:read, *args:ref):
  67. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  68. {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
  69. }
  70. HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
  71. : MetaFuncGraph("hyper_map"),
  72. fn_leaf_(fn_leaf),
  73. reverse_(reverse),
  74. broadcast_(false),
  75. nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
  76. Init();
  77. }
  78. HyperMap::HyperMap(const HyperMap &h)
  79. : MetaFuncGraph("hyper_map"),
  80. fn_leaf_(h.fn_leaf_),
  81. reverse_(h.reverse_),
  82. broadcast_(h.broadcast_),
  83. nonleaf_(h.nonleaf_) {
  84. Init();
  85. }
  86. AnfNodePtr HyperMap::FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  87. MS_EXCEPTION_IF_NULL(func_graph);
  88. std::vector<AnfNodePtr> inputs;
  89. if (fn_arg != nullptr) {
  90. inputs.push_back(fn_arg);
  91. } else {
  92. inputs.push_back(NewValueNode(fn_leaf_));
  93. }
  94. (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
  95. [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
  96. return func_graph->NewCNodeInOrder(inputs);
  97. }
  98. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
  99. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  100. MS_EXCEPTION_IF_NULL(func_graph);
  101. MS_EXCEPTION_IF_NULL(type);
  102. size_t size = type->elements().size();
  103. size_t num = 0;
  104. bool is_not_same =
  105. std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
  106. num++;
  107. auto lhs = std::static_pointer_cast<List>(item.second);
  108. if (lhs == nullptr) {
  109. MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a List, but got "
  110. << item.second->ToString();
  111. }
  112. if (lhs->elements().size() != size) {
  113. MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
  114. << lhs->elements().size();
  115. return true;
  116. }
  117. return false;
  118. });
  119. if (is_not_same) {
  120. MS_LOG(EXCEPTION) << "List in HyperMap should have same length.";
  121. }
  122. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  123. // hypermap and graph generated, it will cause memory leak.
  124. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  125. constexpr size_t kPrimHoldLen = 1;
  126. std::vector<AnfNodePtr> inputs;
  127. inputs.reserve(size + kPrimHoldLen);
  128. inputs.push_back(NewValueNode(prim::kPrimMakeList));
  129. for (size_t i = 0; i < size; i++) {
  130. MS_LOG(DEBUG) << "FullMakeList for the " << i << "th element of the target, reverse_: " << reverse_;
  131. std::vector<AnfNodePtr> inputs2;
  132. inputs2.push_back(fn_rec);
  133. if (fn_arg != nullptr) {
  134. inputs2.push_back(fn_arg);
  135. }
  136. size_t pos = (reverse_ ? (size - 1 - i) : i);
  137. (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
  138. [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
  139. return func_graph->NewCNodeInOrder(
  140. {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
  141. });
  142. auto call_node = func_graph->NewCNodeInOrder(inputs2);
  143. if (reverse_) {
  144. inputs.insert(inputs.begin() + 1, call_node);
  145. } else {
  146. inputs.emplace_back(call_node);
  147. }
  148. }
  149. return func_graph->NewCNodeInOrder(inputs);
  150. }
  151. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
  152. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  153. MS_EXCEPTION_IF_NULL(func_graph);
  154. MS_EXCEPTION_IF_NULL(type);
  155. size_t size = type->elements().size();
  156. size_t num = 0;
  157. bool is_not_same =
  158. std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
  159. num++;
  160. auto lhs = std::static_pointer_cast<Tuple>(item.second);
  161. if (lhs == nullptr) {
  162. MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a Tuple, but got "
  163. << item.second->ToString();
  164. }
  165. if (lhs->elements().size() != size) {
  166. MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
  167. << lhs->elements().size();
  168. return true;
  169. }
  170. return false;
  171. });
  172. if (is_not_same) {
  173. MS_LOG(EXCEPTION) << "Tuple in HyperMap should have same length.";
  174. }
  175. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  176. // hypermap and graph generated, it will cause memory leak.
  177. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  178. constexpr size_t kPrimHoldLen = 1;
  179. std::vector<AnfNodePtr> inputs;
  180. inputs.reserve(size + kPrimHoldLen);
  181. inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  182. for (size_t i = 0; i < size; i++) {
  183. MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th element of the target, reverse_: " << reverse_;
  184. std::vector<AnfNodePtr> inputs2;
  185. inputs2.push_back(fn_rec);
  186. if (fn_arg != nullptr) {
  187. inputs2.push_back(fn_arg);
  188. }
  189. size_t pos = (reverse_ ? (size - 1 - i) : i);
  190. (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
  191. [&func_graph, &pos](std::pair<AnfNodePtr, Any> item) {
  192. return func_graph->NewCNodeInOrder(
  193. {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
  194. });
  195. auto call_node = func_graph->NewCNodeInOrder(inputs2);
  196. if (reverse_) {
  197. inputs.insert(inputs.begin() + 1, call_node);
  198. } else {
  199. inputs.emplace_back(call_node);
  200. }
  201. }
  202. return func_graph->NewCNodeInOrder(inputs);
  203. }
  204. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
  205. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  206. MS_EXCEPTION_IF_NULL(type);
  207. MS_EXCEPTION_IF_NULL(func_graph);
  208. std::size_t attrSize = type->GetAttributes().size();
  209. constexpr size_t kPrimAndTypeLen = 2;
  210. std::vector<AnfNodePtr> inputs;
  211. inputs.reserve(attrSize + kPrimAndTypeLen);
  212. inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
  213. inputs.push_back(NewValueNode(type));
  214. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  215. // hypermap and graph generated, it will cause memory leak.
  216. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  217. for (std::size_t i = 0; i < attrSize; i++) {
  218. MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the target, reverse_: " << reverse_;
  219. std::vector<AnfNodePtr> inputs2;
  220. inputs2.push_back(fn_rec);
  221. if (fn_arg) {
  222. inputs2.push_back(fn_arg);
  223. }
  224. size_t size = arg_map.size();
  225. for (size_t j = 0; j < size; j++) {
  226. size_t pos = (reverse_ ? (size - 1 - j) : j);
  227. auto &item = arg_map[pos];
  228. inputs2.push_back(
  229. func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
  230. }
  231. auto call_node = func_graph->NewCNodeInOrder(inputs2);
  232. if (reverse_) {
  233. inputs.insert(inputs.begin() + kPrimAndTypeLen, call_node);
  234. } else {
  235. inputs.emplace_back(call_node);
  236. }
  237. }
  238. return func_graph->NewCNodeInOrder(inputs);
  239. }
  240. AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  241. bool found = false;
  242. TypeId id = kObjectTypeEnd;
  243. std::pair<AnfNodePtr, TypePtr> pair;
  244. for (auto &item : arg_map) {
  245. pair = item;
  246. id = item.second->type_id();
  247. if (nonleaf_.count(id)) {
  248. found = true;
  249. break;
  250. }
  251. }
  252. if (found) {
  253. // In a nonleaf situation, all arguments must have the same generic.
  254. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
  255. if (item.first != pair.first) {
  256. return item.second->type_id() != pair.second->type_id();
  257. }
  258. return false;
  259. });
  260. if (is_not_same) {
  261. std::ostringstream oss;
  262. oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
  263. << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  264. int64_t idx = 0;
  265. for (auto &item : arg_map) {
  266. oss << ++idx << ": " << item.second->ToString() << "\n";
  267. }
  268. MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
  269. }
  270. }
  271. switch (id) {
  272. case kObjectTypeList: {
  273. auto type = std::static_pointer_cast<List>(pair.second);
  274. return FullMake(type, func_graph, fn_arg, arg_map);
  275. }
  276. case kObjectTypeTuple: {
  277. auto type = std::static_pointer_cast<Tuple>(pair.second);
  278. return FullMake(type, func_graph, fn_arg, arg_map);
  279. }
  280. case kObjectTypeClass: {
  281. auto type = std::static_pointer_cast<Class>(pair.second);
  282. return FullMake(type, func_graph, fn_arg, arg_map);
  283. }
  284. default:
  285. return FullMake(func_graph, fn_arg, arg_map);
  286. }
  287. }
  288. ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
  289. TypePtr type_tensor = std::make_shared<TensorType>();
  290. bool flag = std::any_of(
  291. args_spec_list.begin(), args_spec_list.end(),
  292. [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
  293. if (flag && broadcast_) {
  294. ArgsPairList ret;
  295. for (auto &item : args_spec_list) {
  296. if (!IsSubType(item.second, type_tensor)) {
  297. TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
  298. ret.push_back(std::make_pair(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimScalarToArray), item.first}),
  299. type_tensor_ele));
  300. } else {
  301. ret.push_back(std::make_pair(item.first, item.second));
  302. }
  303. }
  304. return ret;
  305. }
  306. return args_spec_list;
  307. }
  308. FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
  309. FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
  310. ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  311. ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  312. ptr_graph->debug_info()->set_name("hyper_map");
  313. AnfNodePtr ptrFnArg = nullptr;
  314. std::size_t i = 0;
  315. ArgsPairList argmap;
  316. ArgsPairList argmap2;
  317. if (fn_leaf_ == nullptr) {
  318. ptrFnArg = ptr_graph->add_parameter();
  319. i = 1;
  320. }
  321. std::size_t size = args_spec_list.size();
  322. for (; i < size; ++i) {
  323. argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
  324. }
  325. argmap2 = Harmonize(ptr_graph, argmap);
  326. ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
  327. return ptr_graph;
  328. }
  329. abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  330. if (fn_leaf_ == nullptr) {
  331. if (args_spec_list.empty()) {
  332. MS_LOG(EXCEPTION) << "The args spec list is empty.";
  333. }
  334. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  335. // Assert that hypermap's function param does not contain free variables
  336. if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
  337. auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
  338. auto func_graph = graph_func->func_graph();
  339. if (func_graph->parent() != nullptr) {
  340. MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
  341. }
  342. }
  343. }
  344. AbstractBasePtrList broadened;
  345. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
  346. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  347. MS_EXCEPTION_IF_NULL(arg);
  348. return arg->Broaden();
  349. });
  350. return broadened;
  351. }
  352. REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
  353. (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
  354. .def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
  355. py::arg("ops"))
  356. .def(py::init<bool>(), py::arg("reverse"));
  357. }));
  358. bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
  359. for (size_t i = 0; i < tuple->size(); ++i) {
  360. if (!(*tuple)[i]->isa<abstract::AbstractUndetermined>() &&
  361. !((*tuple)[i]->isa<abstract::AbstractTuple>() &&
  362. CheckSequenceAllTensor((*tuple)[i]->cast<abstract::AbstractTuplePtr>()))) {
  363. return false;
  364. }
  365. }
  366. return true;
  367. }
  368. bool CheckTailGradFristSequence(const abstract::AbstractSequeuePtr &sequeue, bool enable_tuple_grad) {
  369. return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
  370. ((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
  371. (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
  372. (*sequeue)[1]->BuildType()->isa<Number>()) ||
  373. ((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
  374. CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
  375. }
  376. FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue,
  377. const abstract::AbstractSequeuePtr &pos) const {
  378. MS_EXCEPTION_IF_NULL(sequeue);
  379. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  380. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  381. ret->debug_info()->set_name("tail");
  382. AnfNodePtr ptrTup = ret->add_parameter();
  383. std::vector<AnfNodePtr> elems;
  384. PrimitivePtr op = nullptr;
  385. if (sequeue->isa<AbstractTuple>()) {
  386. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  387. op = prim::kPrimTupleGetItem;
  388. } else {
  389. elems.push_back(NewValueNode(prim::kPrimMakeList));
  390. op = prim::kPrimListGetItem;
  391. }
  392. if (tail_type_ == kGradFirst) {
  393. if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
  394. ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
  395. } else {
  396. ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
  397. }
  398. return ret;
  399. }
  400. if (tail_type_ == kGradByPosition) {
  401. if (pos == nullptr) {
  402. MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
  403. }
  404. std::vector<AnfNodePtr> pos_elems;
  405. PrimitivePtr pos_op = nullptr;
  406. if (pos->isa<AbstractTuple>()) {
  407. pos_elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  408. pos_op = prim::kPrimTupleGetItem;
  409. } else {
  410. pos_elems.push_back(NewValueNode(prim::kPrimMakeList));
  411. pos_op = prim::kPrimListGetItem;
  412. }
  413. AnfNodePtr pos_value = nullptr;
  414. AnfNodePtr pos_value_adjust = nullptr;
  415. auto ptrpos = ret->add_parameter();
  416. if (pos->size() == 1) {
  417. pos_value = ret->NewCNode({NewValueNode(pos_op), ptrpos, NewValueNode(SizeToLong(0))});
  418. pos_value_adjust = ret->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
  419. if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
  420. ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, pos_value_adjust}));
  421. } else {
  422. ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
  423. }
  424. return ret;
  425. } else {
  426. for (size_t i = 0; i < pos->size(); ++i) {
  427. pos_value = ret->NewCNode({NewValueNode(pos_op), ptrpos, NewValueNode(SizeToLong(i))});
  428. pos_value_adjust = ret->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
  429. pos_elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, pos_value_adjust}));
  430. }
  431. }
  432. ret->set_output(ret->NewCNodeInOrder(pos_elems));
  433. return ret;
  434. }
  435. for (size_t i = 1; i < sequeue->size(); ++i) {
  436. if (tail_type_ == kGradAll) {
  437. MS_EXCEPTION_IF_NULL((*sequeue)[i]);
  438. if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
  439. (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
  440. (*sequeue)[i]->BuildType()->isa<Number>())) {
  441. elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
  442. }
  443. } else {
  444. elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
  445. }
  446. }
  447. ret->set_output(ret->NewCNodeInOrder(elems));
  448. return ret;
  449. }
  450. FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  451. if (args_spec_list.size() < 1) {
  452. MS_LOG(EXCEPTION) << "Tail requires a non-empty tuple.";
  453. }
  454. AbstractBasePtr a = args_spec_list[0];
  455. if (a->isa<AbstractTuple>() || a->isa<AbstractList>()) {
  456. if (args_spec_list.size() > 1) {
  457. AbstractBasePtr pos = args_spec_list[1];
  458. if (pos->isa<AbstractTuple>() || pos->isa<AbstractList>()) {
  459. return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>(),
  460. pos->cast<abstract::AbstractSequeuePtr>());
  461. }
  462. MS_LOG(EXCEPTION) << "'Tail' arg1 must be AbstractTuple or AbstractList, but: " << pos->ToString();
  463. }
  464. return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>());
  465. }
  466. MS_LOG(EXCEPTION) << "'Tail' arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
  467. }
  468. REGISTER_PYBIND_DEFINE(
  469. Tail_, ([](const py::module *m) {
  470. (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
  471. }));
  472. FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  473. int64_t tuple_size = SizeToLong(args_spec_list.size());
  474. std::ostringstream ss;
  475. ss << "▶make_tuple_" << tuple_size;
  476. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  477. fg->debug_info()->set_name(ss.str());
  478. std::vector<AnfNodePtr> params;
  479. params.push_back(NewValueNode(prim::kPrimMakeTuple));
  480. for (int64_t i = 0; i < tuple_size; ++i) {
  481. params.push_back(fg->add_parameter());
  482. }
  483. // make fprob first result, maketuple's forward result.
  484. AnfNodePtr out = fg->NewCNodeInOrder(params);
  485. // make fprob second result, maketuple's backward function.
  486. FuncGraphPtr b = std::make_shared<FuncGraph>();
  487. ss.clear();
  488. ss << "◀make_tuple_" << tuple_size;
  489. b->debug_info()->set_name(ss.str());
  490. AnfNodePtr dout = b->add_parameter();
  491. std::vector<AnfNodePtr> grads;
  492. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  493. grads.push_back(NewValueNode(newenv));
  494. for (int64_t i = 0; i < tuple_size; ++i) {
  495. grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
  496. }
  497. b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  498. b->set_output(b->NewCNodeInOrder(grads));
  499. fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  500. fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  501. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
  502. return fg;
  503. }
  504. FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  505. int64_t list_size = SizeToLong(args_spec_list.size());
  506. std::ostringstream ss;
  507. ss << "▶make_list_" << list_size;
  508. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  509. fg->debug_info()->set_name(ss.str());
  510. std::vector<AnfNodePtr> params;
  511. params.push_back(NewValueNode(prim::kPrimMakeList));
  512. for (int64_t i = 0; i < list_size; ++i) {
  513. params.push_back(fg->add_parameter());
  514. }
  515. // make fprob first result, maketuple's forward result.
  516. AnfNodePtr out = fg->NewCNodeInOrder(params);
  517. // make fprob second result, maketuple's backward function.
  518. FuncGraphPtr b = std::make_shared<FuncGraph>();
  519. ss.clear();
  520. ss << "◀make_list_" << list_size;
  521. b->debug_info()->set_name(ss.str());
  522. AnfNodePtr dout = b->add_parameter();
  523. std::vector<AnfNodePtr> grads;
  524. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  525. grads.push_back(NewValueNode(newenv));
  526. for (int64_t i = 0; i < list_size; ++i) {
  527. grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
  528. }
  529. b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  530. b->set_output(b->NewCNodeInOrder(grads));
  531. fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  532. fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  533. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
  534. return fg;
  535. }
  536. GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param,
  537. bool get_by_position)
  538. : MetaFuncGraph(name),
  539. get_all_(get_all),
  540. get_by_list_(get_by_list),
  541. sens_param_(sens_param),
  542. get_by_position_(get_by_position) {
  543. if (get_by_position) {
  544. signatures_ =
  545. // def grad(func:read, weight_list:ref, position_list:ref):
  546. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  547. {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault},
  548. {"position_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
  549. } else if (get_by_list) {
  550. signatures_ =
  551. // def grad(func:read, weight_list:ref):
  552. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  553. {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
  554. }
  555. }
  556. FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, const AnfNodePtr &position,
  557. const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
  558. const std::vector<AnfNodePtr> &weight_args) {
  559. FuncGraphPtr k_child = std::make_shared<FuncGraph>();
  560. k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  561. AnfNodePtr weights_node = nullptr;
  562. AnfNodePtr position_node = nullptr;
  563. if (weights != nullptr) {
  564. weights_node = weights;
  565. } else if (!weight_args.empty()) {
  566. weights_node = k_child->NewCNodeInOrder(weight_args);
  567. }
  568. if (position != nullptr) {
  569. position_node = position;
  570. }
  571. std::vector<AnfNodePtr> inputs;
  572. inputs.push_back(k);
  573. for (size_t i = 0; i < forward_graph_params.size(); ++i) {
  574. inputs.push_back(k_child->add_parameter());
  575. }
  576. auto k_app = k_child->NewCNodeInOrder(inputs);
  577. auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
  578. auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
  579. auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
  580. GradByParameter(k_child, f_app, bprop, weights_node, position_node, enable_tuple_grad);
  581. return k_child;
  582. }
  583. // Do grad by the parameter of GradOperation.
  584. void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
  585. const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad) {
  586. MS_EXCEPTION_IF_NULL(k_child);
  587. AnfNodePtr bprop_arg = nullptr;
  588. if (sens_param_) {
  589. bprop_arg = k_child->add_parameter();
  590. } else {
  591. auto ones_like = prim::GetPythonOps("ones_like");
  592. bprop_arg = k_child->NewCNodeInOrder({NewValueNode(ones_like), f_app});
  593. }
  594. AnfNodePtr b_app = k_child->NewCNodeInOrder({bprop, bprop_arg});
  595. CNodePtr fv_bprop = nullptr;
  596. if (get_by_list_) {
  597. // python code: grads = hyper_map(F.partial(env_get, env), weights)
  598. AnfNodePtr env =
  599. k_child->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
  600. AnfNodePtr partial_env_get =
  601. k_child->NewCNodeInOrder({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
  602. MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
  603. fv_bprop = k_child->NewCNodeInOrder({NewValueNode(hyper_map), partial_env_get, weights});
  604. }
  605. CNodePtr inputs_bprop = nullptr;
  606. if (get_by_position_) {
  607. TailPtr tail_grad_by_position = std::make_shared<Tail>("tail_grad_by_position", kGradByPosition);
  608. inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_by_position), b_app, position});
  609. k_child->set_output(inputs_bprop);
  610. return;
  611. }
  612. if (get_all_) {
  613. TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
  614. inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app});
  615. }
  616. // Gradients wrt inputs and parameters
  617. if (fv_bprop != nullptr && inputs_bprop != nullptr) {
  618. k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
  619. return;
  620. }
  621. // Gradients wrt parameters
  622. if (fv_bprop != nullptr) {
  623. k_child->set_output(fv_bprop);
  624. return;
  625. }
  626. // Gradients wrt inputs
  627. if (inputs_bprop != nullptr) {
  628. k_child->set_output(inputs_bprop);
  629. return;
  630. }
  631. // Gradients wrt first input.
  632. // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
  633. // so obtain first input grad by setting tail_type of Tail to kGradFirst.
  634. TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
  635. tail_grad_first->set_enable_tuple_grad(enable_tuple_grad);
  636. k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
  637. }
  638. // Generate the graph.
  639. FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  640. if (args_spec_list.empty()) {
  641. MS_LOG(EXCEPTION)
  642. << "'GradOperation' requires a forward network or function as an input, while the input is empty.";
  643. }
  644. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  645. AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
  646. if (fn == nullptr) {
  647. MS_LOG(EXCEPTION) << "'GradOperation' arg0 must be a 'Function' or 'Cell', but got "
  648. << args_spec_list[0]->ToString();
  649. }
  650. // Waiting for implementation.
  651. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
  652. MS_EXCEPTION_IF_NULL(real_fn);
  653. FuncGraphPtr forward_graph = real_fn->func_graph();
  654. MS_EXCEPTION_IF_NULL(forward_graph);
  655. forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
  656. FuncGraphPtr grad_fg = nullptr;
  657. {
  658. TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
  659. grad_fg = std::make_shared<FuncGraph>();
  660. }
  661. auto nparam = forward_graph->parameters().size();
  662. std::ostringstream ss;
  663. ss << "grad{" << nparam << "}";
  664. grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  665. grad_fg->debug_info()->set_name(ss.str());
  666. ParameterPtr param_graph = grad_fg->add_parameter();
  667. AnfNodePtr weights = nullptr;
  668. if (get_by_list_) {
  669. weights = grad_fg->add_parameter();
  670. }
  671. AnfNodePtr position = nullptr;
  672. if (get_by_position_) {
  673. weights = grad_fg->add_parameter();
  674. position = grad_fg->add_parameter();
  675. }
  676. std::vector<AnfNodePtr> inputs;
  677. inputs.push_back(NewValueNode(prim::kPrimJ));
  678. inputs.push_back(param_graph);
  679. auto j = grad_fg->NewCNodeInOrder(inputs);
  680. // df is checked in GetGrad
  681. FuncGraphPtr k_child = nullptr;
  682. {
  683. TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
  684. k_child = GetGrad(j, weights, position, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad"));
  685. }
  686. grad_fg->set_output(NewValueNode(k_child));
  687. return grad_fg;
  688. }
  689. REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
  690. (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
  691. *m, "GradOperation_")
  692. .def(py::init<std::string &>(), py::arg("fn"))
  693. .def(py::init<std::string &, bool, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
  694. py::arg("get_by_list"), py::arg("sens_param"), py::arg("get_by_position"));
  695. }));
  696. // Generate the ListMap func graph.
  697. FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  698. size_t args_num = args_spec_list.size();
  699. // args: fn, list1, list2, ...
  700. if (args_num < 2) {
  701. MS_LOG(EXCEPTION) << "list_map takes at least two arguments";
  702. }
  703. for (size_t i = 1; i < args_num; ++i) {
  704. if (typeid(args_spec_list[i]) != typeid(AbstractBase)) {
  705. // The function currently not be use
  706. MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'";
  707. }
  708. }
  709. FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
  710. fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  711. fg_ptr->debug_info()->set_name("list_map");
  712. AnfNodePtr fn = fg_ptr->add_parameter();
  713. std::vector<AnfNodePtr> lists;
  714. for (size_t i = 1; i < args_num; ++i) {
  715. lists.push_back(fg_ptr->add_parameter());
  716. }
  717. std::vector<AnfNodePtr> iters;
  718. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  719. return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("list_iter")), item});
  720. });
  721. std::vector<AnfNodePtr> nexts;
  722. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  723. return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item});
  724. });
  725. std::vector<AnfNodePtr> values;
  726. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  727. return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item});
  728. });
  729. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  730. return fg_ptr->NewCNodeInOrder(
  731. {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast<int64_t>(1))});
  732. });
  733. (void)values.insert(values.begin(), fn);
  734. AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values);
  735. AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimMakeList), cnode_graph});
  736. FuncGraphPtr fgnext_ptr = std::make_shared<FuncGraph>();
  737. fgnext_ptr->debug_info()->set_name("body");
  738. FuncGraphPtr fgcond_ptr = std::make_shared<FuncGraph>();
  739. fgcond_ptr->debug_info()->set_name("cond");
  740. MakeCond(lists, fgnext_ptr, fgcond_ptr);
  741. MakeNext(lists, fgcond_ptr, fgnext_ptr);
  742. CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl});
  743. auto inputs = output_cnode->inputs();
  744. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  745. output_cnode->set_inputs(inputs);
  746. fg_ptr->set_output(output_cnode);
  747. return fg_ptr;
  748. }
  749. void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
  750. const FuncGraphPtr &fg_ptr) {
  751. MS_EXCEPTION_IF_NULL(fg_ptr);
  752. AnfNodePtr fn = fg_ptr->add_parameter();
  753. AnfNodePtr resl = fg_ptr->add_parameter();
  754. std::vector<AnfNodePtr> iters;
  755. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  756. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  757. std::vector<AnfNodePtr> hasnexts;
  758. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) {
  759. return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("hasnext")), item});
  760. });
  761. // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
  762. FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
  763. fgtrue_ptr->debug_info()->set_name("ftrue");
  764. fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  765. CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNodeInOrder({NewValueNode(fgnext_ptr), fn, resl});
  766. auto inputs = fgtrue_output_cnode->inputs();
  767. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  768. fgtrue_output_cnode->set_inputs(inputs);
  769. fgtrue_ptr->set_output(fgtrue_output_cnode);
  770. FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
  771. fgfalse_ptr->debug_info()->set_name("ffalse");
  772. fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  773. fgfalse_ptr->set_output(resl);
  774. AnfNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
  775. NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)});
  776. fgtrue_ptr->set_output(output_cnode);
  777. }
  778. void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
  779. const FuncGraphPtr &fg_ptr) {
  780. MS_EXCEPTION_IF_NULL(fg_ptr);
  781. AnfNodePtr fn = fg_ptr->add_parameter();
  782. std::vector<AnfNodePtr> iters;
  783. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  784. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  785. std::vector<AnfNodePtr> nexts;
  786. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  787. return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item});
  788. });
  789. std::vector<AnfNodePtr> values;
  790. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  791. return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item, nullptr});
  792. });
  793. iters.clear();
  794. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  795. return fg_ptr->NewCNodeInOrder(
  796. {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast<int64_t>(1))});
  797. });
  798. (void)values.insert(values.begin(), fn);
  799. AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values);
  800. AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimListAppend), cnode_graph});
  801. CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl});
  802. auto inputs = output_cnode->inputs();
  803. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  804. output_cnode->set_inputs(inputs);
  805. fg_ptr->set_output(output_cnode);
  806. }
  807. FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  808. // args: tuple1, tuple2
  809. abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
  810. AbstractBasePtr abs_a = args_spec_list[0];
  811. AbstractBasePtr abs_b = args_spec_list[1];
  812. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
  813. abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
  814. if (a_tuple == nullptr || b_tuple == nullptr) {
  815. TypePtrList types;
  816. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
  817. [](const AbstractBasePtr &arg) -> TypePtr {
  818. MS_EXCEPTION_IF_NULL(arg);
  819. return arg->BuildType();
  820. });
  821. auto stub = GenerateStubFunc(types);
  822. if (stub != nullptr) {
  823. MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
  824. << ", function: " << stub->ToString();
  825. return stub;
  826. }
  827. MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple, but " << args_spec_list[0]->ToString() << ", "
  828. << args_spec_list[1]->ToString();
  829. }
  830. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  831. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  832. AnfNodePtr p_tup_a = ret->add_parameter();
  833. AnfNodePtr p_tup_b = ret->add_parameter();
  834. std::vector<AnfNodePtr> elems;
  835. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  836. int64_t tuple_size = SizeToLong(a_tuple->size());
  837. for (int64_t i = 0; i < tuple_size; ++i) {
  838. elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
  839. }
  840. tuple_size = SizeToLong(b_tuple->size());
  841. for (int64_t i = 0; i < tuple_size; ++i) {
  842. elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
  843. }
  844. ret->set_output(ret->NewCNodeInOrder(elems));
  845. return ret;
  846. }
  847. int64_t GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
  848. MS_EXCEPTION_IF_NULL(scalar);
  849. return GetValue<int64_t>(scalar->BuildValue());
  850. }
  851. int64_t GetPositiveIndex(int64_t index, int64_t length) {
  852. if (index < 0) {
  853. index += length;
  854. }
  855. return index;
  856. }
  857. int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, const std::string &member_name) {
  858. MS_EXCEPTION_IF_NULL(member);
  859. if (member->isa<AbstractScalar>()) {
  860. return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
  861. }
  862. if (member->isa<AbstractNone>()) {
  863. return default_value;
  864. }
  865. MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString();
  866. }
  867. void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int64_t *start_index,
  868. int64_t *stop_index, int64_t *step_value) {
  869. MS_EXCEPTION_IF_NULL(tuple);
  870. MS_EXCEPTION_IF_NULL(slice);
  871. MS_EXCEPTION_IF_NULL(start_index);
  872. MS_EXCEPTION_IF_NULL(stop_index);
  873. MS_EXCEPTION_IF_NULL(step_value);
  874. const std::string start_name("Slice start index");
  875. const std::string stop_name("Slice stop index");
  876. const std::string step_name("Slice step value");
  877. int64_t tuple_size = SizeToLong(tuple->size());
  878. int64_t start_default = 0;
  879. int64_t stop_default = tuple_size;
  880. int64_t step_default = 1;
  881. *step_value = CheckSliceMember(slice->step(), step_default, step_name);
  882. if (*step_value == 0) {
  883. MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0.";
  884. }
  885. if (*step_value < 0) {
  886. start_default = tuple_size - 1;
  887. stop_default = -1;
  888. }
  889. *start_index = CheckSliceMember(slice->start(), start_default, start_name);
  890. *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
  891. if (*start_index < -tuple_size) *start_index = 0;
  892. if (*stop_index > tuple_size) *stop_index = tuple_size;
  893. if (*start_index > tuple_size || *stop_index < -tuple_size) {
  894. *start_index = 0;
  895. *stop_index = 0;
  896. }
  897. *start_index = GetPositiveIndex(*start_index, tuple_size);
  898. if (!slice->stop()->isa<AbstractNone>()) {
  899. *stop_index = GetPositiveIndex(*stop_index, tuple_size);
  900. }
  901. }
  902. FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  903. // slice a tuple
  904. // args: tuple, start index, end index, step
  905. const std::string op_name("TupleSlice");
  906. constexpr size_t arg_size = 2;
  907. abstract::CheckArgsSize(op_name, args_spec_list, arg_size);
  908. AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  909. AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
  910. int64_t start_index;
  911. int64_t stop_index;
  912. int64_t step_value;
  913. GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
  914. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  915. ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  916. AnfNodePtr p_tuple = ret->add_parameter();
  917. (void)ret->add_parameter();
  918. std::vector<AnfNodePtr> elems;
  919. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  920. if (step_value > 0) {
  921. for (int64_t index = start_index; index < stop_index; index = index + step_value) {
  922. elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  923. }
  924. } else {
  925. for (int64_t index = start_index; index > stop_index; index = index + step_value) {
  926. elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  927. }
  928. }
  929. ret->set_output(ret->NewCNodeInOrder(elems));
  930. return ret;
  931. }
  932. FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  933. // select indexed item
  934. // args: tuple of items, index
  935. const std::string op_name = std::string("TupleGetItemTensor");
  936. const size_t inputs_size = 2;
  937. abstract::CheckArgsSize(op_name, args_spec_list, inputs_size);
  938. auto ret_graph = std::make_shared<FuncGraph>();
  939. ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  940. auto functions = ret_graph->add_parameter();
  941. auto index = ret_graph->add_parameter();
  942. ret_graph->set_output(ret_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
  943. return ret_graph;
  944. }
  945. REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
  946. (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
  947. .def(py::init<std::string &>());
  948. }));
  949. REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
  950. (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
  951. .def(py::init<std::string &>());
  952. }));
  953. REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
  954. (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
  955. *m, "TupleGetItemTensor_")
  956. .def(py::init<std::string &>());
  957. }));
  958. } // namespace prim
  959. } // namespace mindspore