You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_structures.cc 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/prim.h"
  19. #include "pipeline/static_analysis/utils.h"
  20. #include "pipeline/static_analysis/param_validator.h"
  21. #include "operator/ops.h"
  22. #include "utils/convert_utils.h"
  23. namespace mindspore {
  24. namespace abstract {
  25. AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  26. const AbstractBasePtrList &args_spec_list) {
  27. // Inputs: two scalars whose value is a string.
  28. const std::string op_name = primitive->name();
  29. CheckArgsSize(op_name, args_spec_list, 2);
  30. AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  31. AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  32. ValuePtr value_x = scalar_x->BuildValue();
  33. ValuePtr value_y = scalar_y->BuildValue();
  34. if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
  35. MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
  36. << ", param1: " << value_y->ToString();
  37. }
  38. bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
  39. return std::make_shared<AbstractScalar>(ret);
  40. }
  41. AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  42. const AbstractBasePtrList &args_spec_list) {
  43. // Inputs: two scalars whose value is a string.
  44. const std::string op_name = primitive->name();
  45. CheckArgsSize(op_name, args_spec_list, 2);
  46. AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  47. AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  48. ValuePtr value_x = scalar_x->BuildValue();
  49. ValuePtr value_y = scalar_y->BuildValue();
  50. if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
  51. MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
  52. << ", param1: " << value_y->ToString();
  53. }
  54. std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
  55. return std::make_shared<AbstractScalar>(ret);
  56. }
  57. AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
  58. const AbstractBasePtrList &args_spec_list) {
  59. return std::make_shared<AbstractTuple>(args_spec_list);
  60. }
  61. AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
  62. const AbstractBasePtrList &args_spec_list) {
  63. return std::make_shared<AbstractList>(args_spec_list);
  64. }
  65. AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  66. const AbstractBasePtrList &args_spec_list) {
  67. // Inputs: two tuples.
  68. const std::string op_name = primitive->name();
  69. CheckArgsSize(op_name, args_spec_list, 2);
  70. AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  71. AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  72. size_t keys_size = keys->size();
  73. if (values->size() != keys_size) {
  74. MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
  75. }
  76. std::vector<AbstractAttribute> key_value;
  77. AbstractScalarPtr key;
  78. AbstractBasePtrList key_list = keys->elements();
  79. AbstractBasePtrList value_list = values->elements();
  80. for (size_t index = 0; index < keys_size; index++) {
  81. key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
  82. ValuePtr keyPtr = key->BuildValue();
  83. MS_EXCEPTION_IF_NULL(keyPtr);
  84. if (!keyPtr->isa<StringImm>()) {
  85. MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
  86. }
  87. std::string key_string = GetValue<std::string>(keyPtr);
  88. key_value.emplace_back(key_string, value_list[index]);
  89. }
  90. return std::make_shared<AbstractDictionary>(key_value);
  91. }
  92. AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  93. const AbstractBasePtrList &args_spec_list) {
  94. // Inputs: a string and an object of a subclass of AbstractBase.
  95. const std::string op_name = primitive->name();
  96. CheckArgsSize(op_name, args_spec_list, 2);
  97. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  98. ValuePtr keyPtr = key->BuildValue();
  99. if (!keyPtr->isa<StringImm>()) {
  100. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
  101. }
  102. std::string key_string = GetValue<std::string>(keyPtr);
  103. return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
  104. }
  105. AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  106. const AbstractBasePtrList &args_spec_list) {
  107. // Inputs: a string and a keyword.
  108. const std::string op_name = primitive->name();
  109. CheckArgsSize(op_name, args_spec_list, 2);
  110. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  111. AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
  112. ValuePtr key_value = key->BuildValue();
  113. if (!key_value->isa<StringImm>()) {
  114. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  115. }
  116. std::string key_input = GetValue<std::string>(key_value);
  117. std::string key_actual = kwarg->get_key();
  118. if (key_actual != key_input) {
  119. MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
  120. << key_input << ", AbstractKeywordArg' key is " << key_actual;
  121. }
  122. return kwarg->get_arg();
  123. }
  124. AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  125. const AbstractBasePtrList &args_spec_list) {
  126. // Inputs: three scalars whose value is an int32 number.
  127. CheckArgsSize(primitive->name(), args_spec_list, 3);
  128. size_t args_size = args_spec_list.size();
  129. for (size_t index = 0; index < args_size; index++) {
  130. MS_EXCEPTION_IF_NULL(args_spec_list[index]);
  131. if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
  132. MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
  133. }
  134. if (args_spec_list[index]->isa<AbstractScalar>() &&
  135. !dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
  136. MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number.";
  137. }
  138. }
  139. // Slice: start, end, step
  140. return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
  141. }
  142. // Eval the return type of make_record
  143. AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  144. const AbstractBasePtrList &args_spec_list) {
  145. // Inputs: at lease two objects of a subclass of AbstractBase.
  146. if (args_spec_list.size() < 2) {
  147. MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
  148. << args_spec_list.size() << ".";
  149. }
  150. // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
  151. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  152. TypePtr type = args_spec_list[0]->GetTypeTrack();
  153. MS_EXCEPTION_IF_NULL(type);
  154. if (type->type_id() != kMetaTypeTypeType) {
  155. MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
  156. }
  157. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  158. MS_EXCEPTION_IF_NULL(value_track);
  159. TypePtr type_ptr = value_track->cast<TypePtr>();
  160. if (type_ptr == nullptr) {
  161. MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
  162. }
  163. auto cls = dyn_cast<Class>(type_ptr);
  164. MS_EXCEPTION_IF_NULL(cls);
  165. ClassAttrVector attributes = cls->GetAttributes();
  166. CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
  167. std::vector<AbstractAttribute> abs_attributes;
  168. for (size_t i = 0; i < attributes.size(); i++) {
  169. AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
  170. abs_attributes.push_back(elem);
  171. }
  172. return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
  173. }
  174. template <typename T>
  175. AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  176. // Inputs: a tuple or list and a scalar whose value is an int32 number.
  177. CheckArgsSize(op_name, args_spec_list, 2);
  178. auto queue = CheckArg<T>(op_name, args_spec_list, 0);
  179. AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  180. ValuePtr index_value = index->BuildValue();
  181. if (!index_value->isa<Int32Imm>()) {
  182. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
  183. << index_value->ToString();
  184. }
  185. int idx_v = GetValue<int>(index_value);
  186. std::size_t nelems = queue->elements().size();
  187. if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
  188. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
  189. << SizeToInt(nelems) << "), but got " << idx_v << ".";
  190. }
  191. std::size_t uidx_v = 0;
  192. if (idx_v >= 0) {
  193. uidx_v = IntToSize(idx_v);
  194. } else {
  195. uidx_v = IntToSize(idx_v + SizeToInt(nelems));
  196. }
  197. return queue->elements()[uidx_v];
  198. }
  199. template <typename T>
  200. AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  201. // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
  202. CheckArgsSize(op_name, args_spec_list, 3);
  203. auto queue = CheckArg<T>(op_name, args_spec_list, 0);
  204. AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  205. ValuePtr index_value = index->BuildValue();
  206. if (!index_value->isa<Int32Imm>()) {
  207. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
  208. << index_value->ToString();
  209. }
  210. int idx_v = GetValue<int>(index_value);
  211. if (idx_v < 0) {
  212. MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
  213. << ".";
  214. }
  215. size_t uidx_v = IntToSize(idx_v);
  216. AbstractBasePtrList elements = queue->elements();
  217. std::size_t nelems = elements.size();
  218. if (uidx_v >= nelems) {
  219. MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
  220. << ".";
  221. }
  222. elements[uidx_v] = args_spec_list[2];
  223. return std::make_shared<T>(elements);
  224. }
  225. AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  226. const AbstractBasePtrList &args_spec_list) {
  227. return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
  228. }
  229. AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  230. const AbstractBasePtrList &args_spec_list) {
  231. return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
  232. }
  233. AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  234. const AbstractBasePtrList &args_spec_list) {
  235. return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
  236. }
  237. AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  238. const AbstractBasePtrList &args_spec_list) {
  239. return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
  240. }
  241. AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  242. const AbstractBasePtrList &args_spec_list) {
  243. // Inputs: a dict and a scalar whose value is a string.
  244. const std::string op_name = primitive->name();
  245. CheckArgsSize(op_name, args_spec_list, 2);
  246. AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
  247. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  248. ValuePtr key_value = key->BuildValue();
  249. if (!key_value->isa<StringImm>()) {
  250. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  251. }
  252. auto key_str = GetValue<std::string>(key_value);
  253. std::vector<AbstractAttribute> dict_elems = dict->elements();
  254. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  255. [key_str](const AbstractAttribute &item) { return item.first == key_str; });
  256. if (it == dict_elems.end()) {
  257. MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
  258. }
  259. return it->second;
  260. }
  261. AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  262. const AbstractBasePtrList &args_spec_list) {
  263. // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
  264. const std::string op_name = primitive->name();
  265. CheckArgsSize(op_name, args_spec_list, 3);
  266. AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
  267. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  268. ValuePtr key_value = key->BuildValue();
  269. if (!key_value->isa<StringImm>()) {
  270. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  271. }
  272. std::string key_str = GetValue<std::string>(key_value);
  273. std::vector<AbstractAttribute> dict_elems = dict->elements();
  274. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  275. [key_str](AbstractAttribute &item) { return item.first == key_str; });
  276. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  277. auto new_ele = std::make_pair(key_str, args_spec_list[2]);
  278. if (it != dict_elems.end()) {
  279. int index = it - dict_elems.begin();
  280. dict_elems[IntToSize(index)] = new_ele;
  281. } else {
  282. dict_elems.push_back(new_ele);
  283. }
  284. return std::make_shared<AbstractDictionary>(dict_elems);
  285. }
  286. AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  287. const AbstractBasePtrList &args_spec_list) {
  288. // Inputs: a list and an object of a subclass of AbstractBase.
  289. const std::string op_name = primitive->name();
  290. CheckArgsSize(op_name, args_spec_list, 2);
  291. AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
  292. (void)AbstractJoin(list->elements());
  293. return list;
  294. }
  295. template <typename T>
  296. AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  297. // Inputs: a tuple or list or dict.
  298. CheckArgsSize(op_name, args_spec_list, 1);
  299. auto arg = CheckArg<T>(op_name, args_spec_list, 0);
  300. return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
  301. }
  302. AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  303. const AbstractBasePtrList &args_spec_list) {
  304. return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
  305. }
  306. AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  307. const AbstractBasePtrList &args_spec_list) {
  308. return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
  309. }
  310. AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  311. const AbstractBasePtrList &args_spec_list) {
  312. return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
  313. }
  314. AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
  315. const AbstractBasePtrList &args_spec_list) {
  316. return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  317. }
  318. AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
  319. const AbstractBasePtrList &args_spec_list) {
  320. // Inputs: fn, list1, list2, ...
  321. MS_EXCEPTION_IF_NULL(engine);
  322. if (args_spec_list.size() <= 1) {
  323. MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
  324. }
  325. AbstractFunctionPtr fn = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
  326. // check args from 1.
  327. CheckArgsSpec<AbstractList>(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end()));
  328. AbstractBasePtrList subargs;
  329. for (std::size_t i = 1; i < args_spec_list.size(); i++) {
  330. AbstractListPtr l_ptr = dyn_cast<AbstractList>(args_spec_list[i]);
  331. if (l_ptr == nullptr) {
  332. MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list.";
  333. }
  334. subargs.push_back(AbstractJoin(l_ptr->elements()));
  335. }
  336. EvalResultPtr engin_exc = engine->Execute(fn, subargs);
  337. AbstractBasePtrList result;
  338. for (std::size_t i = 1; i < args_spec_list.size(); i++) {
  339. result.push_back(engin_exc->abstract());
  340. }
  341. return std::make_shared<AbstractList>(result);
  342. }
  343. AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
  344. const AbstractBasePtrList &args_spec_list) {
  345. // Inputs: a fn, a list and an object of a subclass of a AbstractBase.
  346. MS_EXCEPTION_IF_NULL(engine);
  347. const std::string op_name = primitive->name();
  348. CheckArgsSize(op_name, args_spec_list, 3);
  349. AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
  350. AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
  351. AbstractBasePtr dflt = args_spec_list[2];
  352. AbstractBasePtr list_type = AbstractJoin(lst->elements());
  353. auto result1 = engine->Execute(fn, lst->elements());
  354. auto result2 = engine->Execute(fn, {dflt, list_type});
  355. MS_EXCEPTION_IF_NULL(result1->abstract());
  356. MS_EXCEPTION_IF_NULL(result2->abstract());
  357. return result1->abstract()->Join(result2->abstract());
  358. }
  359. AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  360. const AbstractBasePtrList &args_spec_list) {
  361. // Inputs: a tuple
  362. const std::string op_name = primitive->name();
  363. CheckArgsSize(op_name, args_spec_list, 1);
  364. AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  365. auto tuple_elements = input->elements();
  366. AbstractBasePtrList elem_list;
  367. (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
  368. [](const AbstractBasePtr &elem) { return elem->Clone(); });
  369. return std::make_shared<AbstractTuple>(elem_list);
  370. }
  371. AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
  372. const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
  373. size_t x_rank = x_shape->size();
  374. std::set<int> axis_set;
  375. auto axis_data = axis_value_ptr->value();
  376. if (axis_data.empty()) {
  377. int size = 1;
  378. AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
  379. return std::make_shared<AbstractTuple>(values);
  380. }
  381. for (auto &elem : axis_data) {
  382. int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
  383. (void)axis_set.insert(e_value);
  384. }
  385. auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
  386. if (x_shp_data.size() < x_rank) {
  387. MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
  388. }
  389. AbstractBasePtrList values;
  390. for (size_t i = 0; i < x_rank; i++) {
  391. if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
  392. auto axis_v = MakeValue(1);
  393. values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
  394. } else {
  395. int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
  396. auto dim = MakeValue(dim_value);
  397. values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
  398. }
  399. }
  400. return std::make_shared<AbstractTuple>(values);
  401. }
  402. AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  403. const AbstractBasePtrList &args_spec_list) {
  404. // Inputs: x_shape, axis
  405. const std::string op_name = primitive->name();
  406. CheckArgsSize(op_name, args_spec_list, 2);
  407. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  408. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  409. auto x_shp_value = shape_x->BuildValue();
  410. if (x_shp_value->isa<AnyValue>()) {
  411. MS_LOG(EXCEPTION) << op_name
  412. << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
  413. }
  414. // Axis can be scalar, tuple or None
  415. AbstractTuplePtr axis = nullptr;
  416. if (args_spec_list[1]->isa<AbstractScalar>()) {
  417. MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
  418. AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
  419. axis = std::make_shared<AbstractTuple>(axis_list);
  420. } else if (args_spec_list[1]->isa<AbstractTuple>()) {
  421. MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
  422. axis = args_spec_list[1]->cast<AbstractTuplePtr>();
  423. } else {
  424. MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
  425. << args_spec_list[1]->ToString();
  426. }
  427. auto axis_value = axis->BuildValue();
  428. if (axis_value->isa<AnyValue>()) {
  429. MS_LOG(EXCEPTION) << op_name
  430. << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
  431. }
  432. auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
  433. MS_EXCEPTION_IF_NULL(axis_value_ptr);
  434. return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
  435. }
  436. AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  437. const AbstractBasePtrList &args_spec_list) {
  438. // Inputs: two tuples.
  439. const std::string op_name = primitive->name();
  440. CheckArgsSize(op_name, args_spec_list, 2);
  441. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  442. AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  443. MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();
  444. auto div_shp_value = div_shp->BuildValue();
  445. if (div_shp_value->isa<AnyValue>()) {
  446. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString();
  447. }
  448. auto shpx_value = shape_x->BuildValue();
  449. if (shpx_value->isa<AnyValue>()) {
  450. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString();
  451. }
  452. if (div_shp->size() != shape_x->size()) {
  453. MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size()
  454. << ", shapex: " << shape_x->size() << ".";
  455. }
  456. auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
  457. auto div_shp_data = div_shp_value->cast<ValueTuplePtr>()->value();
  458. AbstractBasePtrList values;
  459. for (size_t i = 0; i < div_shp_data.size(); i++) {
  460. if (div_shp_data[i]->cast<Int32ImmPtr>() == nullptr) {
  461. MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString();
  462. }
  463. int shapex_value = GetValue<int>(shpx_data[i]);
  464. int div_value = GetValue<int>(div_shp_data[i]);
  465. MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
  466. if (div_value == 0) {
  467. MS_LOG(EXCEPTION) << "error: division value should not be 0!";
  468. }
  469. if ((shapex_value % div_value) != 0) {
  470. MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value;
  471. }
  472. int result = shapex_value / div_value;
  473. auto result_v = MakeValue(result);
  474. values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
  475. }
  476. return std::make_shared<AbstractTuple>(values);
  477. }
  478. AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  479. const AbstractBasePtrList &args_spec_list) {
  480. // Inputs: a tuple
  481. const std::string op_name = primitive->name();
  482. CheckArgsSize(op_name, args_spec_list, 1);
  483. AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  484. py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
  485. py::array data = py::array(data_tuple);
  486. auto tensor = std::make_shared<tensor::Tensor>(data);
  487. auto ret = tensor->ToAbstract();
  488. ret->set_value(tensor);
  489. MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
  490. return ret;
  491. }
  492. AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  493. const AbstractBasePtrList &args_spec_list) {
  494. // Inputs: a tuple
  495. // example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6
  496. const std::string op_name = primitive->name();
  497. CheckArgsSize(op_name, args_spec_list, 1);
  498. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  499. auto shpx_value = shape_x->BuildValue();
  500. if (shpx_value->isa<AnyValue>()) {
  501. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString();
  502. }
  503. auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
  504. int result = 1;
  505. for (size_t i = 0; i < shpx_data.size(); i++) {
  506. int value = GetValue<int>(shpx_data[i]);
  507. IntMulWithOverflowCheck(result, value, &result);
  508. }
  509. auto result_v = MakeValue(result);
  510. MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString();
  511. return std::make_shared<AbstractScalar>(result_v, result_v->type());
  512. }
  513. template <typename T>
  514. AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  515. // Inputs: two tuples or two lists.
  516. CheckArgsSize(op_name, args_spec_list, 2);
  517. auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
  518. auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
  519. ValuePtr x_value = input_x->BuildValue();
  520. ValuePtr y_value = input_y->BuildValue();
  521. return std::make_shared<AbstractScalar>(*x_value == *y_value);
  522. }
  523. AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  524. const AbstractBasePtrList &args_spec_list) {
  525. return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
  526. }
  527. AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  528. const AbstractBasePtrList &args_spec_list) {
  529. return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
  530. }
  531. struct SlideInfo {
  532. int start;
  533. int step;
  534. int stop;
  535. };
  536. void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
  537. int arg1 = 0;
  538. int arg2 = 0;
  539. if (!args_spec_list.empty()) {
  540. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  541. auto arg_value = args_spec_list[0]->BuildValue();
  542. if (!arg_value->isa<Int32Imm>()) {
  543. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  544. }
  545. arg1 = GetValue<int>(arg_value);
  546. }
  547. if (args_spec_list.size() >= 2) {
  548. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  549. auto arg_value = args_spec_list[1]->BuildValue();
  550. if (!arg_value->isa<Int32Imm>()) {
  551. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  552. }
  553. arg2 = GetValue<int>(arg_value);
  554. }
  555. if (args_spec_list.size() == 3) {
  556. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  557. auto arg_value = args_spec_list[2]->BuildValue();
  558. if (!arg_value->isa<Int32Imm>()) {
  559. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  560. }
  561. slide->step = GetValue<int>(arg_value);
  562. slide->start = arg1;
  563. slide->stop = arg2;
  564. }
  565. if (args_spec_list.size() == 2) {
  566. slide->start = arg1;
  567. slide->stop = arg2;
  568. }
  569. if (args_spec_list.size() == 1) {
  570. slide->stop = arg1;
  571. }
  572. }
  573. AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
  574. const AbstractBasePtrList &args_spec_list) {
  575. if (args_spec_list.empty()) {
  576. MS_LOG(EXCEPTION) << "Cannot make range from empty input.";
  577. }
  578. if (args_spec_list.size() > 3) {
  579. MS_LOG(EXCEPTION) << "Error args size of make range operational.";
  580. }
  581. SlideInfo slide = {0, 1, 0};
  582. CalcSlidePara(args_spec_list, &slide);
  583. if (slide.step == 0) {
  584. MS_LOG(EXCEPTION) << "Error, step value is 0.";
  585. }
  586. AbstractBasePtrList args;
  587. if (slide.start <= slide.stop) {
  588. if (slide.step <= 0) {
  589. MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
  590. }
  591. for (int i = slide.start; i < slide.stop; i += slide.step) {
  592. args.push_back(abstract::FromValue(i));
  593. }
  594. } else {
  595. if (slide.step >= 0) {
  596. MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
  597. }
  598. for (int i = slide.start; i > slide.stop; i += slide.step) {
  599. args.push_back(abstract::FromValue(i));
  600. }
  601. }
  602. return std::make_shared<AbstractTuple>(args);
  603. }
  604. AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  605. const AbstractBasePtrList &args_spec_list) {
  606. // Inputs: a tensor
  607. CheckArgsSize(primitive->name(), args_spec_list, 1);
  608. return args_spec_list[0]->Clone();
  609. }
  610. } // namespace abstract
  611. } // namespace mindspore