You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_structures.cc 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/static_analysis/prim.h"
  19. #include "abstract/utils.h"
  20. #include "abstract/param_validator.h"
  21. #include "frontend/operator/ops.h"
  22. #include "utils/convert_utils.h"
  23. #include "ir/tensor_py.h"
  24. using mindspore::tensor::TensorPy;
  25. namespace mindspore {
  26. namespace abstract {
  27. AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  28. const AbstractBasePtrList &args_spec_list) {
  29. // Inputs: two scalars whose value is a string.
  30. const std::string op_name = primitive->name();
  31. CheckArgsSize(op_name, args_spec_list, 2);
  32. AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  33. AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  34. ValuePtr value_x = scalar_x->BuildValue();
  35. ValuePtr value_y = scalar_y->BuildValue();
  36. if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
  37. MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
  38. << ", param1: " << value_y->ToString();
  39. }
  40. bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
  41. return std::make_shared<AbstractScalar>(ret);
  42. }
  43. AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  44. const AbstractBasePtrList &args_spec_list) {
  45. // Inputs: two scalars whose value is a string.
  46. const std::string op_name = primitive->name();
  47. CheckArgsSize(op_name, args_spec_list, 2);
  48. AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  49. AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  50. ValuePtr value_x = scalar_x->BuildValue();
  51. ValuePtr value_y = scalar_y->BuildValue();
  52. if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
  53. MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
  54. << ", param1: " << value_y->ToString();
  55. }
  56. std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
  57. return std::make_shared<AbstractScalar>(ret);
  58. }
  59. AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
  60. const AbstractBasePtrList &args_spec_list) {
  61. return std::make_shared<AbstractTuple>(args_spec_list);
  62. }
  63. AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
  64. const AbstractBasePtrList &args_spec_list) {
  65. return std::make_shared<AbstractList>(args_spec_list);
  66. }
  67. AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  68. const AbstractBasePtrList &args_spec_list) {
  69. // Inputs: two tuples.
  70. const std::string op_name = primitive->name();
  71. CheckArgsSize(op_name, args_spec_list, 2);
  72. AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  73. AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  74. size_t keys_size = keys->size();
  75. if (values->size() != keys_size) {
  76. MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
  77. }
  78. std::vector<AbstractAttribute> key_value;
  79. AbstractScalarPtr key;
  80. AbstractBasePtrList key_list = keys->elements();
  81. AbstractBasePtrList value_list = values->elements();
  82. for (size_t index = 0; index < keys_size; index++) {
  83. key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
  84. ValuePtr keyPtr = key->BuildValue();
  85. MS_EXCEPTION_IF_NULL(keyPtr);
  86. if (!keyPtr->isa<StringImm>()) {
  87. MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
  88. }
  89. std::string key_string = GetValue<std::string>(keyPtr);
  90. key_value.emplace_back(key_string, value_list[index]);
  91. }
  92. return std::make_shared<AbstractDictionary>(key_value);
  93. }
  94. AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  95. const AbstractBasePtrList &args_spec_list) {
  96. // Inputs: a string and an object of a subclass of AbstractBase.
  97. const std::string op_name = primitive->name();
  98. CheckArgsSize(op_name, args_spec_list, 2);
  99. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  100. ValuePtr keyPtr = key->BuildValue();
  101. if (!keyPtr->isa<StringImm>()) {
  102. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
  103. }
  104. std::string key_string = GetValue<std::string>(keyPtr);
  105. return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
  106. }
  107. AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  108. const AbstractBasePtrList &args_spec_list) {
  109. // Inputs: a string and a keyword.
  110. const std::string op_name = primitive->name();
  111. CheckArgsSize(op_name, args_spec_list, 2);
  112. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  113. AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
  114. ValuePtr key_value = key->BuildValue();
  115. if (!key_value->isa<StringImm>()) {
  116. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  117. }
  118. std::string key_input = GetValue<std::string>(key_value);
  119. std::string key_actual = kwarg->get_key();
  120. if (key_actual != key_input) {
  121. MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
  122. << key_input << ", AbstractKeywordArg' key is " << key_actual;
  123. }
  124. return kwarg->get_arg();
  125. }
  126. AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  127. const AbstractBasePtrList &args_spec_list) {
  128. // Inputs: three scalars whose value is an int32 number.
  129. CheckArgsSize(primitive->name(), args_spec_list, 3);
  130. size_t args_size = args_spec_list.size();
  131. for (size_t index = 0; index < args_size; index++) {
  132. MS_EXCEPTION_IF_NULL(args_spec_list[index]);
  133. if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
  134. MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
  135. }
  136. if (args_spec_list[index]->isa<AbstractScalar>() &&
  137. !dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
  138. MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number.";
  139. }
  140. }
  141. // Slice: start, end, step
  142. return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
  143. }
  144. // Eval the return type of make_record
  145. AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  146. const AbstractBasePtrList &args_spec_list) {
  147. // Inputs: at lease two objects of a subclass of AbstractBase.
  148. if (args_spec_list.size() < 2) {
  149. MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
  150. << args_spec_list.size() << ".";
  151. }
  152. // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
  153. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  154. TypePtr type = args_spec_list[0]->GetTypeTrack();
  155. MS_EXCEPTION_IF_NULL(type);
  156. if (type->type_id() != kMetaTypeTypeType) {
  157. MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
  158. }
  159. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  160. MS_EXCEPTION_IF_NULL(value_track);
  161. TypePtr type_ptr = value_track->cast<TypePtr>();
  162. if (type_ptr == nullptr) {
  163. MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
  164. }
  165. auto cls = dyn_cast<Class>(type_ptr);
  166. MS_EXCEPTION_IF_NULL(cls);
  167. ClassAttrVector attributes = cls->GetAttributes();
  168. CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
  169. std::vector<AbstractAttribute> abs_attributes;
  170. for (size_t i = 0; i < attributes.size(); i++) {
  171. AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
  172. abs_attributes.push_back(elem);
  173. }
  174. return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
  175. }
  176. template <typename T>
  177. AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  178. // Inputs: a tuple or list and a scalar whose value is an int32 number.
  179. CheckArgsSize(op_name, args_spec_list, 2);
  180. auto queue = CheckArg<T>(op_name, args_spec_list, 0);
  181. AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  182. ValuePtr index_value = index->BuildValue();
  183. if (!index_value->isa<Int32Imm>()) {
  184. // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
  185. // and continue
  186. if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
  187. return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
  188. }
  189. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
  190. << index_value->ToString();
  191. }
  192. int idx_v = GetValue<int>(index_value);
  193. std::size_t nelems = queue->elements().size();
  194. if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
  195. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
  196. << SizeToInt(nelems) << "), but got " << idx_v << ".";
  197. }
  198. std::size_t uidx_v = 0;
  199. if (idx_v >= 0) {
  200. uidx_v = IntToSize(idx_v);
  201. } else {
  202. uidx_v = IntToSize(idx_v + SizeToInt(nelems));
  203. }
  204. return queue->elements()[uidx_v];
  205. }
  206. template <typename T>
  207. AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  208. // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
  209. CheckArgsSize(op_name, args_spec_list, 3);
  210. auto queue = CheckArg<T>(op_name, args_spec_list, 0);
  211. AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  212. ValuePtr index_value = index->BuildValue();
  213. if (!index_value->isa<Int32Imm>()) {
  214. MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
  215. << index_value->ToString();
  216. }
  217. int idx_v = GetValue<int>(index_value);
  218. if (idx_v < 0) {
  219. MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
  220. << ".";
  221. }
  222. size_t uidx_v = IntToSize(idx_v);
  223. AbstractBasePtrList elements = queue->elements();
  224. std::size_t nelems = elements.size();
  225. if (uidx_v >= nelems) {
  226. MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
  227. << ".";
  228. }
  229. elements[uidx_v] = args_spec_list[2];
  230. return std::make_shared<T>(elements);
  231. }
  232. AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  233. const AbstractBasePtrList &args_spec_list) {
  234. return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
  235. }
  236. AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  237. const AbstractBasePtrList &args_spec_list) {
  238. return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
  239. }
  240. AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  241. const AbstractBasePtrList &args_spec_list) {
  242. return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
  243. }
  244. AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  245. const AbstractBasePtrList &args_spec_list) {
  246. return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
  247. }
  248. AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  249. const AbstractBasePtrList &args_spec_list) {
  250. // Inputs: a dict and a scalar whose value is a string.
  251. const std::string op_name = primitive->name();
  252. CheckArgsSize(op_name, args_spec_list, 2);
  253. AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
  254. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  255. ValuePtr key_value = key->BuildValue();
  256. if (!key_value->isa<StringImm>()) {
  257. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  258. }
  259. auto key_str = GetValue<std::string>(key_value);
  260. std::vector<AbstractAttribute> dict_elems = dict->elements();
  261. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  262. [key_str](const AbstractAttribute &item) { return item.first == key_str; });
  263. if (it == dict_elems.end()) {
  264. MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
  265. }
  266. return it->second;
  267. }
  268. AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  269. const AbstractBasePtrList &args_spec_list) {
  270. // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
  271. const std::string op_name = primitive->name();
  272. CheckArgsSize(op_name, args_spec_list, 3);
  273. AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
  274. AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
  275. ValuePtr key_value = key->BuildValue();
  276. if (!key_value->isa<StringImm>()) {
  277. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  278. }
  279. std::string key_str = GetValue<std::string>(key_value);
  280. std::vector<AbstractAttribute> dict_elems = dict->elements();
  281. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  282. [key_str](AbstractAttribute &item) { return item.first == key_str; });
  283. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  284. auto new_ele = std::make_pair(key_str, args_spec_list[2]);
  285. if (it != dict_elems.end()) {
  286. int index = it - dict_elems.begin();
  287. dict_elems[IntToSize(index)] = new_ele;
  288. } else {
  289. dict_elems.push_back(new_ele);
  290. }
  291. return std::make_shared<AbstractDictionary>(dict_elems);
  292. }
  293. AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  294. const AbstractBasePtrList &args_spec_list) {
  295. // Inputs: a list and an object of a subclass of AbstractBase.
  296. const std::string op_name = primitive->name();
  297. CheckArgsSize(op_name, args_spec_list, 2);
  298. AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
  299. (void)AbstractJoin(list->elements());
  300. return list;
  301. }
  302. template <typename T>
  303. AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  304. // Inputs: a tuple or list or dict.
  305. CheckArgsSize(op_name, args_spec_list, 1);
  306. auto arg = CheckArg<T>(op_name, args_spec_list, 0);
  307. return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
  308. }
  309. AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  310. const AbstractBasePtrList &args_spec_list) {
  311. return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
  312. }
  313. AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  314. const AbstractBasePtrList &args_spec_list) {
  315. return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
  316. }
  317. AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  318. const AbstractBasePtrList &args_spec_list) {
  319. return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
  320. }
  321. AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
  322. const AbstractBasePtrList &args_spec_list) {
  323. return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  324. }
  325. AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
  326. const AbstractBasePtrList &args_spec_list) {
  327. // Inputs: fn, list1, list2, ...
  328. MS_EXCEPTION_IF_NULL(engine);
  329. if (args_spec_list.size() <= 1) {
  330. MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
  331. }
  332. AbstractFunctionPtr fn = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
  333. // check args from 1.
  334. CheckArgsSpec<AbstractList>(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end()));
  335. AbstractBasePtrList subargs;
  336. for (std::size_t i = 1; i < args_spec_list.size(); i++) {
  337. AbstractListPtr l_ptr = dyn_cast<AbstractList>(args_spec_list[i]);
  338. if (l_ptr == nullptr) {
  339. MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list.";
  340. }
  341. subargs.push_back(AbstractJoin(l_ptr->elements()));
  342. }
  343. EvalResultPtr engin_exc = engine->Execute(fn, subargs);
  344. AbstractBasePtrList result;
  345. for (std::size_t i = 1; i < args_spec_list.size(); i++) {
  346. result.push_back(engin_exc->abstract());
  347. }
  348. return std::make_shared<AbstractList>(result);
  349. }
  350. AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
  351. const AbstractBasePtrList &args_spec_list) {
  352. // Inputs: a fn, a list and an object of a subclass of a AbstractBase.
  353. MS_EXCEPTION_IF_NULL(engine);
  354. const std::string op_name = primitive->name();
  355. CheckArgsSize(op_name, args_spec_list, 3);
  356. AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
  357. AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
  358. AbstractBasePtr dflt = args_spec_list[2];
  359. AbstractBasePtr list_type = AbstractJoin(lst->elements());
  360. auto result1 = engine->Execute(fn, lst->elements());
  361. auto result2 = engine->Execute(fn, {dflt, list_type});
  362. MS_EXCEPTION_IF_NULL(result1->abstract());
  363. MS_EXCEPTION_IF_NULL(result2->abstract());
  364. return result1->abstract()->Join(result2->abstract());
  365. }
  366. AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  367. const AbstractBasePtrList &args_spec_list) {
  368. // Inputs: a tuple
  369. const std::string op_name = primitive->name();
  370. CheckArgsSize(op_name, args_spec_list, 1);
  371. AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  372. auto tuple_elements = input->elements();
  373. AbstractBasePtrList elem_list;
  374. (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
  375. [](const AbstractBasePtr &elem) { return elem->Clone(); });
  376. return std::make_shared<AbstractTuple>(elem_list);
  377. }
  378. AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
  379. const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
  380. size_t x_rank = x_shape->size();
  381. std::set<int> axis_set;
  382. auto axis_data = axis_value_ptr->value();
  383. if (axis_data.empty()) {
  384. int size = 1;
  385. AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
  386. return std::make_shared<AbstractTuple>(values);
  387. }
  388. for (auto &elem : axis_data) {
  389. int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
  390. (void)axis_set.insert(e_value);
  391. }
  392. auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
  393. if (x_shp_data.size() < x_rank) {
  394. MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
  395. }
  396. AbstractBasePtrList values;
  397. for (size_t i = 0; i < x_rank; i++) {
  398. if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
  399. auto axis_v = MakeValue(1);
  400. values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
  401. } else {
  402. int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
  403. auto dim = MakeValue(dim_value);
  404. values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
  405. }
  406. }
  407. return std::make_shared<AbstractTuple>(values);
  408. }
  409. AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  410. const AbstractBasePtrList &args_spec_list) {
  411. // Inputs: x_shape, axis
  412. const std::string op_name = primitive->name();
  413. CheckArgsSize(op_name, args_spec_list, 2);
  414. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  415. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  416. auto x_shp_value = shape_x->BuildValue();
  417. if (x_shp_value->isa<AnyValue>()) {
  418. MS_LOG(EXCEPTION) << op_name
  419. << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
  420. }
  421. // Axis can be scalar, tuple or None
  422. AbstractTuplePtr axis = nullptr;
  423. if (args_spec_list[1]->isa<AbstractScalar>()) {
  424. MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
  425. AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
  426. axis = std::make_shared<AbstractTuple>(axis_list);
  427. } else if (args_spec_list[1]->isa<AbstractTuple>()) {
  428. MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
  429. axis = args_spec_list[1]->cast<AbstractTuplePtr>();
  430. } else {
  431. MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
  432. << args_spec_list[1]->ToString();
  433. }
  434. auto axis_value = axis->BuildValue();
  435. if (axis_value->isa<AnyValue>()) {
  436. MS_LOG(EXCEPTION) << op_name
  437. << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
  438. }
  439. auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
  440. MS_EXCEPTION_IF_NULL(axis_value_ptr);
  441. return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
  442. }
  443. AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  444. const AbstractBasePtrList &args_spec_list) {
  445. // Inputs: two tuples.
  446. const std::string op_name = primitive->name();
  447. CheckArgsSize(op_name, args_spec_list, 2);
  448. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  449. AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  450. MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();
  451. auto div_shp_value = div_shp->BuildValue();
  452. if (div_shp_value->isa<AnyValue>()) {
  453. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString();
  454. }
  455. auto shpx_value = shape_x->BuildValue();
  456. if (shpx_value->isa<AnyValue>()) {
  457. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString();
  458. }
  459. if (div_shp->size() != shape_x->size()) {
  460. MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size()
  461. << ", shapex: " << shape_x->size() << ".";
  462. }
  463. auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
  464. auto div_shp_data = div_shp_value->cast<ValueTuplePtr>()->value();
  465. AbstractBasePtrList values;
  466. for (size_t i = 0; i < div_shp_data.size(); i++) {
  467. if (div_shp_data[i]->cast<Int32ImmPtr>() == nullptr) {
  468. MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString();
  469. }
  470. int shapex_value = GetValue<int>(shpx_data[i]);
  471. int div_value = GetValue<int>(div_shp_data[i]);
  472. MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
  473. if (div_value == 0) {
  474. MS_LOG(EXCEPTION) << "error: division value should not be 0!";
  475. }
  476. if ((shapex_value % div_value) != 0) {
  477. MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value;
  478. }
  479. int result = shapex_value / div_value;
  480. auto result_v = MakeValue(result);
  481. values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
  482. }
  483. return std::make_shared<AbstractTuple>(values);
  484. }
  485. AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  486. const AbstractBasePtrList &args_spec_list) {
  487. // Inputs: a tuple
  488. const std::string op_name = primitive->name();
  489. CheckArgsSize(op_name, args_spec_list, 1);
  490. AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  491. py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
  492. py::array data = py::array(data_tuple);
  493. auto tensor = TensorPy::MakeTensor(data);
  494. auto ret = tensor->ToAbstract();
  495. ret->set_value(tensor);
  496. MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
  497. return ret;
  498. }
  499. AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  500. const AbstractBasePtrList &args_spec_list) {
  501. // Inputs: a tuple
  502. // example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6
  503. const std::string op_name = primitive->name();
  504. CheckArgsSize(op_name, args_spec_list, 1);
  505. AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  506. auto shpx_value = shape_x->BuildValue();
  507. if (shpx_value->isa<AnyValue>()) {
  508. MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString();
  509. }
  510. auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
  511. int result = 1;
  512. for (size_t i = 0; i < shpx_data.size(); i++) {
  513. int value = GetValue<int>(shpx_data[i]);
  514. result = IntMulWithOverflowCheck(result, value);
  515. }
  516. auto result_v = MakeValue(result);
  517. MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString();
  518. return std::make_shared<AbstractScalar>(result_v, result_v->type());
  519. }
  520. template <typename T>
  521. AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
  522. // Inputs: two tuples or two lists.
  523. CheckArgsSize(op_name, args_spec_list, 2);
  524. auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
  525. auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
  526. ValuePtr x_value = input_x->BuildValue();
  527. ValuePtr y_value = input_y->BuildValue();
  528. return std::make_shared<AbstractScalar>(*x_value == *y_value);
  529. }
  530. AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  531. const AbstractBasePtrList &args_spec_list) {
  532. return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
  533. }
  534. AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  535. const AbstractBasePtrList &args_spec_list) {
  536. return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
  537. }
  538. struct SlideInfo {
  539. int start;
  540. int step;
  541. int stop;
  542. };
  543. void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
  544. int arg1 = 0;
  545. int arg2 = 0;
  546. if (!args_spec_list.empty()) {
  547. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  548. auto arg_value = args_spec_list[0]->BuildValue();
  549. if (!arg_value->isa<Int32Imm>()) {
  550. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  551. }
  552. arg1 = GetValue<int>(arg_value);
  553. }
  554. if (args_spec_list.size() >= 2) {
  555. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  556. auto arg_value = args_spec_list[1]->BuildValue();
  557. if (!arg_value->isa<Int32Imm>()) {
  558. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  559. }
  560. arg2 = GetValue<int>(arg_value);
  561. }
  562. if (args_spec_list.size() == 3) {
  563. MS_EXCEPTION_IF_NULL(args_spec_list[2]);
  564. auto arg_value = args_spec_list[2]->BuildValue();
  565. if (!arg_value->isa<Int32Imm>()) {
  566. MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
  567. }
  568. slide->step = GetValue<int>(arg_value);
  569. slide->start = arg1;
  570. slide->stop = arg2;
  571. }
  572. if (args_spec_list.size() == 2) {
  573. slide->start = arg1;
  574. slide->stop = arg2;
  575. }
  576. if (args_spec_list.size() == 1) {
  577. slide->stop = arg1;
  578. }
  579. }
  580. AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
  581. const AbstractBasePtrList &args_spec_list) {
  582. if (args_spec_list.empty()) {
  583. MS_LOG(EXCEPTION) << "Cannot make range from empty input.";
  584. }
  585. if (args_spec_list.size() > 3) {
  586. MS_LOG(EXCEPTION) << "Error args size of make range operational.";
  587. }
  588. SlideInfo slide = {0, 1, 0};
  589. CalcSlidePara(args_spec_list, &slide);
  590. if (slide.step == 0) {
  591. MS_LOG(EXCEPTION) << "Error, step value is 0.";
  592. }
  593. AbstractBasePtrList args;
  594. if (slide.start <= slide.stop) {
  595. if (slide.step <= 0) {
  596. MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
  597. }
  598. for (int i = slide.start; i < slide.stop; i += slide.step) {
  599. args.push_back(abstract::FromValue(i));
  600. }
  601. } else {
  602. if (slide.step >= 0) {
  603. MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
  604. }
  605. for (int i = slide.start; i > slide.stop; i += slide.step) {
  606. args.push_back(abstract::FromValue(i));
  607. }
  608. }
  609. return std::make_shared<AbstractTuple>(args);
  610. }
  611. AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  612. const AbstractBasePtrList &args_spec_list) {
  613. // Inputs: a tensor
  614. CheckArgsSize(primitive->name(), args_spec_list, 1);
  615. return args_spec_list[0]->Clone();
  616. }
  617. } // namespace abstract
  618. } // namespace mindspore