You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_test.cc 45 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include "pybind11/pybind11.h"
  19. #include "common/common_test.h"
  20. #include "common/py_func_graph_fetcher.h"
  21. #include "ir/manager.h"
  22. #include "pipeline/jit/static_analysis/prim.h"
  23. #include "pipeline/static_analysis/helper.h"
  24. #include "frontend/operator/ops.h"
  25. #include "debug/draw.h"
  26. #include "ir/tensor.h"
  27. #include "utils/symbolic.h"
  28. namespace mindspore {
  29. namespace abstract {
  30. namespace py = pybind11;
  31. namespace python_adapter = mindspore::parse::python_adapter;
  32. class UTPrimUtils {
  33. public:
  34. using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
  35. using AbstractTuplePtr = std::shared_ptr<AbstractTuple>;
  36. static const std::shared_ptr<Float> kF32;
  37. static const std::shared_ptr<Float> kF64;
  38. static const std::shared_ptr<Int> kI16;
  39. static const std::shared_ptr<Int> kI64;
  40. static const std::shared_ptr<UInt> kU64;
  41. static std::shared_ptr<AbstractType> TypeToAbstract(TypePtr t) { return std::make_shared<AbstractType>(t); }
  42. static AbstractTensorPtr ArrayFloat64Of(std::initializer_list<int> shp) {
  43. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
  44. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  45. }
  46. static AbstractTensorPtr ArrayFloat32Of(std::initializer_list<int> shp) {
  47. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kFloat32);
  48. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  49. }
  50. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int> shp) {
  51. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  52. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  53. }
  54. static AbstractTuplePtr ShapeOf(std::initializer_list<int> vals) {
  55. AbstractBasePtrList te;
  56. for (auto v : vals) {
  57. te.push_back(std::make_shared<AbstractScalar>(v));
  58. }
  59. return std::make_shared<AbstractTuple>(te);
  60. }
  61. static AbstractListPtr ListShapeOf(std::initializer_list<int> vals) {
  62. AbstractBasePtrList te;
  63. for (auto v : vals) {
  64. te.push_back(std::make_shared<AbstractScalar>(v));
  65. }
  66. return std::make_shared<AbstractList>(te);
  67. }
  68. };
  69. const std::shared_ptr<Float> UTPrimUtils::kF64 = std::make_shared<Float>(64);
  70. const std::shared_ptr<Float> UTPrimUtils::kF32 = std::make_shared<Float>(32);
  71. const std::shared_ptr<Int> UTPrimUtils::kI16 = std::make_shared<Int>(16);
  72. const std::shared_ptr<Int> UTPrimUtils::kI64 = std::make_shared<Int>(64);
  73. const std::shared_ptr<UInt> UTPrimUtils::kU64 = std::make_shared<UInt>(64);
  74. namespace {
  75. /* skip ut test cases temporarily
  76. AbstractBasePtr ArrayOfTensor(const TypePtr &t, std::initializer_list<int> shp) {
  77. auto shape = std::vector<int>(shp);
  78. auto tensor = std::make_shared<tensor::Tensor>(t->type_id(), shape);
  79. return ToAbstract(tensor);
  80. }
  81. */
  82. } // namespace
  83. class TestPrim : public UT::Common {
  84. public:
  85. TestPrim() : getPyFun("gtest_input.pipeline.infer", true) {}
  86. void SetUp();
  87. void TearDown();
  88. AnalysisEnginePtr engine_;
  89. UT::PyFuncGraphFetcher getPyFun;
  90. };
  91. void TestPrim::SetUp() { engine_ = SetupAnalysisEngine(); }
  92. void TestPrim::TearDown() {
  93. // destroy resource
  94. }
  95. static FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, unsigned int nparam) {
  96. // build the func_graph manually, eg:
  97. // MakeFuncGraph(std::make_shared<Primitive>("scalar_add"), 2) means:
  98. /* python source code:
  99. * @mindspore
  100. * def f(x, y):
  101. * return x + y
  102. */
  103. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  104. std::vector<AnfNodePtr> inputs;
  105. inputs.push_back(NewValueNode(prim));
  106. for (unsigned int i = 0; i < nparam; i++) {
  107. inputs.push_back(func_graph->add_parameter());
  108. }
  109. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  110. inputs.clear();
  111. inputs.push_back(NewValueNode(prim::kPrimReturn));
  112. inputs.push_back(cnode_prim);
  113. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  114. func_graph->set_return(cnode_return);
  115. return func_graph;
  116. }
  117. TEST_F(TestPrim, test_typeof) {
  118. AbstractBasePtrList args_spec_list;
  119. int v1 = 1;
  120. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  121. args_spec_list.push_back(abstract_v1);
  122. auto prim_typeof = std::make_shared<Primitive>("typeof");
  123. FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
  124. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  125. res->dump();
  126. TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
  127. res_value->dump();
  128. ASSERT_TRUE(*res_value == Int(32));
  129. }
  130. TEST_F(TestPrim, test_list_map) {
  131. AbstractBasePtrList args_spec_list;
  132. AbstractBasePtr abstract_v1 = FromValue(1, false);
  133. AbstractBasePtr abstract_u1 = FromValue(1, false);
  134. auto abstract_list1 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_u1}));
  135. AbstractBasePtr abstract_v2 = FromValue(2, false);
  136. AbstractBasePtr abstract_u2 = FromValue(2, false);
  137. auto abstract_list2 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v2, abstract_u2}));
  138. auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
  139. AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
  140. args_spec_list.push_back(abstract_func);
  141. args_spec_list.push_back(abstract_list1);
  142. args_spec_list.push_back(abstract_list2);
  143. auto prim_list_map = std::make_shared<Primitive>("list_map");
  144. FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
  145. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  146. auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)}));
  147. res->dump();
  148. MS_LOG(INFO) << "result res: " << res->ToString();
  149. MS_LOG(INFO) << "result expected: " << expected->ToString();
  150. ASSERT_TRUE(*res == *expected);
  151. }
  152. TEST_F(TestPrim, test_list_reduce) {
  153. AbstractBasePtrList args_spec_list;
  154. int v1 = 1;
  155. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  156. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  157. auto abstract_list = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
  158. auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
  159. AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
  160. args_spec_list.push_back(abstract_func);
  161. args_spec_list.push_back(abstract_list);
  162. args_spec_list.push_back(abstract_v1);
  163. auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
  164. FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
  165. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  166. res->dump();
  167. TypePtr res_type = res->GetTypeTrack();
  168. res_type->dump();
  169. ASSERT_TRUE(*res_type == Int(32));
  170. }
  171. TEST_F(TestPrim, test_scalar_to_array) {
  172. AbstractBasePtrList args_spec_list;
  173. int v1 = 1;
  174. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  175. args_spec_list.push_back(abstract_v1);
  176. auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
  177. FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
  178. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  179. res->dump();
  180. TypePtr res_type = res->BuildType();
  181. res_type->dump();
  182. ASSERT_TRUE(*res_type == TensorType(std::make_shared<Int>(32)));
  183. }
  184. TEST_F(TestPrim, test_array_to_scalar) {
  185. AbstractBasePtrList args_spec_list;
  186. int v1 = 1;
  187. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  188. auto abstract_a1 = std::make_shared<AbstractTensor>(abstract_v1, std::make_shared<Shape>());
  189. args_spec_list.push_back(abstract_a1);
  190. auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
  191. FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
  192. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  193. res->dump();
  194. TypePtr res_type = res->BuildType();
  195. res_type->dump();
  196. ASSERT_TRUE(*res_type == Int(32));
  197. }
  198. TEST_F(TestPrim, test_J_1) {
  199. AbstractBasePtrList args_spec_list;
  200. int v1 = 1;
  201. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  202. args_spec_list.push_back(abstract_v1);
  203. auto prim_J = std::make_shared<Primitive>("J");
  204. FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
  205. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  206. AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
  207. ASSERT_TRUE(res_J != nullptr);
  208. ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
  209. }
  210. TEST_F(TestPrim, test_J_2) {
  211. // def add(x):
  212. // return x + x
  213. // def f(x):
  214. // return J(add)(x)
  215. std::vector<AnfNodePtr> inputs;
  216. FuncGraphPtr func_graph1 = std::make_shared<FuncGraph>();
  217. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  218. auto x = func_graph1->add_parameter();
  219. inputs.push_back(x);
  220. inputs.push_back(x);
  221. CNodePtr cnode1 = func_graph1->NewCNode(inputs);
  222. func_graph1->set_return(cnode1);
  223. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  224. inputs.clear();
  225. auto x1 = func_graph->add_parameter();
  226. inputs.clear();
  227. inputs.push_back(NewValueNode(prim::kPrimJ));
  228. inputs.push_back(NewValueNode(func_graph1));
  229. CNodePtr jf = func_graph->NewCNode(inputs);
  230. inputs.clear();
  231. inputs.push_back(jf);
  232. inputs.push_back(x1);
  233. CNodePtr jf_jx = func_graph->NewCNode(inputs);
  234. inputs.clear();
  235. inputs.push_back(NewValueNode(prim::kPrimReturn));
  236. inputs.push_back(jf_jx);
  237. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  238. func_graph->set_return(cnode_return);
  239. draw::Draw("test_J_2.dot", func_graph);
  240. int v1 = 1;
  241. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  242. AbstractBasePtrList args_spec_list = {abstract_v1};
  243. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  244. res->dump();
  245. AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
  246. ASSERT_TRUE(res_J != nullptr);
  247. auto res_J_0 = res_J->elements()[0];
  248. ASSERT_TRUE(res_J_0 != nullptr);
  249. ASSERT_TRUE(*res_J_0 == *(FromValue(2, false)));
  250. AbstractFunctionPtr res_J_1 = dyn_cast<AbstractFunction>(res_J->elements()[1]);
  251. ASSERT_TRUE(res_J_1 != nullptr);
  252. }
  253. TEST_F(TestPrim, test_dot) {
  254. auto dot = std::make_shared<Primitive>("dot");
  255. FuncGraphPtr func_graph = MakeFuncGraph(dot, 2);
  256. auto a1 = UTPrimUtils::ArrayFloat64Of({2, 3});
  257. auto a2 = UTPrimUtils::ArrayFloat64Of({3, 4});
  258. std::vector<int> expectedA = {2, 4};
  259. auto expected = UTPrimUtils::ArrayFloat64Of({2, 4});
  260. AbstractBasePtrList args_spec_list = {a1, a2};
  261. AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
  262. ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
  263. }
  264. // tail half
  265. TEST_F(TestPrim, test_switch1) {
  266. PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
  267. FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
  268. AbstractBasePtr arg0 = FromValue(true, false);
  269. AbstractBasePtr arg1 = FromValue(1, false);
  270. AbstractBasePtr arg2 = FromValue(2, false);
  271. AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
  272. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  273. ASSERT_TRUE(*res == *arg1);
  274. }
  275. TEST_F(TestPrim, test_switch2) {
  276. PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
  277. FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
  278. AbstractBasePtr arg0 = FromValue(false, false);
  279. AbstractBasePtr arg1 = FromValue(1, false);
  280. AbstractBasePtr arg2 = FromValue(2, false);
  281. AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
  282. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  283. MS_LOG(INFO) << "make result res: " << res->ToString();
  284. MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
  285. ASSERT_TRUE(*res == *arg2);
  286. }
  287. TEST_F(TestPrim, test_identity) {
  288. PrimitivePtr identity = std::make_shared<Primitive>("identity");
  289. FuncGraphPtr func_graph = MakeFuncGraph(identity, 1);
  290. AbstractBasePtr abstract_v1 = FromValue(1, false);
  291. AbstractBasePtrList args_spec_list = {abstract_v1};
  292. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  293. ASSERT_TRUE(*res == *abstract_v1);
  294. }
  295. TEST_F(TestPrim, test_broadcast_shape) {
  296. PrimitivePtr broadcast_shape = std::make_shared<Primitive>("broadcast_shape");
  297. FuncGraphPtr func_graph = MakeFuncGraph(broadcast_shape, 2);
  298. auto a = UTPrimUtils::ShapeOf({Shape::SHP_ANY, Shape::SHP_ANY});
  299. auto b = UTPrimUtils::ShapeOf({Shape::SHP_ANY});
  300. std::vector<Any> expected{Shape::SHP_ANY, Shape::SHP_ANY};
  301. AbstractBasePtrList args_spec_list = {a, b};
  302. AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
  303. auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
  304. std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
  305. ASSERT_TRUE(ret.size() == element_list.size());
  306. for (int i = 0; i < element_list.size(); i++) {
  307. ASSERT_TRUE(*ret[i] == *element_list[i]);
  308. }
  309. }
  310. TEST_F(TestPrim, test_partial) {
  311. PrimitivePtr prim = prim::kPrimPartial;
  312. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  313. PrimitivePtr add = prim::kPrimScalarAdd;
  314. AbstractBasePtr abstract_add = ToAbstract(add);
  315. AbstractBasePtr abstract_v1 = FromValue(1, false);
  316. AbstractBasePtr abstract_v2 = FromValue(1, false);
  317. AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};
  318. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  319. AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
  320. auto expected = std::make_shared<PartialAbstractClosure>(
  321. std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
  322. MS_LOG(INFO) << "result: " << res->ToString();
  323. MS_LOG(INFO) << "expected: " << expected->ToString();
  324. ASSERT_TRUE(res->ToString() == expected->ToString());
  325. }
  326. // def test_env(x, y):
  327. // return env_setitem(newenv, embed(x), y)
  328. TEST_F(TestPrim, test_env_setitem) {
  329. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  330. AbstractBasePtr abstract_x = FromValue(1, false);
  331. AbstractBasePtrList args_spec_list = {abstract_x};
  332. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
  333. FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
  334. AbstractBasePtr abstract_env = ToAbstract(newenv);
  335. AbstractBasePtr abstract_y = FromValue(2, false);
  336. args_spec_list = {abstract_env, embed_x, abstract_y};
  337. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  338. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  339. ASSERT_TRUE(*res == *exp);
  340. }
  341. // def test_env(x, y, z):
  342. // e = env_setitem(newenv, embed(x), y)
  343. // return env_getitem(e, embed(x), z)
  344. TEST_F(TestPrim, test_env_getitem) {
  345. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  346. AbstractBasePtr abstract_x = FromValue(1, false);
  347. AbstractBasePtrList args_spec_list = {abstract_x};
  348. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
  349. FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
  350. AbstractBasePtr abstract_env = ToAbstract(newenv);
  351. AbstractBasePtr abstract_y = FromValue(2, false);
  352. args_spec_list = {abstract_env, embed_x, abstract_y};
  353. AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
  354. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  355. ASSERT_TRUE(*res == *exp);
  356. FuncGraphPtr graph_getitem = MakeFuncGraph(prim::kPrimEnvGetItem, 3);
  357. AbstractBasePtr abstract_z = FromValue(3, false);
  358. args_spec_list = {res, embed_x, abstract_z};
  359. res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract();
  360. ASSERT_TRUE(*res == *abstract_x);
  361. }
  362. // def test_env(x, y, z):
  363. // e1 = env_setitem(newenv, embed(x), y)
  364. // e2 = env_setitem(newenv, embed(x), z)
  365. // return env_add(e1, e2)
  366. TEST_F(TestPrim, test_env_add) {
  367. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  368. AbstractBasePtr abstract_x = FromValue(1, false);
  369. AbstractBasePtrList args_spec_list = {abstract_x};
  370. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
  371. FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
  372. AbstractBasePtr abstract_env = ToAbstract(newenv);
  373. AbstractBasePtr abstract_y = FromValue(2, false);
  374. args_spec_list = {abstract_env, embed_x, abstract_y};
  375. AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
  376. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  377. ASSERT_TRUE(*abstract_e1 == *exp);
  378. AbstractBasePtr abstract_z = FromValue(3, false);
  379. args_spec_list = {abstract_env, embed_x, abstract_z};
  380. AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
  381. ASSERT_TRUE(*abstract_e2 == *exp);
  382. FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
  383. args_spec_list = {abstract_e1, abstract_e2};
  384. AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();
  385. ASSERT_TRUE(*res == *exp);
  386. }
  387. TEST_F(TestPrim, test_relu) {
  388. PrimitivePtr relu = prim::kPrimRelu;
  389. relu->AddAttr("T", MakeValue(static_cast<int>(kNumberTypeFloat64)));
  390. FuncGraphPtr func_graph = MakeFuncGraph(relu, 1);
  391. AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW
  392. AbstractBasePtrList args_spec_list = {expected};
  393. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  394. ASSERT_TRUE(*res == *expected);
  395. }
  396. /*
  397. TEST_F(TestPrim, test_relu2) {
  398. FuncGraphPtr func_graph = getPyFun("get_relu");
  399. ASSERT_TRUE(func_graph != nullptr);
  400. draw::Draw("test_relu.dot", func_graph);
  401. auto arr = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
  402. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
  403. AbstractBasePtrList args_spec_list = {arr};
  404. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  405. auto res = dyn_cast<AbstractTensor>(ret);
  406. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  407. }
  408. TEST_F(TestPrim, test_conv2d1) {
  409. std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
  410. py::tuple kernel_size(2);
  411. kernel_size[0] = 5;
  412. kernel_size[1] = 5;
  413. std::shared_ptr<FuncGraph> func_graph = getPyFun.CallAndParseRet("test_conv2d", 64, kernel_size, 0, 2, 1);
  414. // NCHW
  415. std::vector<int> inputs_dims = {2, 20, 32, 32};
  416. std::vector<int> weight_dims = {64, 20, 5, 5};
  417. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  418. inputs->set_data_type(kNumberTypeInt32);
  419. inputs->set_shape(inputs_dims);
  420. // Cout, Cin, kernel_size
  421. tensor::TensorPtr weight = std::make_shared<tensor::Tensor>();
  422. weight->set_data_type(kNumberTypeInt32);
  423. weight->set_shape(weight_dims);
  424. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  425. AbstractBasePtr abstract_weight = FromValue(weight, true);
  426. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_weight};
  427. AbstractBasePtr expected = abstract_inputs->Clone();
  428. // NCHW
  429. std::vector<int> shape = {2, 64, 14, 14};
  430. expected->set_shape(std::make_shared<Shape>(shape));
  431. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  432. MS_LOG(INFO) << "result: " << res->ToString();
  433. MS_LOG(INFO) << "expected: " << expected->ToString();
  434. auto res_ptr = dyn_cast<AbstractTensor>(res);
  435. auto expected_ptr = dyn_cast<AbstractTensor>(expected);
  436. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  437. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  438. }
  439. TEST_F(TestPrim, test_conv2d) {
  440. FuncGraphPtr func_graph = getPyFun("get_conv2d");
  441. ASSERT_TRUE(func_graph != nullptr);
  442. auto input = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  443. auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});
  444. AbstractBasePtrList args_spec_list = {input, weight};
  445. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  446. auto res = dyn_cast<AbstractTensor>(ret);
  447. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
  448. MS_LOG(INFO) << "result: " << res->ToString();
  449. MS_LOG(INFO) << "expected: " << expected->ToString();
  450. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  451. }
  452. TEST_F(TestPrim, test_conv2d_native) {
  453. FuncGraphPtr func_graph = getPyFun("get_conv2d_native");
  454. ASSERT_TRUE(func_graph != nullptr);
  455. auto input = ArrayOfTensor(UTPrimUtils::kF64, {10, 32, 32, 32});
  456. auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});
  457. AbstractBasePtrList args_spec_list = {input, weight};
  458. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  459. auto res = dyn_cast<AbstractTensor>(ret);
  460. auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
  461. MS_LOG(INFO) << "result: " << res->ToString();
  462. MS_LOG(INFO) << "expected: " << expected->ToString();
  463. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  464. }
  465. TEST_F(TestPrim, test_biasAdd) {
  466. FuncGraphPtr func_graph = getPyFun("get_bias_add");
  467. ASSERT_TRUE(func_graph != nullptr);
  468. auto value = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  469. auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});
  470. AbstractBasePtrList args_spec_list = {value, bias};
  471. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  472. auto res = dyn_cast<AbstractTensor>(ret);
  473. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  474. MS_LOG(INFO) << "result: " << res->ToString();
  475. MS_LOG(INFO) << "expected: " << expected->ToString();
  476. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  477. }
  478. TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) {
  479. FuncGraphPtr func_graph = getPyFun("get_softmax_cross_entropy_with_logits");
  480. ASSERT_TRUE(func_graph != nullptr);
  481. auto logits = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  482. auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  483. AbstractBasePtrList args_spec_list = {logits, labels};
  484. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  485. ASSERT_NE(ret, nullptr);
  486. auto res = dyn_cast<AbstractTuple>(ret);
  487. auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
  488. auto dLogits = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  489. AbstractBasePtrList expected_list = {loss, dLogits};
  490. auto expected = std::make_shared<AbstractTuple>(expected_list);
  491. MS_LOG(INFO) << "result: " << res->ToString();
  492. MS_LOG(INFO) << "expected: " << expected->ToString();
  493. auto res_ptr0 = dyn_cast<AbstractTuple>(res);
  494. auto expected_ptr0 = dyn_cast<AbstractTuple>(expected);
  495. ASSERT_GT((*res_ptr0).size(), 1);
  496. auto res_ptr = dyn_cast<AbstractTensor>((*res_ptr0)[1]);
  497. ASSERT_GT((*expected_ptr0).size(), 1);
  498. auto expected_ptr = dyn_cast<AbstractTensor>((*expected_ptr0)[1]);
  499. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  500. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  501. }
  502. TEST_F(TestPrim, test_tensor_to_scalar_prim) {
  503. FuncGraphPtr func_graph = getPyFun("get_tensor_to_scalar");
  504. ASSERT_TRUE(func_graph != nullptr);
  505. draw::Draw("get_tensor_to_scalar.dot", func_graph);
  506. auto logits = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
  507. auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
  508. AbstractBasePtrList args_spec_list = {logits, labels};
  509. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  510. auto res = dyn_cast<AbstractScalar>(ret);
  511. AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
  512. expected->set_type(UTPrimUtils::kF64);
  513. MS_LOG(INFO) << "result: " << res->ToString();
  514. MS_LOG(INFO) << "expected: " << expected->ToString();
  515. ASSERT_TRUE(*res == *expected);
  516. }
  517. TEST_F(TestPrim, test_fused_batch_norm) {
  518. PrimitivePtr fused_batch_norm = prim::kPrimFusedBatchNorm;
  519. fused_batch_norm->AddAttr("epsilon", MakeValue(0.001f));
  520. fused_batch_norm->AddAttr("momentum", MakeValue(0.1f));
  521. FuncGraphPtr func_graph = MakeFuncGraph(fused_batch_norm, 5);
  522. // NCHW
  523. std::vector<int> inputs_dims = {128, 64, 32, 64};
  524. std::vector<int> scale_dims = {64};
  525. std::vector<int> offset_dims = {64};
  526. std::vector<int> mean_dims = {64};
  527. std::vector<int> variance_dims = {64};
  528. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  529. inputs->set_data_type(kNumberTypeFloat32);
  530. inputs->set_shape(inputs_dims);
  531. tensor::TensorPtr scale = std::make_shared<tensor::Tensor>();
  532. scale->set_data_type(kNumberTypeFloat32);
  533. scale->set_shape(scale_dims);
  534. tensor::TensorPtr offset = std::make_shared<tensor::Tensor>();
  535. offset->set_data_type(kNumberTypeFloat32);
  536. offset->set_shape(offset_dims);
  537. tensor::TensorPtr mean = std::make_shared<tensor::Tensor>();
  538. mean->set_data_type(kNumberTypeFloat32);
  539. mean->set_shape(mean_dims);
  540. tensor::TensorPtr variance = std::make_shared<tensor::Tensor>();
  541. variance->set_data_type(kNumberTypeFloat32);
  542. variance->set_shape(variance_dims);
  543. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  544. AbstractBasePtr abstract_scale = FromValue(scale, true);
  545. AbstractBasePtr abstract_offset = FromValue(offset, true);
  546. AbstractBasePtr abstract_mean = FromValue(mean, true);
  547. AbstractBasePtr abstract_variance = FromValue(variance, true);
  548. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_scale, abstract_offset, abstract_mean,
  549. abstract_variance};
  550. AbstractBasePtr expected0 = abstract_inputs->Clone();
  551. AbstractBasePtr expected1 = abstract_scale->Clone();
  552. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  553. MS_LOG(INFO) << "result: " << res->ToString();
  554. MS_LOG(INFO) << "expected0: " << expected0->ToString();
  555. MS_LOG(INFO) << "expected1: " << expected1->ToString();
  556. std::shared_ptr<AbstractTuple> abs_tuple = dyn_cast<AbstractTuple>(res);
  557. ASSERT_TRUE(abs_tuple != nullptr);
  558. ASSERT_TRUE(*abs_tuple->elements()[0] == *expected0);
  559. ASSERT_TRUE(*abs_tuple->elements()[1] == *expected1);
  560. ASSERT_TRUE(*abs_tuple->elements()[2] == *expected1);
  561. ASSERT_TRUE(*abs_tuple->elements()[3] == *expected1);
  562. ASSERT_TRUE(*abs_tuple->elements()[4] == *expected1);
  563. }
  564. TEST_F(TestPrim, test_pooling) {
  565. PrimitivePtr pooling = prim::kPrimPooling;
  566. pooling->AddAttr("mode", MakeValue(std::string("avg")));
  567. pooling->AddAttr("pad_mode", MakeValue(std::string("valid")));
  568. pooling->AddAttr("nan_opt", MakeValue(0));
  569. pooling->AddAttr("window", MakeValue(2));
  570. pooling->AddAttr("pad", MakeValue(1));
  571. pooling->AddAttr("stride", MakeValue(1));
  572. pooling->AddAttr("data_mode", MakeValue(1));
  573. pooling->AddAttr("ceil_mode", MakeValue(0));
  574. FuncGraphPtr func_graph = MakeFuncGraph(pooling, 1);
  575. std::vector<int> inputs_dims = {8, 64, 3, 3};
  576. auto inputs = std::make_shared<tensor::Tensor>();
  577. inputs->set_data_type(kNumberTypeFloat32);
  578. inputs->set_shape(inputs_dims);
  579. AbstractBasePtr abstract_input = FromValue(inputs, false);
  580. AbstractBasePtrList args_spec_list = {abstract_input};
  581. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  582. AbstractBasePtr expected = abstract_input->Clone()->Broaden();
  583. std::vector<int> expected_dims = {8, 64, 2, 2};
  584. expected->set_shape(std::make_shared<Shape>(expected_dims));
  585. MS_LOG(INFO) << "result: " << res->ToString();
  586. MS_LOG(INFO) << "expected: " << expected->ToString();
  587. ASSERT_TRUE(*res == *expected);
  588. }
  589. TEST_F(TestPrim, test_hastype) {
  590. AbstractBasePtrList args_spec_list;
  591. int v1 = 1;
  592. TypePtr v2 = std::make_shared<Number>();
  593. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  594. AbstractTypePtr abstract_v2 = UTPrimUtils::TypeToAbstract(v2);
  595. AbstractBasePtr expected = FromValue(true, false);
  596. args_spec_list.push_back(abstract_v1);
  597. args_spec_list.push_back(abstract_v2);
  598. auto prim = std::make_shared<Primitive>("hastype");
  599. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  600. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  601. ASSERT_TRUE(*res == *expected);
  602. }
  603. TEST_F(TestPrim, test_array_len) {
  604. AbstractBasePtrList args_spec_list;
  605. auto v1 = UTPrimUtils::ArrayFloat64Of({3, 4, 0, 2});
  606. auto expected = std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  607. args_spec_list.push_back(v1);
  608. auto prim = std::make_shared<Primitive>("array_len");
  609. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  610. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  611. ASSERT_TRUE(*res == *expected);
  612. }
  613. TEST_F(TestPrim, test_list_len) {
  614. AbstractBasePtrList args_spec_list;
  615. auto v1 = UTPrimUtils::ListShapeOf({3, 4, 0, 2});
  616. auto expected = std::make_shared<AbstractScalar>(4);
  617. args_spec_list.push_back(v1);
  618. auto prim = std::make_shared<Primitive>("list_len");
  619. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  620. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  621. ASSERT_TRUE(*res == *expected);
  622. }
  623. TEST_F(TestPrim, test_tuple_len) {
  624. AbstractBasePtrList args_spec_list;
  625. auto v1 = UTPrimUtils::ShapeOf({3, 4, 0, 2});
  626. auto expected = std::make_shared<AbstractScalar>(4);
  627. args_spec_list.push_back(v1);
  628. auto prim = std::make_shared<Primitive>("tuple_len");
  629. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  630. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  631. ASSERT_TRUE(*res == *expected);
  632. }
  633. TEST_F(TestPrim, test_tuple_reversed) {
  634. AbstractBasePtrList args_spec_list;
  635. auto v1 = UTPrimUtils::ShapeOf({0, 1, 2, 3});
  636. auto expected = UTPrimUtils::ShapeOf({3, 2, 1, 0});
  637. args_spec_list.push_back(v1);
  638. auto prim = std::make_shared<Primitive>("tuple_reversed");
  639. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  640. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  641. MS_LOG(INFO) << "expect=" << expected->ToString();
  642. ASSERT_TRUE(*res == *expected);
  643. }
  644. TEST_F(TestPrim, test_list_getitem) {
  645. AbstractBasePtrList args_spec_list;
  646. int v1 = 2;
  647. int v2 = 1;
  648. AbstractBasePtr elem = FromValue(v1, false);
  649. AbstractBasePtr elem2 = FromValue(v2, false);
  650. AbstractBasePtrList elems = {elem, elem};
  651. auto abstract_v1 = std::make_shared<AbstractList>(elems);
  652. AbstractBasePtr abstract_v2 = FromValue(v2, false);
  653. args_spec_list.push_back(abstract_v1);
  654. args_spec_list.push_back(abstract_v2);
  655. auto prim = std::make_shared<Primitive>("list_getitem");
  656. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  657. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  658. ASSERT_TRUE(*res == *elem);
  659. }
  660. TEST_F(TestPrim, test_list_setitem) {
  661. int v1 = 1;
  662. int v2 = 2;
  663. AbstractBasePtr elem1 = FromValue(v1, false);
  664. AbstractBasePtr elem2 = FromValue(v2, false);
  665. AbstractBasePtrList elems = {elem1, elem1};
  666. auto abstract_tuple = std::make_shared<AbstractList>(elems);
  667. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  668. AbstractBasePtr abstract_v3 = FromValue(v2, false);
  669. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2, abstract_v3};
  670. auto prim = std::make_shared<Primitive>("list_setitem");
  671. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  672. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  673. MS_LOG(INFO) << "result: " << res->ToString();
  674. AbstractBasePtrList elems_exp = {elem1, elem2};
  675. auto expected = std::make_shared<AbstractList>(elems_exp);
  676. MS_LOG(INFO) << "expected: " << expected->ToString();
  677. auto res_list = dyn_cast<AbstractList>(res);
  678. ASSERT_TRUE(*expected == *res_list);
  679. }
  680. TEST_F(TestPrim, test_list_append) {
  681. int v1 = 1;
  682. AbstractBasePtr elem1 = FromValue(v1, false);
  683. AbstractBasePtr elem2 = FromValue(v1, false);
  684. auto abstract_tuple = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
  685. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  686. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2};
  687. auto prim = std::make_shared<Primitive>("list_append");
  688. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  689. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  690. MS_LOG(INFO) << "result: " << res->ToString();
  691. auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
  692. MS_LOG(INFO) << "expected: " << expected->ToString();
  693. auto res_list = dyn_cast<AbstractList>(res);
  694. ASSERT_TRUE(*res_list == *expected);
  695. }
  696. TEST_F(TestPrim, test_tuple_setitem) {
  697. int v1 = 1;
  698. int v2 = 2;
  699. AbstractBasePtr elem1 = FromValue(v1, false);
  700. AbstractBasePtr elem2 = FromValue(v2, false);
  701. AbstractBasePtrList elems = {elem1, elem1};
  702. auto abstract_tuple = std::make_shared<AbstractTuple>(elems);
  703. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  704. AbstractBasePtr abstract_v3 = FromValue(v2, false);
  705. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2, abstract_v3};
  706. auto prim = std::make_shared<Primitive>("tuple_setitem");
  707. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  708. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  709. MS_LOG(INFO) << "result: " << res->ToString();
  710. AbstractBasePtrList elems_exp = {elem1, elem2};
  711. auto expected = std::make_shared<AbstractTuple>(elems_exp);
  712. MS_LOG(INFO) << "expected: " << expected->ToString();
  713. auto res_tuple = dyn_cast<AbstractTuple>(res);
  714. ASSERT_TRUE(*res == *expected);
  715. }
  716. TEST_F(TestPrim, test_make_list) {
  717. AbstractBasePtrList args_spec_list;
  718. int v1 = 2;
  719. int v2 = 2;
  720. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  721. AbstractBasePtr abstract_v2 = FromValue(v2, false);
  722. auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
  723. args_spec_list.push_back(abstract_v1);
  724. args_spec_list.push_back(abstract_v2);
  725. auto prim = std::make_shared<Primitive>("make_list");
  726. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  727. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  728. ASSERT_TRUE(*res == *expected);
  729. }
  730. TEST_F(TestPrim, test_make_range) {
  731. AbstractBasePtrList args_spec_list;
  732. int v1 = 1;
  733. int v2 = 4;
  734. AbstractBasePtr abstract_v1 = FromValue(v1);
  735. AbstractBasePtr abstract_v2 = FromValue(v2);
  736. args_spec_list.push_back(abstract_v1);
  737. args_spec_list.push_back(abstract_v2);
  738. auto prim = std::make_shared<Primitive>("make_range");
  739. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(prim, 2);
  740. AbstractBasePtr ele1 = FromValue(1);
  741. AbstractBasePtr ele2 = FromValue(2);
  742. AbstractBasePtr ele3 = FromValue(3);
  743. AbstractBasePtrList elem_list({ele1, ele2, ele3});
  744. AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);
  745. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  746. MS_LOG(INFO) << "res=" << res->ToString();
  747. MS_LOG(INFO) << "expected=" << expected->ToString();
  748. ASSERT_TRUE(*res == *expected);
  749. }
  750. TEST_F(TestPrim, test_layernorm) {
  751. PrimitivePtr layerNorm = prim::kPrimLayerNorm;
  752. layerNorm->AddAttr("begin_norm_axis", MakeValue(1));
  753. layerNorm->AddAttr("begin_params_axis", MakeValue(1));
  754. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(layerNorm, 3);
  755. std::vector<int> inputs_dims = {128, 64, 32, 64};
  756. std::vector<int> mean_var_dims = {128, 64, 32, 1};
  757. std::vector<int> params_dims = {64, 32, 64};
  758. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  759. inputs->set_data_type(kNumberTypeFloat32);
  760. inputs->set_shape(inputs_dims);
  761. tensor::TensorPtr mean_var = std::make_shared<tensor::Tensor>();
  762. mean_var->set_data_type(kNumberTypeFloat32);
  763. mean_var->set_shape(mean_var_dims);
  764. tensor::TensorPtr gamma = std::make_shared<tensor::Tensor>();
  765. gamma->set_data_type(kNumberTypeFloat32);
  766. gamma->set_shape(params_dims);
  767. tensor::TensorPtr beta = std::make_shared<tensor::Tensor>();
  768. beta->set_data_type(kNumberTypeFloat32);
  769. beta->set_shape(params_dims);
  770. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  771. AbstractBasePtr abstract_mean_var = FromValue(mean_var, true);
  772. AbstractBasePtr abstract_gamma = FromValue(gamma, true);
  773. AbstractBasePtr abstract_beta = FromValue(beta, true);
  774. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_gamma, abstract_beta};
  775. AbstractBasePtr expected0 = abstract_inputs->Clone();
  776. AbstractBasePtr expected1 = abstract_mean_var->Clone();
  777. AbstractBasePtr expected2 = abstract_mean_var->Clone();
  778. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  779. MS_LOG(INFO) << "result: " << res->ToString();
  780. MS_LOG(INFO) << "expected0: " << expected0->ToString();
  781. MS_LOG(INFO) << "expected1: " << expected1->ToString();
  782. MS_LOG(INFO) << "expected2: " << expected2->ToString();
  783. std::shared_ptr<AbstractTuple> abs_tuple = dyn_cast<AbstractTuple>(res);
  784. ASSERT_TRUE(abs_tuple != nullptr);
  785. auto res_ptr0 = dyn_cast<AbstractTensor>(abs_tuple->elements()[0]);
  786. auto expected_ptr0 = dyn_cast<AbstractTensor>(expected0);
  787. ASSERT_TRUE(*res_ptr0->shape() == *expected_ptr0->shape());
  788. ASSERT_TRUE(*res_ptr0->element() == *expected_ptr0->element());
  789. auto res_ptr1 = dyn_cast<AbstractTensor>(abs_tuple->elements()[1]);
  790. auto expected_ptr1 = dyn_cast<AbstractTensor>(expected1);
  791. ASSERT_TRUE(*res_ptr1->shape() == *expected_ptr1->shape());
  792. ASSERT_TRUE(*res_ptr1->element() == *expected_ptr1->element());
  793. auto res_ptr2 = dyn_cast<AbstractTensor>(abs_tuple->elements()[2]);
  794. auto expected_ptr2 = dyn_cast<AbstractTensor>(expected2);
  795. ASSERT_TRUE(*res_ptr2->shape() == *expected_ptr2->shape());
  796. ASSERT_TRUE(*res_ptr2->element() == *expected_ptr2->element());
  797. }
  798. TEST_F(TestPrim, test_DropoutGenMask) {
  799. AbstractBasePtrList args_spec_list;
  800. auto arg0 = UTPrimUtils::ShapeOf({5, 5, 5, 5});
  801. std::vector<int> keep_prob_shape = {};
  802. tensor::TensorPtr keep_prob = std::make_shared<tensor::Tensor>(0.5f);
  803. keep_prob->set_data_type(kNumberTypeFloat32);
  804. keep_prob->set_shape(keep_prob_shape);
  805. AbstractBasePtr abstract_keep_prob = FromValue(keep_prob);
  806. auto prim = std::make_shared<Primitive>("DropoutGenMask");
  807. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(prim, 2);
  808. args_spec_list.push_back(arg0);
  809. args_spec_list.push_back(abstract_keep_prob);
  810. // should return a tensor with on dimension of 79 elements
  811. AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
  812. std::make_shared<Shape>(std::vector<int>{79}));
  813. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  814. MS_LOG(INFO) << "res=" << res->ToString();
  815. MS_LOG(INFO) << "expected=" << expected->ToString();
  816. ASSERT_TRUE(*res == *expected);
  817. }
  818. TEST_F(TestPrim, test_dropout) {
  819. std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
  820. std::shared_ptr<FuncGraph> func_graph = getPyFun.CallAndParseRet("test_dropout");
  821. std::vector<int> inputs_dims = {2, 20, 32, 32};
  822. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  823. inputs->set_data_type(kNumberTypeFloat32);
  824. inputs->set_shape(inputs_dims);
  825. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  826. std::vector<int> keep_prob_shape = {};
  827. tensor::TensorPtr keep_prob = std::make_shared<tensor::Tensor>(0.5f);
  828. keep_prob->set_data_type(kNumberTypeFloat32);
  829. keep_prob->set_shape(keep_prob_shape);
  830. AbstractBasePtr abstract_keep_prob = FromValue(keep_prob);
  831. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_keep_prob};
  832. AbstractBasePtr expected = abstract_inputs->Clone();
  833. // NCHW
  834. std::vector<int> shape = {2, 20, 32, 32};
  835. expected->set_shape(std::make_shared<Shape>(shape));
  836. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  837. MS_LOG(INFO) << "result: " << res->ToString();
  838. MS_LOG(INFO) << "expected: " << expected->ToString();
  839. auto res_ptr = dyn_cast<AbstractTensor>(res);
  840. auto expected_ptr = dyn_cast<AbstractTensor>(expected);
  841. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  842. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  843. }
  844. TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) {
  845. PrimitivePtr broadcatGradientArgs = prim::kPrimBroadcastGradientArgs;
  846. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(broadcatGradientArgs, 2);
  847. // broadcast shape: x: 8,5,3, y:3
  848. // output: ((),(0, 1))
  849. AbstractBasePtrList x_arg_list({abstract::FromValue(8), abstract::FromValue(5), abstract::FromValue(3)});
  850. AbstractBasePtrList y_arg_list({abstract::FromValue(3)});
  851. auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
  852. auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
  853. AbstractBasePtrList args_spec_list = {x_input, y_input};
  854. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  855. auto res = dyn_cast<AbstractTuple>(ret);
  856. AbstractBasePtrList x_idx_list;
  857. auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
  858. AbstractBasePtrList y_idx_list({abstract::FromValue(0), abstract::FromValue(1)});
  859. auto r_y = std::make_shared<AbstractTuple>(y_idx_list);
  860. AbstractBasePtrList elem_list({r_x, r_y});
  861. auto expected = std::make_shared<AbstractTuple>(elem_list);
  862. MS_LOG(INFO) << "result: " << res->ToString();
  863. MS_LOG(INFO) << "expected: " << expected->ToString();
  864. ASSERT_TRUE(*res == *expected);
  865. }
  866. TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) {
  867. PrimitivePtr broadcatGradientArgs = prim::kPrimBroadcastGradientArgs;
  868. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(broadcatGradientArgs, 2);
  869. // broadcast shape: x: 8,1,3, y:8 5 3
  870. // output: ((1),())
  871. AbstractBasePtrList x_arg_list({abstract::FromValue(8), abstract::FromValue(1), abstract::FromValue(3)});
  872. AbstractBasePtrList y_arg_list({abstract::FromValue(8), abstract::FromValue(5), abstract::FromValue(3)});
  873. auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
  874. auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
  875. AbstractBasePtrList args_spec_list = {x_input, y_input};
  876. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  877. auto res = dyn_cast<AbstractTuple>(ret);
  878. AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
  879. auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
  880. AbstractBasePtrList y_idx_list;
  881. auto r_y = std::make_shared<AbstractTuple>(y_idx_list);
  882. AbstractBasePtrList elem_list({r_x, r_y});
  883. auto expected = std::make_shared<AbstractTuple>(elem_list);
  884. MS_LOG(INFO) << "result: " << res->ToString();
  885. MS_LOG(INFO) << "expected: " << expected->ToString();
  886. ASSERT_TRUE(*res == *expected);
  887. }
  888. TEST_F(TestPrim, test_DictGetItem) {
  889. PrimitivePtr dictGetItem = prim::kPrimDictGetItem;
  890. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(dictGetItem, 2);
  891. std::vector<std::pair<std::string, ValuePtr>> tensor_map = {
  892. {"x", std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int>{2, 3, 4})},
  893. {"y", std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int>{2, 1, 4})}};
  894. ValueDictionary value_dict(tensor_map);
  895. AbstractBasePtr array_dict = value_dict.ToAbstract();
  896. AbstractBasePtr key = abstract::FromValue("x");
  897. AbstractBasePtrList args_spec_list = {array_dict, key};
  898. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  899. AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
  900. AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));
  901. ASSERT_TRUE(*tensor_ret == *expect);
  902. }
  903. TEST_F(TestPrim, test_DictGetItem2) {
  904. PrimitivePtr dictGetItem = prim::kPrimDictGetItem;
  905. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(dictGetItem, 2);
  906. AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5});
  907. AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5});
  908. AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5});
  909. std::vector<AbstractAttribute> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  910. AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map);
  911. AbstractBasePtr key = abstract::FromValue("x");
  912. AbstractBasePtrList args_spec_list = {array_dict, key};
  913. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  914. AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
  915. AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);
  916. ASSERT_TRUE(*tensor_ret == *expect);
  917. }
  918. */
  919. } // namespace abstract
  920. } // namespace mindspore