You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_structures.cc 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/prim.h"
  19. #include "pipeline/static_analysis/utils.h"
  20. #include "pipeline/static_analysis/param_validator.h"
  21. #include "operator/ops.h"
  22. #include "utils/convert_utils.h"
  23. #include "ir/tensor_py.h"
  24. using mindspore::tensor::TensorPy;
  25. namespace mindspore {
  26. namespace abstract {
  27. AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  28. const AbstractBasePtrList &args_spec_list) {
  29. // Inputs: two scalars whose value is a string.
  30. const std::string op_name = primitive->name();
  31. CheckArgsSize(op_name, args_spec_list, 2);
  32. AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  33. AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  34. ValuePtr value_x = scalar_x->BuildValue();
  35. ValuePtr value_y = scalar_y->BuildValue();
  36. if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
  37. MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
  38. << ", param1: " << value_y->ToString();
  39. }
  40. bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
  41. return std::make_shared<AbstractScalar>(ret);
  42. }
  43. AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  44. const AbstractBasePtrList &args_spec_list) {
  45. // Inputs: two scalars whose value is a string.
  46. const std::string op_name = primitive->name();
  47. CheckArgsSize(op_name, args_spec_list, 2);
  48. AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  49. AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  50. ValuePtr value_x = scalar_x->BuildValue();
  51. ValuePtr value_y = scalar_y->BuildValue();
  52. if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
  53. MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
  54. << ", param1: " << value_y->ToString();
  55. }
  56. std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
  57. return std::make_shared<AbstractScalar>(ret);
  58. }
  59. AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
  60. const AbstractBasePtrList &args_spec_list) {
  61. return std::make_shared<AbstractTuple>(args_spec_list);
  62. }
  63. AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
  64. const AbstractBasePtrList &args_spec_list) {
  65. return std::make_shared<AbstractList>(args_spec_list);
  66. }
  67. AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  68. const AbstractBasePtrList &args_spec_list) {
  69. // Inputs: two tuples.
  70. const std::string op_name = primitive->name();
  71. CheckArgsSize(op_name, args_spec_list, 2);
  72. AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  73. AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  74. size_t keys_size = keys->size();
  75. if (values->size() != keys_size) {
  76. MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
  77. }
  78. std::vector<AbstractAttribute> key_value;
  79. AbstractScalarPtr key;
  80. AbstractBasePtrList key_list = keys->elements();
  81. AbstractBasePtrList value_list = values->elements();
  82. for (size_t index = 0; index < keys_size; index++) {
  83. key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
  84. ValuePtr keyPtr = key->BuildValue();
  85. MS_EXCEPTION_IF_NULL(keyPtr);
  86. if (!keyPtr->isa<StringImm>()) {
  87. MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
  88. }
  89. std::string key_string = GetValue<std::string>(keyPtr);
  90. key_value.emplace_back(key_string, value_list[index]);
  91. }
  92. return std::make_shared<AbstractDictionary>(key_value);
  93. }
  94. AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  95. const AbstractBasePtrList &args_spec_list) {
  96. // Inputs: a string and an object of a subclass of AbstractBase.
  97. const std::string op_name = primitive->name();
  98. CheckArgsSize(op_name, args_spec_list, 2);
  99. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  100. ValuePtr keyPtr = key->BuildValue();
  101. if (!keyPtr->isa<StringImm>()) {
  102. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
  103. }
  104. std::string key_string = GetValue<std::string>(keyPtr);
  105. return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
  106. }
  107. AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  108. const AbstractBasePtrList &args_spec_list) {
  109. // Inputs: a string and a keyword.
  110. const std::string op_name = primitive->name();
  111. CheckArgsSize(op_name, args_spec_list, 2);
  112. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  113. AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
  114. ValuePtr key_value = key->BuildValue();
  115. if (!key_value->isa<StringImm>()) {
  116. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  117. }
  118. std::string key_input = GetValue<std::string>(key_value);
  119. std::string key_actual = kwarg->get_key();
  120. if (key_actual != key_input) {
  121. MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
  122. << key_input << ", AbstractKeywordArg' key is " << key_actual;
  123. }
  124. return kwarg->get_arg();
  125. }
  126. AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  127. const AbstractBasePtrList &args_spec_list) {
  128. // Inputs: three scalars whose value is an int32 number.
  129. CheckArgsSize(primitive->name(), args_spec_list, 3);
  130. size_t args_size = args_spec_list.size();
  131. for (size_t index = 0; index < args_size; index++) {
  132. MS_EXCEPTION_IF_NULL(args_spec_list[index]);
  133. if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
  134. MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
  135. }
  136. if (args_spec_list[index]->isa<AbstractScalar>() &&
  137. !dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
  138. MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number.";
  139. }
  140. }
  141. // Slice: start, end, step
  142. return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
  143. }
  144. // Eval the return type of make_record
  145. AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  146. const AbstractBasePtrList &args_spec_list) {
  147. // Inputs: at lease two objects of a subclass of AbstractBase.
  148. if (args_spec_list.size() < 2) {
  149. MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
  150. << args_spec_list.size() << ".";
  151. }
  152. // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
  153. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  154. TypePtr type = args_spec_list[0]->GetTypeTrack();
  155. MS_EXCEPTION_IF_NULL(type);
  156. if (type->type_id() != kMetaTypeTypeType) {
  157. MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
  158. }
  159. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  160. MS_EXCEPTION_IF_NULL(value_track);
  161. TypePtr type_ptr = value_track->cast<TypePtr>();
  162. if (type_ptr == nullptr) {
  163. MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
  164. }
  165. auto cls = dyn_cast<Class>(type_ptr);
  166. MS_EXCEPTION_IF_NULL(cls);
  167. ClassAttrVector attributes = cls->GetAttributes();
  168. CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
  169. std::vector<AbstractAttribute> abs_attributes;
  170. for (size_t i = 0; i < attributes.size(); i++) {
  171. AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
  172. abs_attributes.push_back(elem);
  173. }
  174. return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
  175. }
  176. template <typename T>
  177. AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  178. // Inputs: a tuple or list and a scalar whose value is an int32 number.
  179. CheckArgsSize(op_name, args_spec_list, 2);
  180. auto queue = CheckArg<T>(op_name, args_spec_list, 0);
  181. AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  182. ValuePtr index_value = index->BuildValue();
  183. if (!index_value->isa<Int32Imm>()) {
  184. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
  185. << index_value->ToString();
  186. }
  187. int idx_v = GetValue<int>(index_value);
  188. std::size_t nelems = queue->elements().size();
  189. if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
  190. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
  191. << SizeToInt(nelems) << "), but got " << idx_v << ".";
  192. }
  193. std::size_t uidx_v = 0;
  194. if (idx_v >= 0) {
  195. uidx_v = IntToSize(idx_v);
  196. } else {
  197. uidx_v = IntToSize(idx_v + SizeToInt(nelems));
  198. }
  199. return queue->elements()[uidx_v];
  200. }
  201. template <typename T>
  202. AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  203. // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
  204. CheckArgsSize(op_name, args_spec_list, 3);
  205. auto queue = CheckArg<T>(op_name, args_spec_list, 0);
  206. AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  207. ValuePtr index_value = index->BuildValue();
  208. if (!index_value->isa<Int32Imm>()) {
  209. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
  210. << index_value->ToString();
  211. }
  212. int idx_v = GetValue<int>(index_value);
  213. if (idx_v < 0) {
  214. MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
  215. << ".";
  216. }
  217. size_t uidx_v = IntToSize(idx_v);
  218. AbstractBasePtrList elements = queue->elements();
  219. std::size_t nelems = elements.size();
  220. if (uidx_v >= nelems) {
  221. MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
  222. << ".";
  223. }
  224. elements[uidx_v] = args_spec_list[2];
  225. return std::make_shared<T>(elements);
  226. }
  227. AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  228. const AbstractBasePtrList &args_spec_list) {
  229. return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
  230. }
  231. AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  232. const AbstractBasePtrList &args_spec_list) {
  233. return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
  234. }
  235. AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  236. const AbstractBasePtrList &args_spec_list) {
  237. return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
  238. }
  239. AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  240. const AbstractBasePtrList &args_spec_list) {
  241. return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
  242. }
  243. AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  244. const AbstractBasePtrList &args_spec_list) {
  245. // Inputs: a dict and a scalar whose value is a string.
  246. const std::string op_name = primitive->name();
  247. CheckArgsSize(op_name, args_spec_list, 2);
  248. AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
  249. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  250. ValuePtr key_value = key->BuildValue();
  251. if (!key_value->isa<StringImm>()) {
  252. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  253. }
  254. auto key_str = GetValue<std::string>(key_value);
  255. std::vector<AbstractAttribute> dict_elems = dict->elements();
  256. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  257. [key_str](const AbstractAttribute &item) { return item.first == key_str; });
  258. if (it == dict_elems.end()) {
  259. MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
  260. }
  261. return it->second;
  262. }
  263. AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  264. const AbstractBasePtrList &args_spec_list) {
  265. // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
  266. const std::string op_name = primitive->name();
  267. CheckArgsSize(op_name, args_spec_list, 3);
  268. AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
  269. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  270. ValuePtr key_value = key->BuildValue();
  271. if (!key_value->isa<StringImm>()) {
  272. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  273. }
  274. std::string key_str = GetValue<std::string>(key_value);
  275. std::vector<AbstractAttribute> dict_elems = dict->elements();
  276. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  277. [key_str](AbstractAttribute &item) { return item.first == key_str; });
  278. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  279. auto new_ele = std::make_pair(key_str, args_spec_list[2]);
  280. if (it != dict_elems.end()) {
  281. int index = it - dict_elems.begin();
  282. dict_elems[IntToSize(index)] = new_ele;
  283. } else {
  284. dict_elems.push_back(new_ele);
  285. }
  286. return std::make_shared<AbstractDictionary>(dict_elems);
  287. }
  288. AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  289. const AbstractBasePtrList &args_spec_list) {
  290. // Inputs: a list and an object of a subclass of AbstractBase.
  291. const std::string op_name = primitive->name();
  292. CheckArgsSize(op_name, args_spec_list, 2);
  293. AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
  294. (void)AbstractJoin(list->elements());
  295. return list;
  296. }
  297. template <typename T>
  298. AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  299. // Inputs: a tuple or list or dict.
  300. CheckArgsSize(op_name, args_spec_list, 1);
  301. auto arg = CheckArg<T>(op_name, args_spec_list, 0);
  302. return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
  303. }
  304. AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  305. const AbstractBasePtrList &args_spec_list) {
  306. return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
  307. }
  308. AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  309. const AbstractBasePtrList &args_spec_list) {
  310. return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
  311. }
  312. AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  313. const AbstractBasePtrList &args_spec_list) {
  314. return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
  315. }
  316. AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
  317. const AbstractBasePtrList &args_spec_list) {
  318. return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  319. }
  320. AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
  321. const AbstractBasePtrList &args_spec_list) {
  322. // Inputs: fn, list1, list2, ...
  323. MS_EXCEPTION_IF_NULL(engine);
  324. if (args_spec_list.size() <= 1) {
  325. MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
  326. }
  327. AbstractFunctionPtr fn = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
  328. // check args from 1.
  329. CheckArgsSpec<AbstractList>(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end()));
  330. AbstractBasePtrList subargs;
  331. for (std::size_t i = 1; i < args_spec_list.size(); i++) {
  332. AbstractListPtr l_ptr = dyn_cast<AbstractList>(args_spec_list[i]);
  333. if (l_ptr == nullptr) {
  334. MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list.";
  335. }
  336. subargs.push_back(AbstractJoin(l_ptr->elements()));
  337. }
  338. EvalResultPtr engin_exc = engine->Execute(fn, subargs);
  339. AbstractBasePtrList result;
  340. for (std::size_t i = 1; i < args_spec_list.size(); i++) {
  341. result.push_back(engin_exc->abstract());
  342. }
  343. return std::make_shared<AbstractList>(result);
  344. }
  345. AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
  346. const AbstractBasePtrList &args_spec_list) {
  347. // Inputs: a fn, a list and an object of a subclass of a AbstractBase.
  348. MS_EXCEPTION_IF_NULL(engine);
  349. const std::string op_name = primitive->name();
  350. CheckArgsSize(op_name, args_spec_list, 3);
  351. AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
  352. AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
  353. AbstractBasePtr dflt = args_spec_list[2];
  354. AbstractBasePtr list_type = AbstractJoin(lst->elements());
  355. auto result1 = engine->Execute(fn, lst->elements());
  356. auto result2 = engine->Execute(fn, {dflt, list_type});
  357. MS_EXCEPTION_IF_NULL(result1->abstract());
  358. MS_EXCEPTION_IF_NULL(result2->abstract());
  359. return result1->abstract()->Join(result2->abstract());
  360. }
  361. AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  362. const AbstractBasePtrList &args_spec_list) {
  363. // Inputs: a tuple
  364. const std::string op_name = primitive->name();
  365. CheckArgsSize(op_name, args_spec_list, 1);
  366. AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  367. auto tuple_elements = input->elements();
  368. AbstractBasePtrList elem_list;
  369. (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
  370. [](const AbstractBasePtr &elem) { return elem->Clone(); });
  371. return std::make_shared<AbstractTuple>(elem_list);
  372. }
  373. AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
  374. const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
  375. size_t x_rank = x_shape->size();
  376. std::set<int> axis_set;
  377. auto axis_data = axis_value_ptr->value();
  378. if (axis_data.empty()) {
  379. int size = 1;
  380. AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
  381. return std::make_shared<AbstractTuple>(values);
  382. }
  383. for (auto &elem : axis_data) {
  384. int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
  385. (void)axis_set.insert(e_value);
  386. }
  387. auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
  388. if (x_shp_data.size() < x_rank) {
  389. MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
  390. }
  391. AbstractBasePtrList values;
  392. for (size_t i = 0; i < x_rank; i++) {
  393. if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
  394. auto axis_v = MakeValue(1);
  395. values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
  396. } else {
  397. int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
  398. auto dim = MakeValue(dim_value);
  399. values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
  400. }
  401. }
  402. return std::make_shared<AbstractTuple>(values);
  403. }
  404. AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  405. const AbstractBasePtrList &args_spec_list) {
  406. // Inputs: x_shape, axis
  407. const std::string op_name = primitive->name();
  408. CheckArgsSize(op_name, args_spec_list, 2);
  409. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  410. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  411. auto x_shp_value = shape_x->BuildValue();
  412. if (x_shp_value->isa<AnyValue>()) {
  413. MS_LOG(EXCEPTION) << op_name
  414. << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
  415. }
  416. // Axis can be scalar, tuple or None
  417. AbstractTuplePtr axis = nullptr;
  418. if (args_spec_list[1]->isa<AbstractScalar>()) {
  419. MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
  420. AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
  421. axis = std::make_shared<AbstractTuple>(axis_list);
  422. } else if (args_spec_list[1]->isa<AbstractTuple>()) {
  423. MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
  424. axis = args_spec_list[1]->cast<AbstractTuplePtr>();
  425. } else {
  426. MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
  427. << args_spec_list[1]->ToString();
  428. }
  429. auto axis_value = axis->BuildValue();
  430. if (axis_value->isa<AnyValue>()) {
  431. MS_LOG(EXCEPTION) << op_name
  432. << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
  433. }
  434. auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
  435. MS_EXCEPTION_IF_NULL(axis_value_ptr);
  436. return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
  437. }
  438. AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  439. const AbstractBasePtrList &args_spec_list) {
  440. // Inputs: two tuples.
  441. const std::string op_name = primitive->name();
  442. CheckArgsSize(op_name, args_spec_list, 2);
  443. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  444. AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  445. MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();
  446. auto div_shp_value = div_shp->BuildValue();
  447. if (div_shp_value->isa<AnyValue>()) {
  448. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString();
  449. }
  450. auto shpx_value = shape_x->BuildValue();
  451. if (shpx_value->isa<AnyValue>()) {
  452. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString();
  453. }
  454. if (div_shp->size() != shape_x->size()) {
  455. MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size()
  456. << ", shapex: " << shape_x->size() << ".";
  457. }
  458. auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
  459. auto div_shp_data = div_shp_value->cast<ValueTuplePtr>()->value();
  460. AbstractBasePtrList values;
  461. for (size_t i = 0; i < div_shp_data.size(); i++) {
  462. if (div_shp_data[i]->cast<Int32ImmPtr>() == nullptr) {
  463. MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString();
  464. }
  465. int shapex_value = GetValue<int>(shpx_data[i]);
  466. int div_value = GetValue<int>(div_shp_data[i]);
  467. MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
  468. if (div_value == 0) {
  469. MS_LOG(EXCEPTION) << "error: division value should not be 0!";
  470. }
  471. if ((shapex_value % div_value) != 0) {
  472. MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value;
  473. }
  474. int result = shapex_value / div_value;
  475. auto result_v = MakeValue(result);
  476. values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
  477. }
  478. return std::make_shared<AbstractTuple>(values);
  479. }
  480. AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  481. const AbstractBasePtrList &args_spec_list) {
  482. // Inputs: a tuple
  483. const std::string op_name = primitive->name();
  484. CheckArgsSize(op_name, args_spec_list, 1);
  485. AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  486. py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
  487. py::array data = py::array(data_tuple);
  488. auto tensor = TensorPy::MakeTensor(data);
  489. auto ret = tensor->ToAbstract();
  490. ret->set_value(tensor);
  491. MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
  492. return ret;
  493. }
  494. AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  495. const AbstractBasePtrList &args_spec_list) {
  496. // Inputs: a tuple
  497. // example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6
  498. const std::string op_name = primitive->name();
  499. CheckArgsSize(op_name, args_spec_list, 1);
  500. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  501. auto shpx_value = shape_x->BuildValue();
  502. if (shpx_value->isa<AnyValue>()) {
  503. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString();
  504. }
  505. auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
  506. int result = 1;
  507. for (size_t i = 0; i < shpx_data.size(); i++) {
  508. int value = GetValue<int>(shpx_data[i]);
  509. IntMulWithOverflowCheck(result, value, &result);
  510. }
  511. auto result_v = MakeValue(result);
  512. MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString();
  513. return std::make_shared<AbstractScalar>(result_v, result_v->type());
  514. }
  515. template <typename T>
  516. AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  517. // Inputs: two tuples or two lists.
  518. CheckArgsSize(op_name, args_spec_list, 2);
  519. auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
  520. auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
  521. ValuePtr x_value = input_x->BuildValue();
  522. ValuePtr y_value = input_y->BuildValue();
  523. return std::make_shared<AbstractScalar>(*x_value == *y_value);
  524. }
  525. AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  526. const AbstractBasePtrList &args_spec_list) {
  527. return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
  528. }
  529. AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  530. const AbstractBasePtrList &args_spec_list) {
  531. return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
  532. }
  533. struct SlideInfo {
  534. int start;
  535. int step;
  536. int stop;
  537. };
  538. void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
  539. int arg1 = 0;
  540. int arg2 = 0;
  541. if (!args_spec_list.empty()) {
  542. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  543. auto arg_value = args_spec_list[0]->BuildValue();
  544. if (!arg_value->isa<Int32Imm>()) {
  545. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  546. }
  547. arg1 = GetValue<int>(arg_value);
  548. }
  549. if (args_spec_list.size() >= 2) {
  550. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  551. auto arg_value = args_spec_list[1]->BuildValue();
  552. if (!arg_value->isa<Int32Imm>()) {
  553. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  554. }
  555. arg2 = GetValue<int>(arg_value);
  556. }
  557. if (args_spec_list.size() == 3) {
  558. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  559. auto arg_value = args_spec_list[2]->BuildValue();
  560. if (!arg_value->isa<Int32Imm>()) {
  561. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  562. }
  563. slide->step = GetValue<int>(arg_value);
  564. slide->start = arg1;
  565. slide->stop = arg2;
  566. }
  567. if (args_spec_list.size() == 2) {
  568. slide->start = arg1;
  569. slide->stop = arg2;
  570. }
  571. if (args_spec_list.size() == 1) {
  572. slide->stop = arg1;
  573. }
  574. }
  575. AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
  576. const AbstractBasePtrList &args_spec_list) {
  577. if (args_spec_list.empty()) {
  578. MS_LOG(EXCEPTION) << "Cannot make range from empty input.";
  579. }
  580. if (args_spec_list.size() > 3) {
  581. MS_LOG(EXCEPTION) << "Error args size of make range operational.";
  582. }
  583. SlideInfo slide = {0, 1, 0};
  584. CalcSlidePara(args_spec_list, &slide);
  585. if (slide.step == 0) {
  586. MS_LOG(EXCEPTION) << "Error, step value is 0.";
  587. }
  588. AbstractBasePtrList args;
  589. if (slide.start <= slide.stop) {
  590. if (slide.step <= 0) {
  591. MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
  592. }
  593. for (int i = slide.start; i < slide.stop; i += slide.step) {
  594. args.push_back(abstract::FromValue(i));
  595. }
  596. } else {
  597. if (slide.step >= 0) {
  598. MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
  599. }
  600. for (int i = slide.start; i > slide.stop; i += slide.step) {
  601. args.push_back(abstract::FromValue(i));
  602. }
  603. }
  604. return std::make_shared<AbstractTuple>(args);
  605. }
  606. AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  607. const AbstractBasePtrList &args_spec_list) {
  608. // Inputs: a tensor
  609. CheckArgsSize(primitive->name(), args_spec_list, 1);
  610. return args_spec_list[0]->Clone();
  611. }
  612. } // namespace abstract
  613. } // namespace mindspore