You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.cc 48 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "operator/composite/composite.h"
  19. #include <algorithm>
  20. #include <utility>
  21. #include <sstream>
  22. #include "ir/anf.h"
  23. #include "ir/func_graph.h"
  24. #include "pipeline/static_analysis/abstract_value.h"
  25. #include "pipeline/static_analysis/abstract_function.h"
  26. #include "pipeline/static_analysis/dshape.h"
  27. #include "pipeline/static_analysis/param_validator.h"
  28. #include "operator/cc_implementations.h"
  29. #include "optimizer/opt.h"
  30. #include "utils/symbolic.h"
  31. #include "pybind_api/api_register.h"
  32. #include "./common.h"
  33. #include "ir/signature.h"
  34. #include "debug/trace.h"
  35. namespace mindspore {
  36. // namespace to support composite operators definition
  37. namespace prim {
  38. using AbstractTensor = mindspore::abstract::AbstractTensor;
  39. using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
  40. using mindspore::abstract::AbstractAttribute;
  41. using mindspore::abstract::AbstractBase;
  42. using mindspore::abstract::AbstractClass;
  43. using mindspore::abstract::AbstractDictionary;
  44. using mindspore::abstract::AbstractDictionaryPtr;
  45. using mindspore::abstract::AbstractEllipsis;
  46. using mindspore::abstract::AbstractEllipsisPtr;
  47. using mindspore::abstract::AbstractFunction;
  48. using mindspore::abstract::AbstractFunctionPtr;
  49. using mindspore::abstract::AbstractList;
  50. using mindspore::abstract::AbstractNone;
  51. using mindspore::abstract::AbstractScalar;
  52. using mindspore::abstract::AbstractSlice;
  53. using mindspore::abstract::AbstractTuple;
  54. ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
  55. {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
  56. {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
  57. {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
  58. {"__ge__", kPrimScalarGe}};
  59. const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail");
  60. // copy from python API: reduce.
  61. // Apply a function of two arguments cumulatively to the items of a sequence,
  62. // from left to right, so as to reduce the sequence to a single value.For example,
  63. // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
  64. AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) {
  65. std::shared_ptr<Any> ret;
  66. size_t size = list.size();
  67. if (size < 2) {
  68. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  69. }
  70. AnyPtrList input;
  71. input.push_back(list[0]);
  72. input.push_back(list[1]);
  73. ret = std::make_shared<Any>(func(input));
  74. for (size_t i = 2; i < size; ++i) {
  75. input.clear();
  76. input.push_back(ret);
  77. input.push_back(list[i]);
  78. ret = std::make_shared<Any>(func(input));
  79. }
  80. return ret;
  81. }
  82. AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) {
  83. size_t size = list.size();
  84. if (size < 2) {
  85. MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
  86. }
  87. std::vector<AnfNodePtr> input;
  88. input.push_back(list[0]);
  89. input.push_back(list[1]);
  90. AnfNodePtr ret = func(input);
  91. for (size_t i = 2; i < size; ++i) {
  92. input.clear();
  93. input.push_back(ret);
  94. input.push_back(list[i]);
  95. ret = func(input);
  96. }
  97. return ret;
  98. }
  99. ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
  100. void HyperMap::Init() {
  101. if (fn_leaf_) {
  102. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  103. }
  104. signatures_ =
  105. // def hypermap(func:read, *args:ref):
  106. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  107. {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
  108. }
  109. HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
  110. : MetaFuncGraph("hyper_map"),
  111. fn_leaf_(fn_leaf),
  112. broadcast_(false),
  113. nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
  114. Init();
  115. }
  116. HyperMap::HyperMap(const HyperMap &h)
  117. : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
  118. Init();
  119. }
  120. AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  121. const ArgsPairList &arg_map) {
  122. MS_EXCEPTION_IF_NULL(func_graph);
  123. std::vector<AnfNodePtr> inputs;
  124. if (fn_arg != nullptr) {
  125. inputs.push_back(fn_arg);
  126. } else {
  127. inputs.push_back(NewValueNode(fn_leaf_));
  128. }
  129. (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
  130. [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
  131. return func_graph->NewCNode(inputs);
  132. }
  133. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
  134. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  135. MS_EXCEPTION_IF_NULL(func_graph);
  136. MS_EXCEPTION_IF_NULL(type);
  137. std::size_t size = type->elements().size();
  138. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  139. auto lhs = std::static_pointer_cast<List>(item.second);
  140. MS_EXCEPTION_IF_NULL(lhs);
  141. return lhs->elements().size() != size;
  142. });
  143. if (is_not_same) {
  144. MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
  145. }
  146. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  147. // hypermap and graph generated, it will cause memory leak.
  148. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  149. std::vector<AnfNodePtr> inputs;
  150. inputs.push_back(NewValueNode(prim::kPrimMakeList));
  151. for (int i = 0; i < SizeToInt(size); ++i) {
  152. std::vector<AnfNodePtr> inputs2;
  153. inputs2.push_back(fn_rec);
  154. if (fn_arg != nullptr) {
  155. inputs2.push_back(fn_arg);
  156. }
  157. (void)std::transform(
  158. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
  159. [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
  160. return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
  161. });
  162. inputs.push_back(func_graph->NewCNode(inputs2));
  163. }
  164. return func_graph->NewCNode(inputs);
  165. }
  166. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
  167. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  168. MS_EXCEPTION_IF_NULL(func_graph);
  169. MS_EXCEPTION_IF_NULL(type);
  170. std::size_t size = type->elements().size();
  171. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
  172. auto lhs = std::static_pointer_cast<Tuple>(item.second);
  173. MS_EXCEPTION_IF_NULL(lhs);
  174. return lhs->elements().size() != size;
  175. });
  176. if (is_not_same) {
  177. MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length";
  178. }
  179. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  180. // hypermap and graph generated, it will cause memory leak.
  181. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  182. std::vector<AnfNodePtr> inputs;
  183. inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  184. for (int i = 0; i < SizeToInt(size); ++i) {
  185. std::vector<AnfNodePtr> inputs2;
  186. inputs2.push_back(fn_rec);
  187. if (fn_arg != nullptr) {
  188. inputs2.push_back(fn_arg);
  189. }
  190. (void)std::transform(
  191. arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
  192. return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
  193. });
  194. inputs.push_back(func_graph->NewCNode(inputs2));
  195. }
  196. return func_graph->NewCNode(inputs);
  197. }
  198. AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
  199. const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  200. MS_EXCEPTION_IF_NULL(type);
  201. MS_EXCEPTION_IF_NULL(func_graph);
  202. std::vector<AnfNodePtr> inputs;
  203. inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
  204. inputs.push_back(NewValueNode(type));
  205. // cannot use shared_from_base() also known as this, as it will make a reference cycle on
  206. // hypermap and graph generated, it will cause memory leak.
  207. auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
  208. std::size_t attrSize = type->GetAttributes().size();
  209. for (std::size_t i = 0; i < attrSize; ++i) {
  210. std::vector<AnfNodePtr> inputs2;
  211. inputs2.push_back(fn_rec);
  212. if (fn_arg) {
  213. inputs2.push_back(fn_arg);
  214. }
  215. int j = 0;
  216. for (auto item : arg_map) {
  217. inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
  218. j++;
  219. }
  220. inputs.push_back(func_graph->NewCNode(inputs2));
  221. }
  222. return func_graph->NewCNode(inputs);
  223. }
  224. AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
  225. bool found = false;
  226. TypeId id = kObjectTypeEnd;
  227. std::pair<AnfNodePtr, TypePtr> pair;
  228. for (auto &item : arg_map) {
  229. pair = item;
  230. id = item.second->type_id();
  231. if (nonleaf_.count(id)) {
  232. found = true;
  233. break;
  234. }
  235. }
  236. if (found) {
  237. // In a nonleaf situation, all arguments must have the same generic.
  238. bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
  239. if (item.first != pair.first) {
  240. return item.second->type_id() != pair.second->type_id();
  241. }
  242. return false;
  243. });
  244. if (is_not_same) {
  245. std::ostringstream oss;
  246. oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
  247. << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  248. int idx = 0;
  249. for (auto &item : arg_map) {
  250. oss << ++idx << ": " << item.second->ToString() << "\n";
  251. }
  252. MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
  253. }
  254. }
  255. switch (id) {
  256. case kObjectTypeList: {
  257. auto type = std::static_pointer_cast<List>(pair.second);
  258. return FullMake(type, func_graph, fn_arg, arg_map);
  259. }
  260. case kObjectTypeTuple: {
  261. auto type = std::static_pointer_cast<Tuple>(pair.second);
  262. return FullMake(type, func_graph, fn_arg, arg_map);
  263. }
  264. case kObjectTypeClass: {
  265. auto type = std::static_pointer_cast<Class>(pair.second);
  266. return FullMake(type, func_graph, fn_arg, arg_map);
  267. }
  268. default:
  269. return FullMake(pair.second, func_graph, fn_arg, arg_map);
  270. }
  271. }
  272. ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
  273. TypePtr type_tensor = std::make_shared<TensorType>();
  274. bool flag = std::any_of(
  275. args_spec_list.begin(), args_spec_list.end(),
  276. [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
  277. if (flag && broadcast_) {
  278. ArgsPairList ret;
  279. for (auto &item : args_spec_list) {
  280. if (!IsSubType(item.second, type_tensor)) {
  281. TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
  282. ret.push_back(
  283. std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele));
  284. } else {
  285. ret.push_back(std::make_pair(item.first, item.second));
  286. }
  287. }
  288. return ret;
  289. }
  290. return args_spec_list;
  291. }
  292. FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
  293. FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
  294. ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  295. ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
  296. ptrGraph->debug_info()->set_name("hyper_map");
  297. AnfNodePtr ptrFnArg = nullptr;
  298. std::size_t i = 0;
  299. ArgsPairList argmap;
  300. ArgsPairList argmap2;
  301. if (fn_leaf_ == nullptr) {
  302. ptrFnArg = ptrGraph->add_parameter();
  303. i = 1;
  304. }
  305. std::size_t size = args_spec_list.size();
  306. for (; i < size; ++i) {
  307. argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
  308. }
  309. argmap2 = Harmonize(ptrGraph, argmap);
  310. ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2));
  311. return ptrGraph;
  312. }
  313. abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  314. if (fn_leaf_ == nullptr) {
  315. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  316. // Assert that hypermap's function param does not contain free variables
  317. if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
  318. auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
  319. auto func_graph = graph_func->func_graph();
  320. if (func_graph->parent() != nullptr) {
  321. MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
  322. }
  323. }
  324. }
  325. AbstractBasePtrList broadened;
  326. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
  327. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  328. MS_EXCEPTION_IF_NULL(arg);
  329. return arg->Broaden();
  330. });
  331. return broadened;
  332. }
  333. REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
  334. (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
  335. .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
  336. .def(py::init<>());
  337. }));
  338. FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) {
  339. MS_EXCEPTION_IF_NULL(a_tuple);
  340. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  341. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  342. ret->debug_info()->set_name("tail");
  343. AnfNodePtr ptrTup = ret->add_parameter();
  344. std::vector<AnfNodePtr> elems;
  345. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  346. int tuple_size = SizeToInt(a_tuple->size());
  347. for (int i = 1; i < tuple_size; ++i) {
  348. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)}));
  349. }
  350. ret->set_output(ret->NewCNode(elems));
  351. return ret;
  352. }
  353. FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) {
  354. MS_EXCEPTION_IF_NULL(a_list);
  355. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  356. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  357. ret->debug_info()->set_name("tail");
  358. AnfNodePtr ptrList = ret->add_parameter();
  359. std::vector<AnfNodePtr> elems;
  360. elems.push_back(NewValueNode(prim::kPrimMakeList));
  361. int list_size = SizeToInt(a_list->size());
  362. for (int i = 1; i < list_size; ++i) {
  363. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)}));
  364. }
  365. ret->set_output(ret->NewCNode(elems));
  366. return ret;
  367. }
  368. FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  369. if (args_spec_list.size() != 1) {
  370. MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
  371. }
  372. AbstractBasePtr a = args_spec_list[0];
  373. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(a);
  374. if (a_tuple != nullptr) {
  375. return GenerateTupleFuncGraph(a_tuple);
  376. }
  377. abstract::AbstractListPtr a_list = dyn_cast<AbstractList>(a);
  378. if (a_list != nullptr) {
  379. return GenerateListFuncGraph(a_list);
  380. }
  381. MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
  382. }
  383. REGISTER_PYBIND_DEFINE(
  384. Tail_, ([](const py::module *m) {
  385. (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
  386. }));
  387. FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  388. int tuple_size = SizeToInt(args_spec_list.size());
  389. std::ostringstream ss;
  390. ss << "▶make_tuple_" << tuple_size;
  391. FuncGraphPtr fg = std::make_shared<FuncGraph>();
  392. fg->debug_info()->set_name(ss.str());
  393. std::vector<AnfNodePtr> params;
  394. params.push_back(NewValueNode(prim::kPrimMakeTuple));
  395. for (int i = 0; i < tuple_size; ++i) {
  396. params.push_back(fg->add_parameter());
  397. }
  398. // make fprob first result, maketuple's forward result.
  399. AnfNodePtr out = fg->NewCNode(params);
  400. // make fprob second result, maketuple's backward function.
  401. FuncGraphPtr b = std::make_shared<FuncGraph>();
  402. ss.clear();
  403. ss << "◀make_tuple_" << tuple_size;
  404. b->debug_info()->set_name(ss.str());
  405. AnfNodePtr dout = b->add_parameter();
  406. std::vector<AnfNodePtr> grads;
  407. grads.push_back(NewValueNode(prim::kPrimMakeTuple));
  408. grads.push_back(NewValueNode(newenv));
  409. for (int i = 0; i < tuple_size; ++i) {
  410. grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
  411. }
  412. b->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  413. b->set_output(b->NewCNode(grads));
  414. fg->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  415. fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
  416. (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
  417. return fg;
  418. }
  419. GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
  420. : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
  421. if (get_by_list) {
  422. signatures_ =
  423. // def grad(func:read, weight_list:ref):
  424. std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
  425. {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
  426. }
  427. }
  428. FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
  429. const std::vector<AnfNodePtr> &params_list, bool applyJ) {
  430. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  431. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  432. ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
  433. ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
  434. std::vector<AnfNodePtr> inputs;
  435. if (applyJ) {
  436. inputs.push_back(opsJ);
  437. inputs.push_back(node);
  438. node = ret->NewCNode(inputs);
  439. }
  440. std::vector<AnfNodePtr> params;
  441. for (size_t i = 0; i < params_list.size(); ++i) {
  442. params.push_back(ret->add_parameter());
  443. }
  444. inputs.clear();
  445. inputs.push_back(node);
  446. (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
  447. AnfNodePtr cnode = ret->NewCNode(inputs);
  448. inputs.clear();
  449. inputs.push_back(opsTupleItem);
  450. inputs.push_back(cnode);
  451. inputs.push_back(NewValueNode(0));
  452. auto out = ret->NewCNode(inputs);
  453. inputs.clear();
  454. inputs.push_back(opsTupleItem);
  455. inputs.push_back(cnode);
  456. inputs.push_back(NewValueNode(1));
  457. AnfNodePtr ptrBprop = ret->NewCNode(inputs);
  458. doGetGrad(ret, out, ptrBprop, weights, opsTupleItem);
  459. return ret;
  460. }
  461. void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
  462. ValueNodePtr opsTupleItem) {
  463. MS_EXCEPTION_IF_NULL(func_graph);
  464. AnfNodePtr ptrBPropArg = nullptr;
  465. if (sens_param_) {
  466. ptrBPropArg = func_graph->add_parameter();
  467. } else {
  468. auto ones_like = prim::GetPythonOps("ones_like");
  469. ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out});
  470. }
  471. AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg});
  472. CNodePtr fv_bprop = nullptr;
  473. if (get_by_list_) {
  474. // python code: grads = hyper_map(F.partial(env_get, env), weights)
  475. AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)});
  476. AnfNodePtr partial_env_get =
  477. func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
  478. MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
  479. fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
  480. }
  481. CNodePtr inputs_bprop = nullptr;
  482. if (get_all_) {
  483. inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp});
  484. }
  485. // Gradients wrt inputs and parameters
  486. if (fv_bprop != nullptr && inputs_bprop != nullptr) {
  487. func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
  488. return;
  489. }
  490. // Gradients wrt parameters
  491. if (fv_bprop != nullptr) {
  492. func_graph->set_output(fv_bprop);
  493. return;
  494. }
  495. // Gradients wrt inputs
  496. if (inputs_bprop != nullptr) {
  497. func_graph->set_output(inputs_bprop);
  498. return;
  499. }
  500. // Gradients wrt first input.
  501. // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
  502. func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)}));
  503. }
  504. // Generate the graph.
  505. FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  506. if (args_spec_list.size() < 1) {
  507. MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
  508. << args_spec_list.size() << ".";
  509. }
  510. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  511. AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
  512. if (fn == nullptr) {
  513. MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  514. }
  515. // Waiting for implementation.
  516. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
  517. MS_EXCEPTION_IF_NULL(real_fn);
  518. FuncGraphPtr ptrGraph = real_fn->func_graph();
  519. MS_EXCEPTION_IF_NULL(ptrGraph);
  520. TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
  521. FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>();
  522. TraceManager::EndTrace();
  523. auto nparam = ptrGraph->parameters().size();
  524. std::ostringstream ss;
  525. ss << "grad{" << nparam << "}";
  526. dfBuilder->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  527. dfBuilder->debug_info()->set_name(ss.str());
  528. ParameterPtr param_graph = dfBuilder->add_parameter();
  529. AnfNodePtr weights = nullptr;
  530. if (get_by_list_) {
  531. weights = dfBuilder->add_parameter();
  532. }
  533. std::vector<AnfNodePtr> inputs;
  534. inputs.push_back(NewValueNode(prim::kPrimJ));
  535. inputs.push_back(param_graph);
  536. auto jf = dfBuilder->NewCNode(inputs);
  537. // df is checked in GetGrad
  538. TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
  539. auto df = GetGrad(jf, weights, ptrGraph->parameters());
  540. TraceManager::EndTrace();
  541. dfBuilder->set_output(NewValueNode(df));
  542. return dfBuilder;
  543. }
  544. REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
  545. (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
  546. *m, "GradOperation_")
  547. .def(py::init<std::string &>(), py::arg("fn"))
  548. .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
  549. py::arg("get_by_list"), py::arg("sens_param"));
  550. }));
  551. MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
  552. fn_cache_.clear();
  553. signatures_ = std::vector<Signature>({// def multitype(*args:ref):
  554. {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
  555. }
  556. void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
  557. MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
  558. auto fn = fn_cache_.find(types);
  559. if (fn != fn_cache_.end()) {
  560. MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
  561. }
  562. fn_cache_[types] = s_fn;
  563. }
  564. void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
  565. MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
  566. auto fn = fn_cache_.find(types);
  567. if (fn != fn_cache_.end()) {
  568. MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
  569. }
  570. fn_cache_py_[types] = py_fn;
  571. }
  572. void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) {
  573. TypePtrList types;
  574. for (auto &type_name : types_name) {
  575. auto type_ptr = StringToType(type_name);
  576. if (type_ptr == nullptr) {
  577. MS_LOG(EXCEPTION) << type_name << " convert from string error ";
  578. }
  579. types.push_back(type_ptr);
  580. }
  581. Register(types, py_fn);
  582. }
  583. void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
  584. std::vector<std::string> types_name;
  585. for (size_t it = 0; it < tuple.size(); ++it) {
  586. py::object name_py = tuple[it];
  587. if (py::isinstance<py::str>(name_py)) {
  588. types_name.push_back(name_py.cast<std::string>());
  589. continue;
  590. }
  591. MS_LOG(EXCEPTION) << "Register must be string";
  592. }
  593. Register(types_name, py_fn);
  594. }
  595. static TypePtr UnwrapRef(const TypePtr &type) {
  596. if (type->isa<RefType>()) {
  597. return type->cast<RefTypePtr>()->subtype();
  598. }
  599. return type;
  600. }
  601. FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
  602. bool find_fn = false;
  603. py::function py_fn;
  604. for (auto &item : fn_cache_py_) {
  605. TypePtrList sign = item.first;
  606. if (sign.size() != types.size()) {
  607. continue;
  608. }
  609. bool match = true;
  610. for (size_t i = 0; i < sign.size(); ++i) {
  611. if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
  612. match = false;
  613. break;
  614. }
  615. }
  616. if (!match) {
  617. continue;
  618. }
  619. find_fn = true;
  620. py_fn = item.second;
  621. break;
  622. }
  623. std::ostringstream buffer;
  624. buffer << types;
  625. if (find_fn) {
  626. FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
  627. if (func_graph == nullptr) {
  628. MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
  629. }
  630. MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
  631. return func_graph;
  632. }
  633. std::ostringstream oss;
  634. oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
  635. << "`, corresponding location info:\n";
  636. int idx = 0;
  637. for (auto &item : fn_cache_py_) {
  638. FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
  639. if (func_graph == nullptr) {
  640. MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
  641. continue;
  642. }
  643. oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
  644. }
  645. MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n"
  646. << oss.str();
  647. }
  648. REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
  649. (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
  650. *m, "MultitypeFuncGraph_")
  651. .def(py::init<std::string &>())
  652. .def("register_fn", &MultitypeFuncGraph::PyRegister);
  653. }));
  654. // Generate the ListMap func graph.
  655. FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  656. size_t args_num = args_spec_list.size();
  657. // args: fn, list1, list2, ...
  658. if (args_num < 2) {
  659. MS_LOG(EXCEPTION) << "list_map takes at least two arguments";
  660. }
  661. for (size_t i = 1; i < args_num; ++i) {
  662. if (typeid(args_spec_list[i]) != typeid(AbstractBase)) {
  663. // The function currently not be use
  664. MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'";
  665. }
  666. }
  667. FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
  668. fg_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  669. fg_ptr->debug_info()->set_name("list_map");
  670. AnfNodePtr fn = fg_ptr->add_parameter();
  671. std::vector<AnfNodePtr> lists;
  672. for (size_t i = 1; i < args_num; ++i) {
  673. lists.push_back(fg_ptr->add_parameter());
  674. }
  675. std::vector<AnfNodePtr> iters;
  676. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  677. return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item});
  678. });
  679. std::vector<AnfNodePtr> nexts;
  680. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  681. return fg_ptr->NewCNode({NewValueNode(std::string("next")), item});
  682. });
  683. std::vector<AnfNodePtr> values;
  684. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  685. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item});
  686. });
  687. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  688. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)});
  689. });
  690. (void)values.insert(values.begin(), fn);
  691. AnfNodePtr cnode_graph = fg_ptr->NewCNode(values);
  692. AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph});
  693. FuncGraphPtr fgnext_ptr = std::make_shared<FuncGraph>();
  694. fgnext_ptr->debug_info()->set_name("body");
  695. FuncGraphPtr fgcond_ptr = std::make_shared<FuncGraph>();
  696. fgcond_ptr->debug_info()->set_name("cond");
  697. MakeCond(lists, fgnext_ptr, fgcond_ptr);
  698. MakeNext(lists, fgcond_ptr, fgnext_ptr);
  699. CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl});
  700. auto inputs = output_cnode->inputs();
  701. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  702. output_cnode->set_inputs(inputs);
  703. fg_ptr->set_output(output_cnode);
  704. return fg_ptr;
  705. }
  706. void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
  707. const FuncGraphPtr &fg_ptr) {
  708. MS_EXCEPTION_IF_NULL(fg_ptr);
  709. AnfNodePtr fn = fg_ptr->add_parameter();
  710. AnfNodePtr resl = fg_ptr->add_parameter();
  711. std::vector<AnfNodePtr> iters;
  712. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  713. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  714. std::vector<AnfNodePtr> hasnexts;
  715. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) {
  716. return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item});
  717. });
  718. // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
  719. FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
  720. fgtrue_ptr->debug_info()->set_name("ftrue");
  721. fgtrue_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  722. CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl});
  723. auto inputs = fgtrue_output_cnode->inputs();
  724. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  725. fgtrue_output_cnode->set_inputs(inputs);
  726. fgtrue_ptr->set_output(fgtrue_output_cnode);
  727. FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
  728. fgfalse_ptr->debug_info()->set_name("ffalse");
  729. fgfalse_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  730. fgfalse_ptr->set_output(resl);
  731. AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
  732. NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)});
  733. fgtrue_ptr->set_output(output_cnode);
  734. }
  735. void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
  736. const FuncGraphPtr &fg_ptr) {
  737. MS_EXCEPTION_IF_NULL(fg_ptr);
  738. AnfNodePtr fn = fg_ptr->add_parameter();
  739. std::vector<AnfNodePtr> iters;
  740. (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
  741. [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
  742. std::vector<AnfNodePtr> nexts;
  743. (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
  744. return fg_ptr->NewCNode({NewValueNode(std::string("next")), item});
  745. });
  746. std::vector<AnfNodePtr> values;
  747. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
  748. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr});
  749. });
  750. iters.clear();
  751. (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
  752. return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)});
  753. });
  754. (void)values.insert(values.begin(), fn);
  755. AnfNodePtr cnode_graph = fg_ptr->NewCNode(values);
  756. AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph});
  757. CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl});
  758. auto inputs = output_cnode->inputs();
  759. (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
  760. output_cnode->set_inputs(inputs);
  761. fg_ptr->set_output(output_cnode);
  762. }
  763. FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  764. // args: tuple1, tuple2
  765. abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
  766. AbstractBasePtr abs_a = args_spec_list[0];
  767. AbstractBasePtr abs_b = args_spec_list[1];
  768. abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
  769. abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
  770. if (a_tuple == nullptr || b_tuple == nullptr) {
  771. MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
  772. << args_spec_list[1]->ToString();
  773. }
  774. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  775. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  776. AnfNodePtr p_tup_a = ret->add_parameter();
  777. AnfNodePtr p_tup_b = ret->add_parameter();
  778. std::vector<AnfNodePtr> elems;
  779. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  780. int tuple_size = SizeToInt(a_tuple->size());
  781. for (int i = 0; i < tuple_size; ++i) {
  782. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
  783. }
  784. tuple_size = SizeToInt(b_tuple->size());
  785. for (int i = 0; i < tuple_size; ++i) {
  786. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
  787. }
  788. ret->set_output(ret->NewCNode(elems));
  789. return ret;
  790. }
  791. int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
  792. MS_EXCEPTION_IF_NULL(scalar);
  793. return GetValue<int>(scalar->BuildValue());
  794. }
  795. bool CheckIndexInRange(int index, int min, int max) { return (index >= min && index <= max); }
  796. int GetPositiveIndex(int index, int length) {
  797. if (index < 0) {
  798. index += length;
  799. }
  800. return index;
  801. }
  802. int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) {
  803. MS_EXCEPTION_IF_NULL(member);
  804. if (member->isa<AbstractScalar>()) {
  805. return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
  806. }
  807. if (member->isa<AbstractNone>()) {
  808. return default_value;
  809. }
  810. MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString();
  811. }
  812. void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index,
  813. int *stop_index, int *step_value) {
  814. MS_EXCEPTION_IF_NULL(tuple);
  815. MS_EXCEPTION_IF_NULL(slice);
  816. MS_EXCEPTION_IF_NULL(start_index);
  817. MS_EXCEPTION_IF_NULL(stop_index);
  818. MS_EXCEPTION_IF_NULL(step_value);
  819. const std::string start_name("Slice start index");
  820. const std::string stop_name("Slice stop index");
  821. const std::string step_name("Slice step value");
  822. int tuple_size = SizeToInt(tuple->size());
  823. int start_default = 0;
  824. int stop_default = tuple_size;
  825. int step_default = 1;
  826. *step_value = CheckSliceMember(slice->step(), step_default, step_name);
  827. if (*step_value == 0) {
  828. MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0.";
  829. }
  830. if (*step_value < 0) {
  831. start_default = tuple_size - 1;
  832. stop_default = -1;
  833. }
  834. *start_index = CheckSliceMember(slice->start(), start_default, start_name);
  835. *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
  836. if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) ||
  837. !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) {
  838. MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
  839. << " out of range, tuple size " << tuple_size << ".";
  840. }
  841. *start_index = GetPositiveIndex(*start_index, tuple_size);
  842. if (!slice->stop()->isa<AbstractNone>()) {
  843. *stop_index = GetPositiveIndex(*stop_index, tuple_size);
  844. }
  845. }
  846. FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  847. // slice a tuple
  848. // args: tuple, start index, end index, step
  849. const std::string op_name("TupleSlice");
  850. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  851. AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  852. AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
  853. int start_index;
  854. int stop_index;
  855. int step_value;
  856. GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
  857. FuncGraphPtr ret = std::make_shared<FuncGraph>();
  858. ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  859. AnfNodePtr p_tuple = ret->add_parameter();
  860. (void)ret->add_parameter();
  861. std::vector<AnfNodePtr> elems;
  862. elems.push_back(NewValueNode(prim::kPrimMakeTuple));
  863. if (step_value > 0) {
  864. for (int index = start_index; index < stop_index; index = index + step_value) {
  865. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  866. }
  867. } else {
  868. for (int index = start_index; index > stop_index; index = index + step_value) {
  869. elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
  870. }
  871. }
  872. ret->set_output(ret->NewCNode(elems));
  873. return ret;
  874. }
  875. int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) {
  876. unsigned int number_dec = 0;
  877. for (size_t index = 0; index < number_bin.size(); index++) {
  878. number_dec |= number_bin[index] << index;
  879. }
  880. return static_cast<int>(number_dec);
  881. }
  882. void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end,
  883. std::vector<int> *strides, int length) {
  884. MS_EXCEPTION_IF_NULL(slice);
  885. MS_EXCEPTION_IF_NULL(begin);
  886. MS_EXCEPTION_IF_NULL(end);
  887. MS_EXCEPTION_IF_NULL(strides);
  888. if (length <= 0) {
  889. MS_LOG(EXCEPTION) << "Could not slice a dim when it's length less than 1";
  890. }
  891. int start_default = 0;
  892. int stop_default = length;
  893. int step_default = 1;
  894. int step_value = CheckSliceMember(slice->step(), step_default, "step");
  895. if (step_value < 0) {
  896. start_default = -1;
  897. stop_default = -(length + 1);
  898. }
  899. begin->push_back(CheckSliceMember(slice->start(), start_default, "begin"));
  900. end->push_back(CheckSliceMember(slice->stop(), stop_default, "stop"));
  901. strides->push_back(step_value);
  902. }
  903. int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape,
  904. std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
  905. MS_EXCEPTION_IF_NULL(slice_tuple);
  906. MS_EXCEPTION_IF_NULL(begin);
  907. MS_EXCEPTION_IF_NULL(end);
  908. MS_EXCEPTION_IF_NULL(strides);
  909. size_t slice_tuple_size = slice_tuple->size();
  910. size_t shape_size = shape.size();
  911. if (slice_tuple_size > shape_size) {
  912. MS_LOG(EXCEPTION) << "The number of slice data to slice tensor should be less than the rank of tensor,"
  913. "when the rank of tensor is "
  914. << shape_size << ", the number of slice is " << slice_tuple_size;
  915. }
  916. std::vector<unsigned int> shrink;
  917. auto slice_tuple_eles = slice_tuple->elements();
  918. size_t ellipsis_num = 0;
  919. for (size_t index = 0; index < slice_tuple_size; index++) {
  920. if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
  921. AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
  922. ParseSlice(slice, begin, end, strides, shape[index]);
  923. shrink.push_back(0);
  924. continue;
  925. }
  926. if (slice_tuple_eles[index]->isa<AbstractScalar>()) {
  927. int ele_index = GetArgScalarValue(dyn_cast<AbstractScalar>(slice_tuple_eles[index]), "slice_tuple");
  928. begin->push_back(ele_index);
  929. end->push_back(ele_index + 1);
  930. strides->push_back(1);
  931. shrink.push_back(1);
  932. continue;
  933. }
  934. if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
  935. ellipsis_num++;
  936. if (ellipsis_num > 1) {
  937. MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
  938. }
  939. size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
  940. begin->insert(begin->end(), ellipsis_len, 0);
  941. end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
  942. strides->insert(strides->end(), ellipsis_len, 1);
  943. shrink.insert(shrink.end(), ellipsis_len, 0);
  944. continue;
  945. }
  946. MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
  947. << slice_tuple_eles[index]->ToString();
  948. }
  949. if (ellipsis_num == 0) {
  950. for (size_t index = slice_tuple_size; index < shape_size; index++) {
  951. begin->push_back(0);
  952. end->push_back(shape[index]);
  953. strides->push_back(1);
  954. }
  955. }
  956. return ConvertBinaryToDecimal(shrink);
  957. }
  958. int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape,
  959. std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
  960. MS_EXCEPTION_IF_NULL(begin);
  961. MS_EXCEPTION_IF_NULL(end);
  962. MS_EXCEPTION_IF_NULL(strides);
  963. size_t shape_size = shape.size();
  964. if (shape_size == 0) {
  965. MS_LOG(EXCEPTION) << "Could slice a scalar tensor";
  966. }
  967. ParseSlice(slice, begin, end, strides, shape[0]);
  968. for (size_t index = 1; index < shape_size; index++) {
  969. begin->push_back(0);
  970. end->push_back(shape[index]);
  971. strides->push_back(1);
  972. }
  973. return 0;
  974. }
  975. int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape,
  976. std::vector<int> *begin, std::vector<int> *end,
  977. std::vector<int> *strides) {
  978. MS_EXCEPTION_IF_NULL(begin);
  979. MS_EXCEPTION_IF_NULL(end);
  980. MS_EXCEPTION_IF_NULL(strides);
  981. int ele_index = GetArgScalarValue(scalar, "slice_tuple");
  982. begin->push_back(ele_index);
  983. end->push_back(ele_index + 1);
  984. strides->push_back(1);
  985. for (size_t index = 1; index < shape.size(); index++) {
  986. begin->push_back(0);
  987. end->push_back(shape[index]);
  988. strides->push_back(1);
  989. }
  990. return 1;
  991. }
  992. FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) {
  993. auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
  994. ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
  995. return ret_graph;
  996. }
  997. FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  998. // slice a tensor
  999. // args: tensor, slice or slice tuple
  1000. const std::string op_name = std::string("TensorSlice");
  1001. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  1002. AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  1003. FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
  1004. ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  1005. AnfNodePtr tensor_node = ret_graph->add_parameter();
  1006. (void)ret_graph->add_parameter();
  1007. auto shape = tensorPtr->shape()->shape();
  1008. std::vector<int> begin;
  1009. std::vector<int> end;
  1010. std::vector<int> strides;
  1011. int shrink_axis_mask;
  1012. if (args_spec_list[1]->isa<AbstractTuple>()) {
  1013. AbstractTuplePtr tuple_ptr = dyn_cast<AbstractTuple>(args_spec_list[1]);
  1014. shrink_axis_mask = GenerateStridedSliceParametersFromTuple(tuple_ptr, shape, &begin, &end, &strides);
  1015. } else if (args_spec_list[1]->isa<AbstractSlice>()) {
  1016. AbstractSlicePtr slice_ptr = dyn_cast<AbstractSlice>(args_spec_list[1]);
  1017. shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
  1018. } else if (args_spec_list[1]->isa<AbstractScalar>()) {
  1019. AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
  1020. if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
  1021. if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
  1022. return ExpandADim(ret_graph, tensor_node);
  1023. }
  1024. MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
  1025. }
  1026. shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
  1027. } else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
  1028. ret_graph->set_output(tensor_node);
  1029. return ret_graph;
  1030. } else if (args_spec_list[1]->isa<AbstractNone>()) {
  1031. return ExpandADim(ret_graph, tensor_node);
  1032. } else {
  1033. std::ostringstream args_info;
  1034. for (const auto &arg : args_spec_list) {
  1035. MS_EXCEPTION_IF_NULL(arg);
  1036. args_info << arg->ToString() << "\n";
  1037. }
  1038. MS_LOG(EXCEPTION)
  1039. << "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
  1040. << args_info.str();
  1041. }
  1042. auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
  1043. auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
  1044. NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
  1045. ret_graph->set_output(ret_graph->NewCNode(
  1046. {PrimStridedSlice, tensor_node, NewValueNode(begin), NewValueNode(end), NewValueNode(strides)}));
  1047. return ret_graph;
  1048. }
  1049. FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
  1050. // select indexed item
  1051. // args: tuple of items, index
  1052. const std::string op_name = std::string("TupleGetItemTensor");
  1053. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  1054. AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  1055. AbstractBasePtrList branches = branches_abs->elements();
  1056. if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
  1057. FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
  1058. ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
  1059. AnfNodePtr functions = ret_graph->add_parameter();
  1060. auto index = ret_graph->add_parameter();
  1061. ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
  1062. return ret_graph;
  1063. }
  1064. MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << ".";
  1065. }
  1066. REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
  1067. (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
  1068. .def(py::init<std::string &>());
  1069. }));
  1070. REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
  1071. (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
  1072. .def(py::init<std::string &>());
  1073. }));
  1074. REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
  1075. (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
  1076. .def(py::init<std::string &>());
  1077. }));
  1078. REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
  1079. (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
  1080. *m, "TupleGetItemTensor_")
  1081. .def(py::init<std::string &>());
  1082. }));
  1083. } // namespace prim
  1084. } // namespace mindspore