You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_structures.cc 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/prim.h"
  19. #include "pipeline/static_analysis/utils.h"
  20. #include "pipeline/static_analysis/param_validator.h"
  21. #include "operator/ops.h"
  22. #include "utils/convert_utils.h"
  23. namespace mindspore {
  24. namespace abstract {
  25. AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  26. const AbstractBasePtrList &args_spec_list) {
  27. // Inputs: two scalars whose value is a string.
  28. const std::string op_name = primitive->name();
  29. CheckArgsSize(op_name, args_spec_list, 2);
  30. AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  31. AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  32. ValuePtr value_x = scalar_x->BuildValue();
  33. ValuePtr value_y = scalar_y->BuildValue();
  34. if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
  35. MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
  36. << ", param1: " << value_y->ToString();
  37. }
  38. bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
  39. return std::make_shared<AbstractScalar>(ret);
  40. }
  41. AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  42. const AbstractBasePtrList &args_spec_list) {
  43. // Inputs: two scalars whose value is a string.
  44. const std::string op_name = primitive->name();
  45. CheckArgsSize(op_name, args_spec_list, 2);
  46. AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  47. AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  48. ValuePtr value_x = scalar_x->BuildValue();
  49. ValuePtr value_y = scalar_y->BuildValue();
  50. if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
  51. MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
  52. << ", param1: " << value_y->ToString();
  53. }
  54. std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
  55. return std::make_shared<AbstractScalar>(ret);
  56. }
  57. AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
  58. const AbstractBasePtrList &args_spec_list) {
  59. return std::make_shared<AbstractTuple>(args_spec_list);
  60. }
  61. AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
  62. const AbstractBasePtrList &args_spec_list) {
  63. return std::make_shared<AbstractList>(args_spec_list);
  64. }
  65. AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  66. const AbstractBasePtrList &args_spec_list) {
  67. // Inputs: two tuples.
  68. const std::string op_name = primitive->name();
  69. CheckArgsSize(op_name, args_spec_list, 2);
  70. AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  71. AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  72. size_t keys_size = keys->size();
  73. if (values->size() != keys_size) {
  74. MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
  75. }
  76. std::vector<AbstractAttribute> key_value;
  77. AbstractScalarPtr key;
  78. AbstractBasePtrList key_list = keys->elements();
  79. AbstractBasePtrList value_list = values->elements();
  80. for (size_t index = 0; index < keys_size; index++) {
  81. key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
  82. ValuePtr keyPtr = key->BuildValue();
  83. MS_EXCEPTION_IF_NULL(keyPtr);
  84. if (!keyPtr->isa<StringImm>()) {
  85. MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
  86. }
  87. std::string key_string = GetValue<std::string>(keyPtr);
  88. key_value.emplace_back(key_string, value_list[index]);
  89. }
  90. return std::make_shared<AbstractDictionary>(key_value);
  91. }
  92. AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  93. const AbstractBasePtrList &args_spec_list) {
  94. // Inputs: a string and an object of a subclass of AbstractBase.
  95. const std::string op_name = primitive->name();
  96. CheckArgsSize(op_name, args_spec_list, 2);
  97. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  98. ValuePtr keyPtr = key->BuildValue();
  99. if (!keyPtr->isa<StringImm>()) {
  100. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
  101. }
  102. std::string key_string = GetValue<std::string>(keyPtr);
  103. return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
  104. }
  105. AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  106. const AbstractBasePtrList &args_spec_list) {
  107. // Inputs: a string and a keyword.
  108. const std::string op_name = primitive->name();
  109. CheckArgsSize(op_name, args_spec_list, 2);
  110. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  111. AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
  112. ValuePtr key_value = key->BuildValue();
  113. if (!key_value->isa<StringImm>()) {
  114. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  115. }
  116. std::string key_input = GetValue<std::string>(key_value);
  117. std::string key_actual = kwarg->get_key();
  118. if (key_actual != key_input) {
  119. MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
  120. << key_input << ", AbstractKeywordArg' key is " << key_actual;
  121. }
  122. return kwarg->get_arg();
  123. }
  124. AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  125. const AbstractBasePtrList &args_spec_list) {
  126. // Inputs: three scalars whose value is an int32 number.
  127. CheckArgsSize(primitive->name(), args_spec_list, 3);
  128. size_t args_size = args_spec_list.size();
  129. for (size_t index = 0; index < args_size; index++) {
  130. MS_EXCEPTION_IF_NULL(args_spec_list[index]);
  131. if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
  132. MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
  133. }
  134. if (args_spec_list[index]->isa<AbstractScalar>() &&
  135. !dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
  136. MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number.";
  137. }
  138. }
  139. // Slice: start, end, step
  140. return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
  141. }
  142. // Eval the return type of make_record
  143. AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  144. const AbstractBasePtrList &args_spec_list) {
  145. // Inputs: at lease two objects of a subclass of AbstractBase.
  146. if (args_spec_list.size() < 2) {
  147. MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
  148. << args_spec_list.size() << ".";
  149. }
  150. // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
  151. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  152. TypePtr type = args_spec_list[0]->GetTypeTrack();
  153. MS_EXCEPTION_IF_NULL(type);
  154. if (type->type_id() != kMetaTypeTypeType) {
  155. MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
  156. }
  157. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  158. MS_EXCEPTION_IF_NULL(value_track);
  159. TypePtr type_ptr = value_track->cast<TypePtr>();
  160. if (type_ptr == nullptr) {
  161. MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
  162. }
  163. auto cls = dyn_cast<Class>(type_ptr);
  164. MS_EXCEPTION_IF_NULL(cls);
  165. ClassAttrVector attributes = cls->GetAttributes();
  166. CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
  167. std::vector<AbstractAttribute> abs_attributes;
  168. for (size_t i = 0; i < attributes.size(); i++) {
  169. AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
  170. abs_attributes.push_back(elem);
  171. }
  172. return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
  173. }
  174. template <typename T>
  175. AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  176. // Inputs: a tuple or list and a scalar whose value is an int32 number.
  177. CheckArgsSize(op_name, args_spec_list, 2);
  178. auto queue = CheckArg<T>(op_name, args_spec_list, 0);
  179. AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  180. ValuePtr index_value = index->BuildValue();
  181. if (!index_value->isa<Int32Imm>()) {
  182. MS_LOG(EXCEPTION) << op_name << " evaluator index should be an int32 number, but got " << index_value->ToString();
  183. }
  184. int idx_v = GetValue<int>(index_value);
  185. std::size_t nelems = queue->elements().size();
  186. if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
  187. MS_LOG(EXCEPTION) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
  188. << SizeToInt(nelems) << "), but got " << idx_v << ".";
  189. }
  190. std::size_t uidx_v = 0;
  191. if (idx_v >= 0) {
  192. uidx_v = IntToSize(idx_v);
  193. } else {
  194. uidx_v = IntToSize(idx_v + SizeToInt(nelems));
  195. }
  196. return queue->elements()[uidx_v];
  197. }
  198. template <typename T>
  199. AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  200. // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
  201. CheckArgsSize(op_name, args_spec_list, 3);
  202. auto queue = CheckArg<T>(op_name, args_spec_list, 0);
  203. AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  204. ValuePtr index_value = index->BuildValue();
  205. if (!index_value->isa<Int32Imm>()) {
  206. MS_LOG(EXCEPTION) << op_name << " evaluator index should be an int32 number, but got " << index_value->ToString();
  207. }
  208. int idx_v = GetValue<int>(index_value);
  209. if (idx_v < 0) {
  210. MS_LOG(EXCEPTION) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v << ".";
  211. }
  212. size_t uidx_v = IntToSize(idx_v);
  213. AbstractBasePtrList elements = queue->elements();
  214. std::size_t nelems = elements.size();
  215. if (uidx_v >= nelems) {
  216. MS_LOG(EXCEPTION) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 << ".";
  217. }
  218. elements[uidx_v] = args_spec_list[2];
  219. return std::make_shared<T>(elements);
  220. }
  221. AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  222. const AbstractBasePtrList &args_spec_list) {
  223. return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
  224. }
  225. AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  226. const AbstractBasePtrList &args_spec_list) {
  227. return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
  228. }
  229. AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  230. const AbstractBasePtrList &args_spec_list) {
  231. return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
  232. }
  233. AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  234. const AbstractBasePtrList &args_spec_list) {
  235. return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
  236. }
  237. AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  238. const AbstractBasePtrList &args_spec_list) {
  239. // Inputs: a dict and a scalar whose value is a string.
  240. const std::string op_name = primitive->name();
  241. CheckArgsSize(op_name, args_spec_list, 2);
  242. AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
  243. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  244. ValuePtr key_value = key->BuildValue();
  245. if (!key_value->isa<StringImm>()) {
  246. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  247. }
  248. auto key_str = GetValue<std::string>(key_value);
  249. std::vector<AbstractAttribute> dict_elems = dict->elements();
  250. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  251. [key_str](const AbstractAttribute &item) { return item.first == key_str; });
  252. if (it == dict_elems.end()) {
  253. MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
  254. }
  255. return it->second;
  256. }
  257. AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  258. const AbstractBasePtrList &args_spec_list) {
  259. // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
  260. const std::string op_name = primitive->name();
  261. CheckArgsSize(op_name, args_spec_list, 3);
  262. AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
  263. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  264. ValuePtr key_value = key->BuildValue();
  265. if (!key_value->isa<StringImm>()) {
  266. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  267. }
  268. std::string key_str = GetValue<std::string>(key_value);
  269. std::vector<AbstractAttribute> dict_elems = dict->elements();
  270. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  271. [key_str](AbstractAttribute &item) { return item.first == key_str; });
  272. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  273. auto new_ele = std::make_pair(key_str, args_spec_list[2]);
  274. if (it != dict_elems.end()) {
  275. int index = it - dict_elems.begin();
  276. dict_elems[IntToSize(index)] = new_ele;
  277. } else {
  278. dict_elems.push_back(new_ele);
  279. }
  280. return std::make_shared<AbstractDictionary>(dict_elems);
  281. }
  282. AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  283. const AbstractBasePtrList &args_spec_list) {
  284. // Inputs: a list and an object of a subclass of AbstractBase.
  285. const std::string op_name = primitive->name();
  286. CheckArgsSize(op_name, args_spec_list, 2);
  287. AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
  288. (void)AbstractJoin(list->elements());
  289. return list;
  290. }
  291. template <typename T>
  292. AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  293. // Inputs: a tuple or list or dict.
  294. CheckArgsSize(op_name, args_spec_list, 1);
  295. auto arg = CheckArg<T>(op_name, args_spec_list, 0);
  296. return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
  297. }
  298. AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  299. const AbstractBasePtrList &args_spec_list) {
  300. return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
  301. }
  302. AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  303. const AbstractBasePtrList &args_spec_list) {
  304. return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
  305. }
  306. AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  307. const AbstractBasePtrList &args_spec_list) {
  308. return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
  309. }
  310. AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
  311. const AbstractBasePtrList &args_spec_list) {
  312. return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  313. }
  314. AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
  315. const AbstractBasePtrList &args_spec_list) {
  316. // Inputs: fn, list1, list2, ...
  317. MS_EXCEPTION_IF_NULL(engine);
  318. if (args_spec_list.size() <= 1) {
  319. MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
  320. }
  321. AbstractFunctionPtr fn = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
  322. // check args from 1.
  323. CheckArgsSpec<AbstractList>(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end()));
  324. AbstractBasePtrList subargs;
  325. for (std::size_t i = 1; i < args_spec_list.size(); i++) {
  326. AbstractListPtr l_ptr = dyn_cast<AbstractList>(args_spec_list[i]);
  327. if (l_ptr == nullptr) {
  328. MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list.";
  329. }
  330. subargs.push_back(AbstractJoin(l_ptr->elements()));
  331. }
  332. AbstractBasePtr engin_exc = engine->Execute(fn, subargs);
  333. AbstractBasePtrList result;
  334. for (std::size_t i = 1; i < args_spec_list.size(); i++) {
  335. result.push_back(engin_exc);
  336. }
  337. return std::make_shared<AbstractList>(result);
  338. }
  339. AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
  340. const AbstractBasePtrList &args_spec_list) {
  341. // Inputs: a fn, a list and an object of a subclass of a AbstractBase.
  342. MS_EXCEPTION_IF_NULL(engine);
  343. const std::string op_name = primitive->name();
  344. CheckArgsSize(op_name, args_spec_list, 3);
  345. AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
  346. AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
  347. AbstractBasePtr dflt = args_spec_list[2];
  348. AbstractBasePtr list_type = AbstractJoin(lst->elements());
  349. auto result1 = engine->Execute(fn, lst->elements());
  350. auto result2 = engine->Execute(fn, {dflt, list_type});
  351. MS_EXCEPTION_IF_NULL(result1);
  352. return result1->Join(result2);
  353. }
  354. AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  355. const AbstractBasePtrList &args_spec_list) {
  356. // Inputs: a tuple
  357. const std::string op_name = primitive->name();
  358. CheckArgsSize(op_name, args_spec_list, 1);
  359. AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  360. auto tuple_elements = input->elements();
  361. AbstractBasePtrList elem_list;
  362. (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
  363. [](const AbstractBasePtr &elem) { return elem->Clone(); });
  364. return std::make_shared<AbstractTuple>(elem_list);
  365. }
  366. AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
  367. const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
  368. size_t x_rank = x_shape->size();
  369. std::set<int> axis_set;
  370. auto axis_data = axis_value_ptr->value();
  371. if (axis_data.empty()) {
  372. int size = 1;
  373. AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
  374. return std::make_shared<AbstractTuple>(values);
  375. }
  376. for (auto &elem : axis_data) {
  377. int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
  378. (void)axis_set.insert(e_value);
  379. }
  380. auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
  381. if (x_shp_data.size() < x_rank) {
  382. MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
  383. }
  384. AbstractBasePtrList values;
  385. for (size_t i = 0; i < x_rank; i++) {
  386. if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
  387. auto axis_v = MakeValue(1);
  388. values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
  389. } else {
  390. int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
  391. auto dim = MakeValue(dim_value);
  392. values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
  393. }
  394. }
  395. return std::make_shared<AbstractTuple>(values);
  396. }
  397. AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  398. const AbstractBasePtrList &args_spec_list) {
  399. // Inputs: x_shape, axis
  400. const std::string op_name = primitive->name();
  401. CheckArgsSize(op_name, args_spec_list, 2);
  402. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  403. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  404. auto x_shp_value = shape_x->BuildValue();
  405. if (x_shp_value->isa<AnyValue>()) {
  406. MS_LOG(EXCEPTION) << op_name
  407. << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
  408. }
  409. // Axis can be scalar, tuple or None
  410. AbstractTuplePtr axis = nullptr;
  411. if (args_spec_list[1]->isa<AbstractScalar>()) {
  412. MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
  413. AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
  414. axis = std::make_shared<AbstractTuple>(axis_list);
  415. } else if (args_spec_list[1]->isa<AbstractTuple>()) {
  416. MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
  417. axis = args_spec_list[1]->cast<AbstractTuplePtr>();
  418. } else {
  419. MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
  420. << args_spec_list[1]->ToString();
  421. }
  422. auto axis_value = axis->BuildValue();
  423. if (axis_value->isa<AnyValue>()) {
  424. MS_LOG(EXCEPTION) << op_name
  425. << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
  426. }
  427. auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
  428. MS_EXCEPTION_IF_NULL(axis_value_ptr);
  429. return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
  430. }
  431. AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  432. const AbstractBasePtrList &args_spec_list) {
  433. // Inputs: two tuples.
  434. const std::string op_name = primitive->name();
  435. CheckArgsSize(op_name, args_spec_list, 2);
  436. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  437. AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  438. MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();
  439. auto div_shp_value = div_shp->BuildValue();
  440. if (div_shp_value->isa<AnyValue>()) {
  441. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString();
  442. }
  443. auto shpx_value = shape_x->BuildValue();
  444. if (shpx_value->isa<AnyValue>()) {
  445. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString();
  446. }
  447. if (div_shp->size() != shape_x->size()) {
  448. MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size()
  449. << ", shapex: " << shape_x->size() << ".";
  450. }
  451. auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
  452. auto div_shp_data = div_shp_value->cast<ValueTuplePtr>()->value();
  453. AbstractBasePtrList values;
  454. for (size_t i = 0; i < div_shp_data.size(); i++) {
  455. if (div_shp_data[i]->cast<Int32ImmPtr>() == nullptr) {
  456. MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString();
  457. }
  458. int shapex_value = GetValue<int>(shpx_data[i]);
  459. int div_value = GetValue<int>(div_shp_data[i]);
  460. MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
  461. if (div_value == 0) {
  462. MS_LOG(EXCEPTION) << "error: division value should not be 0!";
  463. }
  464. if ((shapex_value % div_value) != 0) {
  465. MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value;
  466. }
  467. int result = shapex_value / div_value;
  468. auto result_v = MakeValue(result);
  469. values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
  470. }
  471. return std::make_shared<AbstractTuple>(values);
  472. }
  473. AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  474. const AbstractBasePtrList &args_spec_list) {
  475. // Inputs: a tuple
  476. const std::string op_name = primitive->name();
  477. CheckArgsSize(op_name, args_spec_list, 1);
  478. AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  479. py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
  480. py::array data = py::array(data_tuple);
  481. auto tensor = std::make_shared<tensor::Tensor>(data);
  482. auto ret = tensor->ToAbstract();
  483. ret->set_value(tensor);
  484. MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
  485. return ret;
  486. }
  487. AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  488. const AbstractBasePtrList &args_spec_list) {
  489. // Inputs: a tuple
  490. // example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6
  491. const std::string op_name = primitive->name();
  492. CheckArgsSize(op_name, args_spec_list, 1);
  493. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  494. auto shpx_value = shape_x->BuildValue();
  495. if (shpx_value->isa<AnyValue>()) {
  496. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString();
  497. }
  498. auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
  499. int result = 1;
  500. for (size_t i = 0; i < shpx_data.size(); i++) {
  501. int value = GetValue<int>(shpx_data[i]);
  502. IntMulWithOverflowCheck(result, value, &result);
  503. }
  504. auto result_v = MakeValue(result);
  505. MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString();
  506. return std::make_shared<AbstractScalar>(result_v, result_v->type());
  507. }
  508. template <typename T>
  509. AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  510. // Inputs: two tuples or two lists.
  511. CheckArgsSize(op_name, args_spec_list, 2);
  512. auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
  513. auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
  514. ValuePtr x_value = input_x->BuildValue();
  515. ValuePtr y_value = input_y->BuildValue();
  516. return std::make_shared<AbstractScalar>(*x_value == *y_value);
  517. }
  518. AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  519. const AbstractBasePtrList &args_spec_list) {
  520. return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
  521. }
  522. AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  523. const AbstractBasePtrList &args_spec_list) {
  524. return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
  525. }
  526. struct SlideInfo {
  527. int start;
  528. int step;
  529. int stop;
  530. };
  531. void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
  532. int arg1 = 0;
  533. int arg2 = 0;
  534. if (!args_spec_list.empty()) {
  535. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  536. auto arg_value = args_spec_list[0]->BuildValue();
  537. if (!arg_value->isa<Int32Imm>()) {
  538. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  539. }
  540. arg1 = GetValue<int>(arg_value);
  541. }
  542. if (args_spec_list.size() >= 2) {
  543. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  544. auto arg_value = args_spec_list[1]->BuildValue();
  545. if (!arg_value->isa<Int32Imm>()) {
  546. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  547. }
  548. arg2 = GetValue<int>(arg_value);
  549. }
  550. if (args_spec_list.size() == 3) {
  551. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  552. auto arg_value = args_spec_list[2]->BuildValue();
  553. if (!arg_value->isa<Int32Imm>()) {
  554. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  555. }
  556. slide->step = GetValue<int>(arg_value);
  557. slide->start = arg1;
  558. slide->stop = arg2;
  559. }
  560. if (args_spec_list.size() == 2) {
  561. slide->start = arg1;
  562. slide->stop = arg2;
  563. }
  564. if (args_spec_list.size() == 1) {
  565. slide->stop = arg1;
  566. }
  567. }
  568. AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
  569. const AbstractBasePtrList &args_spec_list) {
  570. if (args_spec_list.empty()) {
  571. MS_LOG(EXCEPTION) << "Cannot make range from empty input.";
  572. }
  573. if (args_spec_list.size() > 3) {
  574. MS_LOG(EXCEPTION) << "Error args size of make range operational.";
  575. }
  576. SlideInfo slide = {0, 1, 0};
  577. CalcSlidePara(args_spec_list, &slide);
  578. if (slide.step == 0) {
  579. MS_LOG(EXCEPTION) << "Error, step value is 0.";
  580. }
  581. AbstractBasePtrList args;
  582. if (slide.start <= slide.stop) {
  583. if (slide.step <= 0) {
  584. MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
  585. }
  586. for (int i = slide.start; i < slide.stop; i += slide.step) {
  587. args.push_back(abstract::FromValue(i));
  588. }
  589. } else {
  590. if (slide.step >= 0) {
  591. MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
  592. }
  593. for (int i = slide.start; i > slide.stop; i += slide.step) {
  594. args.push_back(abstract::FromValue(i));
  595. }
  596. }
  597. return std::make_shared<AbstractTuple>(args);
  598. }
  599. AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  600. const AbstractBasePtrList &args_spec_list) {
  601. // Inputs: a tensor
  602. CheckArgsSize(primitive->name(), args_spec_list, 1);
  603. return args_spec_list[0]->Clone();
  604. }
  605. } // namespace abstract
  606. } // namespace mindspore