You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.cc 47 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "operator/composite/composite.h"
  19. #include <algorithm>
  20. #include <utility>
  21. #include <sstream>
  22. #include "ir/anf.h"
  23. #include "pipeline/static_analysis/abstract_value.h"
  24. #include "pipeline/static_analysis/abstract_function.h"
  25. #include "pipeline/static_analysis/dshape.h"
  26. #include "pipeline/static_analysis/param_validator.h"
  27. #include "operator/cc_implementations.h"
  28. #include "optimizer/opt.h"
  29. #include "utils/symbolic.h"
  30. #include "pybind_api/api_register.h"
  31. #include "./common.h"
  32. #include "ir/signature.h"
  33. #include "debug/trace.h"
  34. namespace mindspore {
  35. // namespace to support composite operators definition
  36. namespace prim {
  37. using AbstractTensor = mindspore::abstract::AbstractTensor;
  38. using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
  39. using mindspore::abstract::AbstractAttribute;
  40. using mindspore::abstract::AbstractBase;
  41. using mindspore::abstract::AbstractClass;
  42. using mindspore::abstract::AbstractDictionary;
  43. using mindspore::abstract::AbstractDictionaryPtr;
  44. using mindspore::abstract::AbstractEllipsis;
  45. using mindspore::abstract::AbstractEllipsisPtr;
  46. using mindspore::abstract::AbstractFunction;
  47. using mindspore::abstract::AbstractFunctionPtr;
  48. using mindspore::abstract::AbstractList;
  49. using mindspore::abstract::AbstractNone;
  50. using mindspore::abstract::AbstractScalar;
  51. using mindspore::abstract::AbstractSlice;
  52. using mindspore::abstract::AbstractTuple;
  53. ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
  54. {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
  55. {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
  56. {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
  57. {"__ge__", kPrimScalarGe}};
  58. const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail");
  59. // copy from python API: reduce.
  60. // Apply a function of two arguments cumulatively to the items of a sequence,
  61. // from left to right, so as to reduce the sequence to a single value.For example,
  62. // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
  63. AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) {
  64. std::shared_ptr<Any> ret;
  65. size_t size = list.size();
  66. if (size < 2) {
  67. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  68. }
  69. AnyPtrList input;
  70. input.push_back(list[0]);
  71. input.push_back(list[1]);
  72. ret = std::make_shared<Any>(func(input));
  73. for (size_t i = 2; i < size; ++i) {
  74. input.clear();
  75. input.push_back(ret);
  76. input.push_back(list[i]);
  77. ret = std::make_shared<Any>(func(input));
  78. }
  79. return ret;
  80. }
  81. AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) {
  82. size_t size = list.size();
  83. if (size < 2) {
  84. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  85. }
  86. std::vector<AnfNodePtr> input;
  87. input.push_back(list[0]);
  88. input.push_back(list[1]);
  89. AnfNodePtr ret = func(input);
  90. for (size_t i = 2; i < size; ++i) {
  91. input.clear();
  92. input.push_back(ret);
  93. input.push_back(list[i]);
  94. ret = func(input);
  95. }
  96. return ret;
  97. }
  98. ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
  99. void HyperMap::Init() {
  100. if (fn_leaf_) {
  101. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  102. }
  103. signatures_ =
  104. // def hypermap(func:read, *args:ref):
  105. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  106. {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
  107. }
  108. HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
  109. : MetaFuncGraph("hyper_map"),
  110. fn_leaf_(fn_leaf),
  111. broadcast_(false),
  112. nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
  113. Init();
  114. }
  115. HyperMap::HyperMap(const HyperMap &h)
  116. : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
  117. Init();
  118. }
  119. AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  120. const ArgsPairList &arg_map) {
  121. MS_EXCEPTION_IF_NULL(func_graph);
  122. std::vector<AnfNodePtr> inputs;
  123. if (fn_arg != nullptr) {
  124. inputs.push_back(fn_arg);
  125. } else {
  126. inputs.push_back(NewValueNode(fn_leaf_));
  127. }
  128. (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
  129. [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
  130. return func_graph->NewCNode(inputs);
  131. }
  132. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
  133. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  134. MS_EXCEPTION_IF_NULL(func_graph);
  135. MS_EXCEPTION_IF_NULL(type);
  136. std::size_t size = type->elements().size();
  137. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  138. auto lhs = std::static_pointer_cast<List>(item.second);
  139. MS_EXCEPTION_IF_NULL(lhs);
  140. return lhs->elements().size() != size;
  141. });
  142. if (is_not_same) {
  143. MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
  144. }
  145. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  146. // hypermap and graph generated, it will cause memory leak.
  147. auto fn_rec = std::make_shared<HyperMap>(*this);
  148. std::vector<AnfNodePtr> inputs;
  149. inputs.push_back(NewValueNode(prim::kPrimMakeList));
  150. for (int i = 0; i < SizeToInt(size); ++i) {
  151. std::vector<AnfNodePtr> inputs2;
  152. inputs2.push_back(NewValueNode(fn_rec));
  153. if (fn_arg != nullptr) {
  154. inputs2.push_back(fn_arg);
  155. }
  156. (void)std::transform(
  157. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
  158. [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
  159. return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
  160. });
  161. inputs.push_back(func_graph->NewCNode(inputs2));
  162. }
  163. return func_graph->NewCNode(inputs);
  164. }
  165. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
  166. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  167. MS_EXCEPTION_IF_NULL(func_graph);
  168. MS_EXCEPTION_IF_NULL(type);
  169. std::size_t size = type->elements().size();
  170. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  171. auto lhs = std::static_pointer_cast<Tuple>(item.second);
  172. MS_EXCEPTION_IF_NULL(lhs);
  173. return lhs->elements().size() != size;
  174. });
  175. if (is_not_same) {
  176. MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length";
  177. }
  178. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  179. // hypermap and graph generated, it will cause memory leak.
  180. auto fn_rec = std::make_shared<HyperMap>(*this);
  181. std::vector<AnfNodePtr> inputs;
  182. inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  183. for (int i = 0; i < SizeToInt(size); ++i) {
  184. std::vector<AnfNodePtr> inputs2;
  185. inputs2.push_back(NewValueNode(fn_rec));
  186. if (fn_arg != nullptr) {
  187. inputs2.push_back(fn_arg);
  188. }
  189. (void)std::transform(
  190. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
  191. return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
  192. });
  193. inputs.push_back(func_graph->NewCNode(inputs2));
  194. }
  195. return func_graph->NewCNode(inputs);
  196. }
  197. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
  198. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  199. MS_EXCEPTION_IF_NULL(type);
  200. MS_EXCEPTION_IF_NULL(func_graph);
  201. std::vector<AnfNodePtr> inputs;
  202. inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
  203. inputs.push_back(NewValueNode(type));
  204. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  205. // hypermap and graph generated, it will cause memory leak.
  206. std::shared_ptr<mindspore::MetaFuncGraph> fn_rec = std::make_shared<HyperMap>(*this);
  207. std::size_t attrSize = type->GetAttributes().size();
  208. for (std::size_t i = 0; i < attrSize; ++i) {
  209. std::vector<AnfNodePtr> inputs2;
  210. inputs2.push_back(NewValueNode(fn_rec));
  211. if (fn_arg) {
  212. inputs2.push_back(fn_arg);
  213. }
  214. int j = 0;
  215. for (auto item : arg_map) {
  216. inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
  217. j++;
  218. }
  219. inputs.push_back(func_graph->NewCNode(inputs2));
  220. }
  221. return func_graph->NewCNode(inputs);
  222. }
  223. AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  224. bool found = false;
  225. TypeId id = kObjectTypeEnd;
  226. std::pair<AnfNodePtr, TypePtr> pair;
  227. for (auto &item : arg_map) {
  228. pair = item;
  229. id = item.second->type_id();
  230. if (nonleaf_.count(id)) {
  231. found = true;
  232. break;
  233. }
  234. }
  235. if (found) {
  236. // In a nonleaf situation, all arguments must have the same generic.
  237. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
  238. if (item.first != pair.first) {
  239. return item.second->type_id() != pair.second->type_id();
  240. }
  241. return false;
  242. });
  243. if (is_not_same) {
  244. std::ostringstream oss;
  245. oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
  246. << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  247. int idx = 0;
  248. for (auto &item : arg_map) {
  249. oss << ++idx << ": " << item.second->ToString() << "\n";
  250. }
  251. MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
  252. }
  253. }
  254. switch (id) {
  255. case kObjectTypeList: {
  256. auto type = std::static_pointer_cast<List>(pair.second);
  257. return FullMake(type, func_graph, fn_arg, arg_map);
  258. }
  259. case kObjectTypeTuple: {
  260. auto type = std::static_pointer_cast<Tuple>(pair.second);
  261. return FullMake(type, func_graph, fn_arg, arg_map);
  262. }
  263. case kObjectTypeClass: {
  264. auto type = std::static_pointer_cast<Class>(pair.second);
  265. return FullMake(type, func_graph, fn_arg, arg_map);
  266. }
  267. default:
  268. return FullMake(pair.second, func_graph, fn_arg, arg_map);
  269. }
  270. }
  271. ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
  272. TypePtr type_tensor = std::make_shared<TensorType>();
  273. bool flag = std::any_of(
  274. args_spec_list.begin(), args_spec_list.end(),
  275. [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
  276. if (flag && broadcast_) {
  277. ArgsPairList ret;
  278. for (auto &item : args_spec_list) {
  279. if (!IsSubType(item.second, type_tensor)) {
  280. TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
  281. ret.push_back(
  282. std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele));
  283. } else {
  284. ret.push_back(std::make_pair(item.first, item.second));
  285. }
  286. }
  287. return ret;
  288. }
  289. return args_spec_list;
  290. }
  291. FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
  292. FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
  293. ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  294. ptrGraph->debug_info()->set_name("hyper_map");
  295. AnfNodePtr ptrFnArg = nullptr;
  296. std::size_t i = 0;
  297. ArgsPairList argmap;
  298. ArgsPairList argmap2;
  299. if (fn_leaf_ == nullptr) {
  300. ptrFnArg = ptrGraph->add_parameter();
  301. i = 1;
  302. }
  303. std::size_t size = args_spec_list.size();
  304. for (; i < size; ++i) {
  305. argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
  306. }
  307. argmap2 = Harmonize(ptrGraph, argmap);
  308. ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2));
  309. return ptrGraph;
  310. }
  311. abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  312. if (fn_leaf_ == nullptr) {
  313. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  314. // Assert that hypermap's function param does not contain free variables
  315. if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
  316. auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
  317. auto func_graph = graph_func->func_graph();
  318. if (func_graph->parent() != nullptr) {
  319. MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
  320. }
  321. }
  322. }
  323. AbstractBasePtrList broadened;
  324. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
  325. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  326. MS_EXCEPTION_IF_NULL(arg);
  327. return arg->Broaden();
  328. });
  329. return broadened;
  330. }
  331. REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
  332. (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
  333. .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
  334. .def(py::init<>());
  335. }));
  336. FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) {
  337. MS_EXCEPTION_IF_NULL(a_tuple);
  338. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  339. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  340. ret->debug_info()->set_name("tail");
  341. AnfNodePtr ptrTup = ret->add_parameter();
  342. std::vector<AnfNodePtr> elems;
  343. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  344. int tuple_size = SizeToInt(a_tuple->size());
  345. for (int i = 1; i < tuple_size; ++i) {
  346. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)}));
  347. }
  348. ret->set_output(ret->NewCNode(elems));
  349. return ret;
  350. }
  351. FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) {
  352. MS_EXCEPTION_IF_NULL(a_list);
  353. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  354. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  355. ret->debug_info()->set_name("tail");
  356. AnfNodePtr ptrList = ret->add_parameter();
  357. std::vector<AnfNodePtr> elems;
  358. elems.push_back(NewValueNode(prim::kPrimMakeList));
  359. int list_size = SizeToInt(a_list->size());
  360. for (int i = 1; i < list_size; ++i) {
  361. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)}));
  362. }
  363. ret->set_output(ret->NewCNode(elems));
  364. return ret;
  365. }
  366. FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  367. if (args_spec_list.size() != 1) {
  368. MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
  369. }
  370. AbstractBasePtr a = args_spec_list[0];
  371. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(a);
  372. if (a_tuple != nullptr) {
  373. return GenerateTupleFuncGraph(a_tuple);
  374. }
  375. abstract::AbstractListPtr a_list = dyn_cast<AbstractList>(a);
  376. if (a_list != nullptr) {
  377. return GenerateListFuncGraph(a_list);
  378. }
  379. MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
  380. }
  381. REGISTER_PYBIND_DEFINE(
  382. Tail_, ([](const py::module *m) {
  383. (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
  384. }));
  385. FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  386. int tuple_size = SizeToInt(args_spec_list.size());
  387. std::ostringstream ss;
  388. ss << "▶make_tuple_" << tuple_size;
  389. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  390. fg->debug_info()->set_name(ss.str());
  391. std::vector<AnfNodePtr> params;
  392. params.push_back(NewValueNode(prim::kPrimMakeTuple));
  393. for (int i = 0; i < tuple_size; ++i) {
  394. params.push_back(fg->add_parameter());
  395. }
  396. // make fprob first result, maketuple's forward result.
  397. AnfNodePtr out = fg->NewCNode(params);
  398. // make fprob second result, maketuple's backward function.
  399. FuncGraphPtr b = std::make_shared<FuncGraph>();
  400. ss.clear();
  401. ss << "◀make_tuple_" << tuple_size;
  402. b->debug_info()->set_name(ss.str());
  403. AnfNodePtr dout = b->add_parameter();
  404. std::vector<AnfNodePtr> grads;
  405. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  406. grads.push_back(NewValueNode(newenv));
  407. for (int i = 0; i < tuple_size; ++i) {
  408. grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
  409. }
  410. b->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  411. b->set_output(b->NewCNode(grads));
  412. fg->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  413. fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  414. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
  415. return fg;
  416. }
  417. GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
  418. : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
  419. if (get_by_list) {
  420. signatures_ =
  421. // def grad(func:read, weight_list:ref):
  422. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  423. {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
  424. }
  425. }
  426. FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
  427. const std::vector<AnfNodePtr> &params_list, bool applyJ) {
  428. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  429. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  430. ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
  431. ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
  432. std::vector<AnfNodePtr> inputs;
  433. if (applyJ) {
  434. inputs.push_back(opsJ);
  435. inputs.push_back(node);
  436. node = ret->NewCNode(inputs);
  437. }
  438. std::vector<AnfNodePtr> params;
  439. for (size_t i = 0; i < params_list.size(); ++i) {
  440. params.push_back(ret->add_parameter());
  441. }
  442. inputs.clear();
  443. inputs.push_back(node);
  444. (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
  445. AnfNodePtr cnode = ret->NewCNode(inputs);
  446. inputs.clear();
  447. inputs.push_back(opsTupleItem);
  448. inputs.push_back(cnode);
  449. inputs.push_back(NewValueNode(0));
  450. auto out = ret->NewCNode(inputs);
  451. inputs.clear();
  452. inputs.push_back(opsTupleItem);
  453. inputs.push_back(cnode);
  454. inputs.push_back(NewValueNode(1));
  455. AnfNodePtr ptrBprop = ret->NewCNode(inputs);
  456. doGetGrad(ret, out, ptrBprop, weights, opsTupleItem);
  457. return ret;
  458. }
  459. void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
  460. ValueNodePtr opsTupleItem) {
  461. MS_EXCEPTION_IF_NULL(func_graph);
  462. AnfNodePtr ptrBPropArg = nullptr;
  463. if (sens_param_) {
  464. ptrBPropArg = func_graph->add_parameter();
  465. } else {
  466. auto ones_like = prim::GetPythonOps("ones_like");
  467. ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out});
  468. }
  469. AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg});
  470. CNodePtr fv_bprop = nullptr;
  471. if (get_by_list_) {
  472. // python code: grads = hyper_map(F.partial(env_get, env), weights)
  473. AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)});
  474. AnfNodePtr partial_env_get =
  475. func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
  476. MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
  477. fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
  478. }
  479. CNodePtr inputs_bprop = nullptr;
  480. if (get_all_) {
  481. inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp});
  482. }
  483. // Gradients wrt inputs and parameters
  484. if (fv_bprop != nullptr && inputs_bprop != nullptr) {
  485. func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
  486. return;
  487. }
  488. // Gradients wrt parameters
  489. if (fv_bprop != nullptr) {
  490. func_graph->set_output(fv_bprop);
  491. return;
  492. }
  493. // Gradients wrt inputs
  494. if (inputs_bprop != nullptr) {
  495. func_graph->set_output(inputs_bprop);
  496. return;
  497. }
  498. // Gradients wrt first input.
  499. // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
  500. func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)}));
  501. }
  502. // Generate the graph.
  503. FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  504. if (args_spec_list.size() < 1) {
  505. MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
  506. << args_spec_list.size() << ".";
  507. }
  508. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  509. AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
  510. if (fn == nullptr) {
  511. MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  512. }
  513. // Waiting for implementation.
  514. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
  515. MS_EXCEPTION_IF_NULL(real_fn);
  516. FuncGraphPtr ptrGraph = real_fn->func_graph();
  517. MS_EXCEPTION_IF_NULL(ptrGraph);
  518. TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
  519. FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>();
  520. TraceManager::EndTrace();
  521. auto nparam = ptrGraph->parameters().size();
  522. std::ostringstream ss;
  523. ss << "grad{" << nparam << "}";
  524. dfBuilder->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  525. dfBuilder->debug_info()->set_name(ss.str());
  526. ParameterPtr param_graph = dfBuilder->add_parameter();
  527. AnfNodePtr weights = nullptr;
  528. if (get_by_list_) {
  529. weights = dfBuilder->add_parameter();
  530. }
  531. std::vector<AnfNodePtr> inputs;
  532. inputs.push_back(NewValueNode(prim::kPrimJ));
  533. inputs.push_back(param_graph);
  534. auto jf = dfBuilder->NewCNode(inputs);
  535. // df is checked in GetGrad
  536. TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
  537. auto df = GetGrad(jf, weights, ptrGraph->parameters());
  538. TraceManager::EndTrace();
  539. dfBuilder->set_output(NewValueNode(df));
  540. return dfBuilder;
  541. }
  542. REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
  543. (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
  544. *m, "GradOperation_")
  545. .def(py::init<std::string &>(), py::arg("fn"))
  546. .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
  547. py::arg("get_by_list"), py::arg("sens_param"));
  548. }));
  549. MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
  550. fn_cache_.clear();
  551. signatures_ = std::vector<Signature>({// def multitype(*args:ref):
  552. {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
  553. }
  554. void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
  555. MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
  556. auto fn = fn_cache_.find(types);
  557. if (fn != fn_cache_.end()) {
  558. MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
  559. }
  560. fn_cache_[types] = s_fn;
  561. }
  562. void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
  563. MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
  564. auto fn = fn_cache_.find(types);
  565. if (fn != fn_cache_.end()) {
  566. MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
  567. }
  568. fn_cache_py_[types] = py_fn;
  569. }
  570. void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) {
  571. TypePtrList types;
  572. for (auto &type_name : types_name) {
  573. auto type_ptr = StringToType(type_name);
  574. if (type_ptr == nullptr) {
  575. MS_LOG(EXCEPTION) << type_name << " convert from string error ";
  576. }
  577. types.push_back(type_ptr);
  578. }
  579. Register(types, py_fn);
  580. }
  581. void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
  582. std::vector<std::string> types_name;
  583. for (size_t it = 0; it < tuple.size(); ++it) {
  584. py::object name_py = tuple[it];
  585. if (py::isinstance<py::str>(name_py)) {
  586. types_name.push_back(name_py.cast<std::string>());
  587. continue;
  588. }
  589. MS_LOG(EXCEPTION) << "Register must be string";
  590. }
  591. Register(types_name, py_fn);
  592. }
  593. static TypePtr UnwrapRef(const TypePtr &type) {
  594. if (type->isa<RefType>()) {
  595. return type->cast<RefTypePtr>()->subtype();
  596. }
  597. return type;
  598. }
  599. FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
  600. bool find_fn = false;
  601. py::function py_fn;
  602. for (auto &item : fn_cache_py_) {
  603. TypePtrList sign = item.first;
  604. if (sign.size() != types.size()) {
  605. continue;
  606. }
  607. bool match = true;
  608. for (size_t i = 0; i < sign.size(); ++i) {
  609. if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
  610. match = false;
  611. break;
  612. }
  613. }
  614. if (!match) {
  615. continue;
  616. }
  617. find_fn = true;
  618. py_fn = item.second;
  619. break;
  620. }
  621. std::ostringstream buffer;
  622. buffer << types;
  623. if (find_fn) {
  624. FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
  625. if (func_graph == nullptr) {
  626. MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
  627. }
  628. MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
  629. return func_graph;
  630. }
  631. std::ostringstream oss;
  632. oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
  633. << "`, corresponding location info:\n";
  634. int idx = 0;
  635. for (auto &item : fn_cache_py_) {
  636. FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
  637. if (func_graph == nullptr) {
  638. MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
  639. continue;
  640. }
  641. oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  642. }
  643. MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n"
  644. << oss.str();
  645. }
  646. REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
  647. (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
  648. *m, "MultitypeFuncGraph_")
  649. .def(py::init<std::string &>())
  650. .def("register_fn", &MultitypeFuncGraph::PyRegister);
  651. }));
  652. // Generate the ListMap func graph.
  653. FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  654. size_t args_num = args_spec_list.size();
  655. // args: fn, list1, list2, ...
  656. if (args_num < 2) {
  657. MS_LOG(EXCEPTION) << "list_map takes at least two arguments";
  658. }
  659. for (size_t i = 1; i < args_num; ++i) {
  660. if (typeid(args_spec_list[i]) != typeid(AbstractBase)) {
  661. // The function currently not be use
  662. MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'";
  663. }
  664. }
  665. FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
  666. fg_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  667. fg_ptr->debug_info()->set_name("list_map");
  668. AnfNodePtr fn = fg_ptr->add_parameter();
  669. std::vector<AnfNodePtr> lists;
  670. for (size_t i = 1; i < args_num; ++i) {
  671. lists.push_back(fg_ptr->add_parameter());
  672. }
  673. std::vector<AnfNodePtr> iters;
  674. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  675. return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item});
  676. });
  677. std::vector<AnfNodePtr> nexts;
  678. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  679. return fg_ptr->NewCNode({NewValueNode(std::string("next")), item});
  680. });
  681. std::vector<AnfNodePtr> values;
  682. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  683. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item});
  684. });
  685. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  686. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)});
  687. });
  688. (void)values.insert(values.begin(), fn);
  689. AnfNodePtr cnode_graph = fg_ptr->NewCNode(values);
  690. AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph});
  691. FuncGraphPtr fgnext_ptr = std::make_shared<FuncGraph>();
  692. fgnext_ptr->debug_info()->set_name("body");
  693. FuncGraphPtr fgcond_ptr = std::make_shared<FuncGraph>();
  694. fgcond_ptr->debug_info()->set_name("cond");
  695. MakeCond(lists, fgnext_ptr, fgcond_ptr);
  696. MakeNext(lists, fgcond_ptr, fgnext_ptr);
  697. CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl});
  698. auto inputs = output_cnode->inputs();
  699. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  700. output_cnode->set_inputs(inputs);
  701. fg_ptr->set_output(output_cnode);
  702. return fg_ptr;
  703. }
  704. void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
  705. const FuncGraphPtr &fg_ptr) {
  706. MS_EXCEPTION_IF_NULL(fg_ptr);
  707. AnfNodePtr fn = fg_ptr->add_parameter();
  708. AnfNodePtr resl = fg_ptr->add_parameter();
  709. std::vector<AnfNodePtr> iters;
  710. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  711. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  712. std::vector<AnfNodePtr> hasnexts;
  713. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) {
  714. return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item});
  715. });
  716. // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
  717. FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
  718. fgtrue_ptr->debug_info()->set_name("ftrue");
  719. fgtrue_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  720. CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl});
  721. auto inputs = fgtrue_output_cnode->inputs();
  722. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  723. fgtrue_output_cnode->set_inputs(inputs);
  724. fgtrue_ptr->set_output(fgtrue_output_cnode);
  725. FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
  726. fgfalse_ptr->debug_info()->set_name("ffalse");
  727. fgfalse_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  728. fgfalse_ptr->set_output(resl);
  729. AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
  730. NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)});
  731. fgtrue_ptr->set_output(output_cnode);
  732. }
  733. void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
  734. const FuncGraphPtr &fg_ptr) {
  735. MS_EXCEPTION_IF_NULL(fg_ptr);
  736. AnfNodePtr fn = fg_ptr->add_parameter();
  737. std::vector<AnfNodePtr> iters;
  738. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  739. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  740. std::vector<AnfNodePtr> nexts;
  741. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  742. return fg_ptr->NewCNode({NewValueNode(std::string("next")), item});
  743. });
  744. std::vector<AnfNodePtr> values;
  745. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  746. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr});
  747. });
  748. iters.clear();
  749. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  750. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)});
  751. });
  752. (void)values.insert(values.begin(), fn);
  753. AnfNodePtr cnode_graph = fg_ptr->NewCNode(values);
  754. AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph});
  755. CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl});
  756. auto inputs = output_cnode->inputs();
  757. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  758. output_cnode->set_inputs(inputs);
  759. fg_ptr->set_output(output_cnode);
  760. }
  761. FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  762. // args: tuple1, tuple2
  763. abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
  764. AbstractBasePtr abs_a = args_spec_list[0];
  765. AbstractBasePtr abs_b = args_spec_list[1];
  766. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
  767. abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
  768. if (a_tuple == nullptr || b_tuple == nullptr) {
  769. MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
  770. << args_spec_list[1]->ToString();
  771. }
  772. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  773. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  774. AnfNodePtr p_tup_a = ret->add_parameter();
  775. AnfNodePtr p_tup_b = ret->add_parameter();
  776. std::vector<AnfNodePtr> elems;
  777. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  778. int tuple_size = SizeToInt(a_tuple->size());
  779. for (int i = 0; i < tuple_size; ++i) {
  780. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
  781. }
  782. tuple_size = SizeToInt(b_tuple->size());
  783. for (int i = 0; i < tuple_size; ++i) {
  784. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
  785. }
  786. ret->set_output(ret->NewCNode(elems));
  787. return ret;
  788. }
  789. int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
  790. MS_EXCEPTION_IF_NULL(scalar);
  791. return GetValue<int>(scalar->BuildValue());
  792. }
  793. bool CheckIndexInRange(int index, int min, int max) { return (index >= min && index <= max); }
  794. int GetPositiveIndex(int index, int length) {
  795. if (index < 0) {
  796. index += length;
  797. }
  798. return index;
  799. }
  800. int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) {
  801. MS_EXCEPTION_IF_NULL(member);
  802. if (member->isa<AbstractScalar>()) {
  803. return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
  804. }
  805. if (member->isa<AbstractNone>()) {
  806. return default_value;
  807. }
  808. MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString();
  809. }
  810. void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index,
  811. int *stop_index, int *step_value) {
  812. MS_EXCEPTION_IF_NULL(tuple);
  813. MS_EXCEPTION_IF_NULL(slice);
  814. MS_EXCEPTION_IF_NULL(start_index);
  815. MS_EXCEPTION_IF_NULL(stop_index);
  816. MS_EXCEPTION_IF_NULL(step_value);
  817. const std::string start_name("Slice start index");
  818. const std::string stop_name("Slice stop index");
  819. const std::string step_name("Slice step value");
  820. int tuple_size = SizeToInt(tuple->size());
  821. int start_default = 0;
  822. int stop_default = tuple_size;
  823. int step_default = 1;
  824. *step_value = CheckSliceMember(slice->step(), step_default, step_name);
  825. if (*step_value == 0) {
  826. MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0.";
  827. }
  828. if (*step_value < 0) {
  829. start_default = tuple_size - 1;
  830. stop_default = -1;
  831. }
  832. *start_index = CheckSliceMember(slice->start(), start_default, start_name);
  833. *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
  834. if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) ||
  835. !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) {
  836. MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
  837. << " out of range, tuple size " << tuple_size << ".";
  838. }
  839. *start_index = GetPositiveIndex(*start_index, tuple_size);
  840. if (!slice->stop()->isa<AbstractNone>()) {
  841. *stop_index = GetPositiveIndex(*stop_index, tuple_size);
  842. }
  843. }
  844. FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  845. // slice a tuple
  846. // args: tuple, start index, end index, step
  847. const std::string op_name("TupleSlice");
  848. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  849. AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  850. AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
  851. int start_index;
  852. int stop_index;
  853. int step_value;
  854. GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
  855. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  856. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  857. AnfNodePtr p_tuple = ret->add_parameter();
  858. (void)ret->add_parameter();
  859. std::vector<AnfNodePtr> elems;
  860. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  861. if (step_value > 0) {
  862. for (int index = start_index; index < stop_index; index = index + step_value) {
  863. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  864. }
  865. } else {
  866. for (int index = start_index; index > stop_index; index = index + step_value) {
  867. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  868. }
  869. }
  870. ret->set_output(ret->NewCNode(elems));
  871. return ret;
  872. }
  873. int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) {
  874. unsigned int number_dec = 0;
  875. for (size_t index = 0; index < number_bin.size(); index++) {
  876. number_dec |= number_bin[index] << index;
  877. }
  878. return static_cast<int>(number_dec);
  879. }
  880. void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end,
  881. std::vector<int> *strides, int length) {
  882. MS_EXCEPTION_IF_NULL(slice);
  883. MS_EXCEPTION_IF_NULL(begin);
  884. MS_EXCEPTION_IF_NULL(end);
  885. MS_EXCEPTION_IF_NULL(strides);
  886. if (length <= 0) {
  887. MS_LOG(EXCEPTION) << "Could not slice a dim when it's length less than 1";
  888. }
  889. int start_default = 0;
  890. int stop_default = length;
  891. int step_default = 1;
  892. int step_value = CheckSliceMember(slice->step(), step_default, "step");
  893. if (step_value < 0) {
  894. start_default = -1;
  895. stop_default = -(length + 1);
  896. }
  897. begin->push_back(CheckSliceMember(slice->start(), start_default, "begin"));
  898. end->push_back(CheckSliceMember(slice->stop(), stop_default, "stop"));
  899. strides->push_back(step_value);
  900. }
  901. int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape,
  902. std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
  903. MS_EXCEPTION_IF_NULL(slice_tuple);
  904. MS_EXCEPTION_IF_NULL(begin);
  905. MS_EXCEPTION_IF_NULL(end);
  906. MS_EXCEPTION_IF_NULL(strides);
  907. size_t slice_tuple_size = slice_tuple->size();
  908. size_t shape_size = shape.size();
  909. if (slice_tuple_size > shape_size) {
  910. MS_LOG(EXCEPTION) << "The number of slice data to slice tensor should be less than the rank of tensor,"
  911. "when the rank of tensor is "
  912. << shape_size << ", the number of slice is " << slice_tuple_size;
  913. }
  914. std::vector<unsigned int> shrink;
  915. auto slice_tuple_eles = slice_tuple->elements();
  916. size_t ellipsis_num = 0;
  917. for (size_t index = 0; index < slice_tuple_size; index++) {
  918. if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
  919. AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
  920. ParseSlice(slice, begin, end, strides, shape[index]);
  921. shrink.push_back(0);
  922. continue;
  923. }
  924. if (slice_tuple_eles[index]->isa<AbstractScalar>()) {
  925. int ele_index = GetArgScalarValue(dyn_cast<AbstractScalar>(slice_tuple_eles[index]), "slice_tuple");
  926. begin->push_back(ele_index);
  927. end->push_back(ele_index + 1);
  928. strides->push_back(1);
  929. shrink.push_back(1);
  930. continue;
  931. }
  932. if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
  933. ellipsis_num++;
  934. if (ellipsis_num > 1) {
  935. MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
  936. }
  937. size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
  938. begin->insert(begin->end(), ellipsis_len, 0);
  939. end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
  940. strides->insert(strides->end(), ellipsis_len, 1);
  941. shrink.insert(shrink.end(), ellipsis_len, 0);
  942. continue;
  943. }
  944. MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
  945. << slice_tuple_eles[index]->ToString();
  946. }
  947. if (ellipsis_num == 0) {
  948. for (size_t index = slice_tuple_size; index < shape_size; index++) {
  949. begin->push_back(0);
  950. end->push_back(shape[index]);
  951. strides->push_back(1);
  952. }
  953. }
  954. return ConvertBinaryToDecimal(shrink);
  955. }
  956. int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape,
  957. std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
  958. MS_EXCEPTION_IF_NULL(begin);
  959. MS_EXCEPTION_IF_NULL(end);
  960. MS_EXCEPTION_IF_NULL(strides);
  961. size_t shape_size = shape.size();
  962. if (shape_size == 0) {
  963. MS_LOG(EXCEPTION) << "Could slice a scalar tensor";
  964. }
  965. ParseSlice(slice, begin, end, strides, shape[0]);
  966. for (size_t index = 1; index < shape_size; index++) {
  967. begin->push_back(0);
  968. end->push_back(shape[index]);
  969. strides->push_back(1);
  970. }
  971. return 0;
  972. }
  973. int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape,
  974. std::vector<int> *begin, std::vector<int> *end,
  975. std::vector<int> *strides) {
  976. MS_EXCEPTION_IF_NULL(begin);
  977. MS_EXCEPTION_IF_NULL(end);
  978. MS_EXCEPTION_IF_NULL(strides);
  979. int ele_index = GetArgScalarValue(scalar, "slice_tuple");
  980. begin->push_back(ele_index);
  981. end->push_back(ele_index + 1);
  982. strides->push_back(1);
  983. for (size_t index = 1; index < shape.size(); index++) {
  984. begin->push_back(0);
  985. end->push_back(shape[index]);
  986. strides->push_back(1);
  987. }
  988. return 1;
  989. }
  990. FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  991. // slice a tensor
  992. // args: tensor, slice or slice tuple
  993. const std::string op_name = std::string("TensorSlice");
  994. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  995. AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  996. FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
  997. ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  998. AnfNodePtr tensor_node = ret_graph->add_parameter();
  999. (void)ret_graph->add_parameter();
  1000. auto shape = tensorPtr->shape()->shape();
  1001. std::vector<int> begin;
  1002. std::vector<int> end;
  1003. std::vector<int> strides;
  1004. int shrink_axis_mask;
  1005. if (args_spec_list[1]->isa<AbstractTuple>()) {
  1006. AbstractTuplePtr tuple_ptr = dyn_cast<AbstractTuple>(args_spec_list[1]);
  1007. shrink_axis_mask = GenerateStridedSliceParametersFromTuple(tuple_ptr, shape, &begin, &end, &strides);
  1008. } else if (args_spec_list[1]->isa<AbstractSlice>()) {
  1009. AbstractSlicePtr slice_ptr = dyn_cast<AbstractSlice>(args_spec_list[1]);
  1010. shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
  1011. } else if (args_spec_list[1]->isa<AbstractScalar>()) {
  1012. AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
  1013. if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
  1014. if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
  1015. return ExpandADim(ret_graph, tensor_node);
  1016. }
  1017. MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
  1018. }
  1019. shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
  1020. } else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
  1021. ret_graph->set_output(tensor_node);
  1022. return ret_graph;
  1023. } else if (args_spec_list[1]->isa<AbstractNone>()) {
  1024. return ExpandADim(ret_graph, tensor_node);
  1025. } else {
  1026. std::ostringstream args_info;
  1027. for (const auto &arg : args_spec_list) {
  1028. MS_EXCEPTION_IF_NULL(arg);
  1029. args_info << arg->ToString() << "\n";
  1030. }
  1031. MS_LOG(EXCEPTION)
  1032. << "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
  1033. << args_info.str();
  1034. }
  1035. auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
  1036. auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
  1037. NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
  1038. ret_graph->set_output(ret_graph->NewCNode(
  1039. {PrimStridedSlice, tensor_node, NewValueNode(begin), NewValueNode(end), NewValueNode(strides)}));
  1040. return ret_graph;
  1041. }
  1042. FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
  1043. auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
  1044. ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
  1045. return ret_graph;
  1046. }
  1047. REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
  1048. (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
  1049. .def(py::init<std::string &>());
  1050. }));
  1051. REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
  1052. (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
  1053. .def(py::init<std::string &>());
  1054. }));
  1055. REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
  1056. (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
  1057. .def(py::init<std::string &>());
  1058. }));
  1059. } // namespace prim
  1060. } // namespace mindspore