You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_test.cc 46 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include "pybind11/pybind11.h"
  19. #include "common/common_test.h"
  20. #include "common/py_func_graph_fetcher.h"
  21. #include "ir/manager.h"
  22. #include "pipeline/jit/static_analysis/prim.h"
  23. #include "pipeline/static_analysis/helper.h"
  24. #include "frontend/operator/ops.h"
  25. #include "debug/draw.h"
  26. #include "ir/tensor.h"
  27. #include "utils/symbolic.h"
  28. namespace mindspore {
  29. namespace abstract {
  30. namespace py = pybind11;
  31. namespace python_adapter = mindspore::parse::python_adapter;
  32. class UTPrimUtils {
  33. public:
  34. using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
  35. using AbstractTuplePtr = std::shared_ptr<AbstractTuple>;
  36. static const std::shared_ptr<Float> kF32;
  37. static const std::shared_ptr<Float> kF64;
  38. static const std::shared_ptr<Int> kI16;
  39. static const std::shared_ptr<Int> kI64;
  40. static const std::shared_ptr<UInt> kU64;
  41. static std::shared_ptr<AbstractType> TypeToAbstract(TypePtr t) { return std::make_shared<AbstractType>(t); }
  42. static AbstractTensorPtr ArrayFloat64Of(std::initializer_list<int64_t> shp) {
  43. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
  44. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  45. }
  46. static AbstractTensorPtr ArrayFloat32Of(std::initializer_list<int64_t> shp) {
  47. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kFloat32);
  48. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  49. }
  50. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
  51. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt64);
  52. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  53. }
  54. static AbstractTuplePtr ShapeOf(std::initializer_list<int64_t> vals) {
  55. AbstractBasePtrList te;
  56. for (auto v : vals) {
  57. te.push_back(std::make_shared<AbstractScalar>(v));
  58. }
  59. return std::make_shared<AbstractTuple>(te);
  60. }
  61. static AbstractListPtr ListShapeOf(std::initializer_list<int64_t> vals) {
  62. AbstractBasePtrList te;
  63. for (auto v : vals) {
  64. te.push_back(std::make_shared<AbstractScalar>(v));
  65. }
  66. return std::make_shared<AbstractList>(te);
  67. }
  68. };
  69. const std::shared_ptr<Float> UTPrimUtils::kF64 = std::make_shared<Float>(64);
  70. const std::shared_ptr<Float> UTPrimUtils::kF32 = std::make_shared<Float>(32);
  71. const std::shared_ptr<Int> UTPrimUtils::kI16 = std::make_shared<Int>(16);
  72. const std::shared_ptr<Int> UTPrimUtils::kI64 = std::make_shared<Int>(64);
  73. const std::shared_ptr<UInt> UTPrimUtils::kU64 = std::make_shared<UInt>(64);
  74. namespace {
  75. /* skip ut test cases temporarily
  76. AbstractBasePtr ArrayOfTensor(const TypePtr &t, std::initializer_list<int64_t> shp) {
  77. auto shape = std::vector<int64_t>(shp);
  78. auto tensor = std::make_shared<tensor::Tensor>(t->type_id(), shape);
  79. return ToAbstract(tensor);
  80. }
  81. */
  82. } // namespace
  83. class TestPrim : public UT::Common {
  84. public:
  85. TestPrim() : getPyFun("gtest_input.pipeline.infer", true) {}
  86. void SetUp();
  87. void TearDown();
  88. AnalysisEnginePtr engine_;
  89. UT::PyFuncGraphFetcher getPyFun;
  90. };
  91. void TestPrim::SetUp() { engine_ = SetupAnalysisEngine(); }
  92. void TestPrim::TearDown() {
  93. // destroy resource
  94. }
  95. static FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, uint64_t nparam) {
  96. // build the func_graph manually, eg:
  97. // MakeFuncGraph(std::make_shared<Primitive>("scalar_add"), 2) means:
  98. /* python source code:
  99. * @mindspore
  100. * def f(x, y):
  101. * return x + y
  102. */
  103. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  104. std::vector<AnfNodePtr> inputs;
  105. inputs.push_back(NewValueNode(prim));
  106. for (uint64_t i = 0; i < nparam; i++) {
  107. inputs.push_back(func_graph->add_parameter());
  108. }
  109. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  110. inputs.clear();
  111. inputs.push_back(NewValueNode(prim::kPrimReturn));
  112. inputs.push_back(cnode_prim);
  113. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  114. func_graph->set_return(cnode_return);
  115. return func_graph;
  116. }
  117. TEST_F(TestPrim, test_typeof) {
  118. AbstractBasePtrList args_spec_list;
  119. int64_t v1 = 1;
  120. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  121. args_spec_list.push_back(abstract_v1);
  122. auto prim_typeof = std::make_shared<Primitive>("typeof");
  123. FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
  124. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  125. res->dump();
  126. TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
  127. res_value->dump();
  128. ASSERT_TRUE(*res_value == Int(64));
  129. }
  130. TEST_F(TestPrim, test_list_map) {
  131. AbstractBasePtrList args_spec_list;
  132. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
  133. AbstractBasePtr abstract_u1 = FromValue(static_cast<int64_t>(1), false);
  134. auto abstract_list1 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_u1}));
  135. AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), false);
  136. AbstractBasePtr abstract_u2 = FromValue(static_cast<int64_t>(2), false);
  137. auto abstract_list2 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v2, abstract_u2}));
  138. auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
  139. AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
  140. args_spec_list.push_back(abstract_func);
  141. args_spec_list.push_back(abstract_list1);
  142. args_spec_list.push_back(abstract_list2);
  143. auto prim_list_map = std::make_shared<Primitive>("list_map");
  144. FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
  145. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  146. auto expected = std::make_shared<AbstractList>(
  147. AbstractBasePtrList({FromValue(static_cast<int64_t>(3), false), FromValue(static_cast<int64_t>(3), false)}));
  148. res->dump();
  149. MS_LOG(INFO) << "result res: " << res->ToString();
  150. MS_LOG(INFO) << "result expected: " << expected->ToString();
  151. ASSERT_TRUE(*res == *expected);
  152. }
  153. TEST_F(TestPrim, test_list_reduce) {
  154. AbstractBasePtrList args_spec_list;
  155. int64_t v1 = 1;
  156. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  157. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  158. auto abstract_list = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
  159. auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
  160. AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
  161. args_spec_list.push_back(abstract_func);
  162. args_spec_list.push_back(abstract_list);
  163. args_spec_list.push_back(abstract_v1);
  164. auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
  165. FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
  166. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  167. res->dump();
  168. TypePtr res_type = res->GetTypeTrack();
  169. res_type->dump();
  170. ASSERT_TRUE(*res_type == Int(64));
  171. }
  172. TEST_F(TestPrim, test_scalar_to_array) {
  173. AbstractBasePtrList args_spec_list;
  174. int64_t v1 = 1;
  175. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  176. args_spec_list.push_back(abstract_v1);
  177. auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
  178. FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
  179. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  180. res->dump();
  181. TypePtr res_type = res->BuildType();
  182. res_type->dump();
  183. ASSERT_TRUE(*res_type == TensorType(std::make_shared<Int>(64)));
  184. }
  185. TEST_F(TestPrim, test_array_to_scalar) {
  186. AbstractBasePtrList args_spec_list;
  187. int64_t v1 = 1;
  188. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  189. auto abstract_a1 = std::make_shared<AbstractTensor>(abstract_v1, std::make_shared<Shape>());
  190. args_spec_list.push_back(abstract_a1);
  191. auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
  192. FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
  193. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  194. res->dump();
  195. TypePtr res_type = res->BuildType();
  196. res_type->dump();
  197. ASSERT_TRUE(*res_type == Int(64));
  198. }
  199. TEST_F(TestPrim, test_J_1) {
  200. AbstractBasePtrList args_spec_list;
  201. int64_t v1 = 1;
  202. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  203. args_spec_list.push_back(abstract_v1);
  204. auto prim_J = std::make_shared<Primitive>("J");
  205. FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
  206. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  207. AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
  208. ASSERT_TRUE(res_J != nullptr);
  209. ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
  210. }
  211. TEST_F(TestPrim, test_J_2) {
  212. // def add(x):
  213. // return x + x
  214. // def f(x):
  215. // return J(add)(x)
  216. std::vector<AnfNodePtr> inputs;
  217. FuncGraphPtr func_graph1 = std::make_shared<FuncGraph>();
  218. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  219. auto x = func_graph1->add_parameter();
  220. inputs.push_back(x);
  221. inputs.push_back(x);
  222. CNodePtr cnode1 = func_graph1->NewCNode(inputs);
  223. func_graph1->set_return(cnode1);
  224. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  225. inputs.clear();
  226. auto x1 = func_graph->add_parameter();
  227. inputs.clear();
  228. inputs.push_back(NewValueNode(prim::kPrimJ));
  229. inputs.push_back(NewValueNode(func_graph1));
  230. CNodePtr jf = func_graph->NewCNode(inputs);
  231. inputs.clear();
  232. inputs.push_back(jf);
  233. inputs.push_back(x1);
  234. CNodePtr jf_jx = func_graph->NewCNode(inputs);
  235. inputs.clear();
  236. inputs.push_back(NewValueNode(prim::kPrimReturn));
  237. inputs.push_back(jf_jx);
  238. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  239. func_graph->set_return(cnode_return);
  240. draw::Draw("test_J_2.dot", func_graph);
  241. int64_t v1 = 1;
  242. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  243. AbstractBasePtrList args_spec_list = {abstract_v1};
  244. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  245. res->dump();
  246. AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
  247. ASSERT_TRUE(res_J != nullptr);
  248. auto res_J_0 = res_J->elements()[0];
  249. ASSERT_TRUE(res_J_0 != nullptr);
  250. ASSERT_TRUE(*res_J_0 == *(FromValue(static_cast<int64_t>(2), false)));
  251. AbstractFunctionPtr res_J_1 = dyn_cast<AbstractFunction>(res_J->elements()[1]);
  252. ASSERT_TRUE(res_J_1 != nullptr);
  253. }
  254. TEST_F(TestPrim, test_dot) {
  255. auto dot = std::make_shared<Primitive>("dot");
  256. FuncGraphPtr func_graph = MakeFuncGraph(dot, 2);
  257. auto a1 = UTPrimUtils::ArrayFloat64Of({2, 3});
  258. auto a2 = UTPrimUtils::ArrayFloat64Of({3, 4});
  259. std::vector<int64_t> expectedA = {2, 4};
  260. auto expected = UTPrimUtils::ArrayFloat64Of({2, 4});
  261. AbstractBasePtrList args_spec_list = {a1, a2};
  262. AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
  263. ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
  264. }
  265. // tail half
  266. TEST_F(TestPrim, test_switch1) {
  267. PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
  268. FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
  269. AbstractBasePtr arg0 = FromValue(true, false);
  270. AbstractBasePtr arg1 = FromValue(static_cast<int64_t>(1), false);
  271. AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false);
  272. AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
  273. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  274. ASSERT_TRUE(*res == *arg1);
  275. }
  276. TEST_F(TestPrim, test_switch2) {
  277. PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
  278. FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
  279. AbstractBasePtr arg0 = FromValue(false, false);
  280. AbstractBasePtr arg1 = FromValue(static_cast<int64_t>(1), false);
  281. AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false);
  282. AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
  283. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  284. MS_LOG(INFO) << "make result res: " << res->ToString();
  285. MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
  286. ASSERT_TRUE(*res == *arg2);
  287. }
  288. TEST_F(TestPrim, test_identity) {
  289. PrimitivePtr identity = std::make_shared<Primitive>("identity");
  290. FuncGraphPtr func_graph = MakeFuncGraph(identity, 1);
  291. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
  292. AbstractBasePtrList args_spec_list = {abstract_v1};
  293. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  294. ASSERT_TRUE(*res == *abstract_v1);
  295. }
  296. TEST_F(TestPrim, test_broadcast_shape) {
  297. PrimitivePtr broadcast_shape = std::make_shared<Primitive>("broadcast_shape");
  298. FuncGraphPtr func_graph = MakeFuncGraph(broadcast_shape, 2);
  299. auto a = UTPrimUtils::ShapeOf({Shape::SHP_ANY, Shape::SHP_ANY});
  300. auto b = UTPrimUtils::ShapeOf({Shape::SHP_ANY});
  301. std::vector<Any> expected{Shape::SHP_ANY, Shape::SHP_ANY};
  302. AbstractBasePtrList args_spec_list = {a, b};
  303. AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
  304. auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
  305. std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
  306. ASSERT_TRUE(ret.size() == element_list.size());
  307. for (int64_t i = 0; i < element_list.size(); i++) {
  308. ASSERT_TRUE(*ret[i] == *element_list[i]);
  309. }
  310. }
  311. TEST_F(TestPrim, test_partial) {
  312. PrimitivePtr prim = prim::kPrimPartial;
  313. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  314. PrimitivePtr add = prim::kPrimScalarAdd;
  315. AbstractBasePtr abstract_add = ToAbstract(add);
  316. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
  317. AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(1), false);
  318. AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};
  319. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  320. AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
  321. auto expected = std::make_shared<PartialAbstractClosure>(
  322. std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
  323. MS_LOG(INFO) << "result: " << res->ToString();
  324. MS_LOG(INFO) << "expected: " << expected->ToString();
  325. ASSERT_TRUE(res->ToString() == expected->ToString());
  326. }
  327. // def test_env(x, y):
  328. // return env_setitem(newenv, embed(x), y)
  329. TEST_F(TestPrim, test_env_setitem) {
  330. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  331. AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
  332. AbstractBasePtrList args_spec_list = {abstract_x};
  333. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
  334. FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
  335. AbstractBasePtr abstract_env = ToAbstract(newenv);
  336. AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
  337. args_spec_list = {abstract_env, embed_x, abstract_y};
  338. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  339. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  340. ASSERT_TRUE(*res == *exp);
  341. }
  342. // def test_env(x, y, z):
  343. // e = env_setitem(newenv, embed(x), y)
  344. // return env_getitem(e, embed(x), z)
  345. TEST_F(TestPrim, test_env_getitem) {
  346. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  347. AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
  348. AbstractBasePtrList args_spec_list = {abstract_x};
  349. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
  350. FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
  351. AbstractBasePtr abstract_env = ToAbstract(newenv);
  352. AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
  353. args_spec_list = {abstract_env, embed_x, abstract_y};
  354. AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
  355. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  356. ASSERT_TRUE(*res == *exp);
  357. FuncGraphPtr graph_getitem = MakeFuncGraph(prim::kPrimEnvGetItem, 3);
  358. AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
  359. args_spec_list = {res, embed_x, abstract_z};
  360. res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract();
  361. ASSERT_TRUE(*res == *abstract_x);
  362. }
  363. // def test_env(x, y, z):
  364. // e1 = env_setitem(newenv, embed(x), y)
  365. // e2 = env_setitem(newenv, embed(x), z)
  366. // return env_add(e1, e2)
  367. TEST_F(TestPrim, test_env_add) {
  368. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  369. AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
  370. AbstractBasePtrList args_spec_list = {abstract_x};
  371. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
  372. FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
  373. AbstractBasePtr abstract_env = ToAbstract(newenv);
  374. AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
  375. args_spec_list = {abstract_env, embed_x, abstract_y};
  376. AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
  377. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  378. ASSERT_TRUE(*abstract_e1 == *exp);
  379. AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
  380. args_spec_list = {abstract_env, embed_x, abstract_z};
  381. AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
  382. ASSERT_TRUE(*abstract_e2 == *exp);
  383. FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
  384. args_spec_list = {abstract_e1, abstract_e2};
  385. AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();
  386. ASSERT_TRUE(*res == *exp);
  387. }
  388. TEST_F(TestPrim, test_relu) {
  389. PrimitivePtr relu = prim::kPrimRelu;
  390. relu->AddAttr("T", MakeValue(static_cast<int64_t>(kNumberTypeFloat64)));
  391. FuncGraphPtr func_graph = MakeFuncGraph(relu, 1);
  392. AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW
  393. AbstractBasePtrList args_spec_list = {expected};
  394. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  395. ASSERT_TRUE(*res == *expected);
  396. }
  397. /*
  398. TEST_F(TestPrim, test_relu2) {
  399. FuncGraphPtr func_graph = getPyFun("get_relu");
  400. ASSERT_TRUE(func_graph != nullptr);
  401. draw::Draw("test_relu.dot", func_graph);
  402. auto arr = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
  403. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
  404. AbstractBasePtrList args_spec_list = {arr};
  405. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  406. auto res = dyn_cast<AbstractTensor>(ret);
  407. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  408. }
  409. TEST_F(TestPrim, test_conv2d1) {
  410. std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
  411. py::tuple kernel_size(2);
  412. kernel_size[0] = 5;
  413. kernel_size[1] = 5;
  414. std::shared_ptr<FuncGraph> func_graph = getPyFun.CallAndParseRet("test_conv2d", 64, kernel_size, 0, 2, 1);
  415. // NCHW
  416. std::vector<int64_t> inputs_dims = {2, 20, 32, 32};
  417. std::vector<int64_t> weight_dims = {64, 20, 5, 5};
  418. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  419. inputs->set_data_type(kNumberTypeInt32);
  420. inputs->set_shape(inputs_dims);
  421. // Cout, Cin, kernel_size
  422. tensor::TensorPtr weight = std::make_shared<tensor::Tensor>();
  423. weight->set_data_type(kNumberTypeInt32);
  424. weight->set_shape(weight_dims);
  425. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  426. AbstractBasePtr abstract_weight = FromValue(weight, true);
  427. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_weight};
  428. AbstractBasePtr expected = abstract_inputs->Clone();
  429. // NCHW
  430. std::vector<int64_t> shape = {2, 64, 14, 14};
  431. expected->set_shape(std::make_shared<Shape>(shape));
  432. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  433. MS_LOG(INFO) << "result: " << res->ToString();
  434. MS_LOG(INFO) << "expected: " << expected->ToString();
  435. auto res_ptr = dyn_cast<AbstractTensor>(res);
  436. auto expected_ptr = dyn_cast<AbstractTensor>(expected);
  437. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  438. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  439. }
  440. TEST_F(TestPrim, test_conv2d) {
  441. FuncGraphPtr func_graph = getPyFun("get_conv2d");
  442. ASSERT_TRUE(func_graph != nullptr);
  443. auto input = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  444. auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});
  445. AbstractBasePtrList args_spec_list = {input, weight};
  446. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  447. auto res = dyn_cast<AbstractTensor>(ret);
  448. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
  449. MS_LOG(INFO) << "result: " << res->ToString();
  450. MS_LOG(INFO) << "expected: " << expected->ToString();
  451. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  452. }
  453. TEST_F(TestPrim, test_conv2d_native) {
  454. FuncGraphPtr func_graph = getPyFun("get_conv2d_native");
  455. ASSERT_TRUE(func_graph != nullptr);
  456. auto input = ArrayOfTensor(UTPrimUtils::kF64, {10, 32, 32, 32});
  457. auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});
  458. AbstractBasePtrList args_spec_list = {input, weight};
  459. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  460. auto res = dyn_cast<AbstractTensor>(ret);
  461. auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
  462. MS_LOG(INFO) << "result: " << res->ToString();
  463. MS_LOG(INFO) << "expected: " << expected->ToString();
  464. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  465. }
  466. TEST_F(TestPrim, test_biasAdd) {
  467. FuncGraphPtr func_graph = getPyFun("get_bias_add");
  468. ASSERT_TRUE(func_graph != nullptr);
  469. auto value = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  470. auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});
  471. AbstractBasePtrList args_spec_list = {value, bias};
  472. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  473. auto res = dyn_cast<AbstractTensor>(ret);
  474. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  475. MS_LOG(INFO) << "result: " << res->ToString();
  476. MS_LOG(INFO) << "expected: " << expected->ToString();
  477. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  478. }
  479. TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) {
  480. FuncGraphPtr func_graph = getPyFun("get_softmax_cross_entropy_with_logits");
  481. ASSERT_TRUE(func_graph != nullptr);
  482. auto logits = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  483. auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  484. AbstractBasePtrList args_spec_list = {logits, labels};
  485. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  486. ASSERT_NE(ret, nullptr);
  487. auto res = dyn_cast<AbstractTuple>(ret);
  488. auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
  489. auto dLogits = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  490. AbstractBasePtrList expected_list = {loss, dLogits};
  491. auto expected = std::make_shared<AbstractTuple>(expected_list);
  492. MS_LOG(INFO) << "result: " << res->ToString();
  493. MS_LOG(INFO) << "expected: " << expected->ToString();
  494. auto res_ptr0 = dyn_cast<AbstractTuple>(res);
  495. auto expected_ptr0 = dyn_cast<AbstractTuple>(expected);
  496. ASSERT_GT((*res_ptr0).size(), 1);
  497. auto res_ptr = dyn_cast<AbstractTensor>((*res_ptr0)[1]);
  498. ASSERT_GT((*expected_ptr0).size(), 1);
  499. auto expected_ptr = dyn_cast<AbstractTensor>((*expected_ptr0)[1]);
  500. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  501. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  502. }
  503. TEST_F(TestPrim, test_tensor_to_scalar_prim) {
  504. FuncGraphPtr func_graph = getPyFun("get_tensor_to_scalar");
  505. ASSERT_TRUE(func_graph != nullptr);
  506. draw::Draw("get_tensor_to_scalar.dot", func_graph);
  507. auto logits = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
  508. auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
  509. AbstractBasePtrList args_spec_list = {logits, labels};
  510. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  511. auto res = dyn_cast<AbstractScalar>(ret);
  512. AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
  513. expected->set_type(UTPrimUtils::kF64);
  514. MS_LOG(INFO) << "result: " << res->ToString();
  515. MS_LOG(INFO) << "expected: " << expected->ToString();
  516. ASSERT_TRUE(*res == *expected);
  517. }
  518. TEST_F(TestPrim, test_fused_batch_norm) {
  519. PrimitivePtr fused_batch_norm = prim::kPrimFusedBatchNorm;
  520. fused_batch_norm->AddAttr("epsilon", MakeValue(0.001f));
  521. fused_batch_norm->AddAttr("momentum", MakeValue(0.1f));
  522. FuncGraphPtr func_graph = MakeFuncGraph(fused_batch_norm, 5);
  523. // NCHW
  524. std::vector<int64_t> inputs_dims = {128, 64, 32, 64};
  525. std::vector<int64_t> scale_dims = {64};
  526. std::vector<int64_t> offset_dims = {64};
  527. std::vector<int64_t> mean_dims = {64};
  528. std::vector<int64_t> variance_dims = {64};
  529. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  530. inputs->set_data_type(kNumberTypeFloat32);
  531. inputs->set_shape(inputs_dims);
  532. tensor::TensorPtr scale = std::make_shared<tensor::Tensor>();
  533. scale->set_data_type(kNumberTypeFloat32);
  534. scale->set_shape(scale_dims);
  535. tensor::TensorPtr offset = std::make_shared<tensor::Tensor>();
  536. offset->set_data_type(kNumberTypeFloat32);
  537. offset->set_shape(offset_dims);
  538. tensor::TensorPtr mean = std::make_shared<tensor::Tensor>();
  539. mean->set_data_type(kNumberTypeFloat32);
  540. mean->set_shape(mean_dims);
  541. tensor::TensorPtr variance = std::make_shared<tensor::Tensor>();
  542. variance->set_data_type(kNumberTypeFloat32);
  543. variance->set_shape(variance_dims);
  544. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  545. AbstractBasePtr abstract_scale = FromValue(scale, true);
  546. AbstractBasePtr abstract_offset = FromValue(offset, true);
  547. AbstractBasePtr abstract_mean = FromValue(mean, true);
  548. AbstractBasePtr abstract_variance = FromValue(variance, true);
  549. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_scale, abstract_offset, abstract_mean,
  550. abstract_variance};
  551. AbstractBasePtr expected0 = abstract_inputs->Clone();
  552. AbstractBasePtr expected1 = abstract_scale->Clone();
  553. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  554. MS_LOG(INFO) << "result: " << res->ToString();
  555. MS_LOG(INFO) << "expected0: " << expected0->ToString();
  556. MS_LOG(INFO) << "expected1: " << expected1->ToString();
  557. std::shared_ptr<AbstractTuple> abs_tuple = dyn_cast<AbstractTuple>(res);
  558. ASSERT_TRUE(abs_tuple != nullptr);
  559. ASSERT_TRUE(*abs_tuple->elements()[0] == *expected0);
  560. ASSERT_TRUE(*abs_tuple->elements()[1] == *expected1);
  561. ASSERT_TRUE(*abs_tuple->elements()[2] == *expected1);
  562. ASSERT_TRUE(*abs_tuple->elements()[3] == *expected1);
  563. ASSERT_TRUE(*abs_tuple->elements()[4] == *expected1);
  564. }
  565. TEST_F(TestPrim, test_pooling) {
  566. PrimitivePtr pooling = prim::kPrimPooling;
  567. pooling->AddAttr("mode", MakeValue(std::string("avg")));
  568. pooling->AddAttr("pad_mode", MakeValue(std::string("valid")));
  569. pooling->AddAttr("nan_opt", MakeValue(0));
  570. pooling->AddAttr("window", MakeValue(2));
  571. pooling->AddAttr("pad", MakeValue(1));
  572. pooling->AddAttr("stride", MakeValue(1));
  573. pooling->AddAttr("data_mode", MakeValue(1));
  574. pooling->AddAttr("ceil_mode", MakeValue(0));
  575. FuncGraphPtr func_graph = MakeFuncGraph(pooling, 1);
  576. std::vector<int64_t> inputs_dims = {8, 64, 3, 3};
  577. auto inputs = std::make_shared<tensor::Tensor>();
  578. inputs->set_data_type(kNumberTypeFloat32);
  579. inputs->set_shape(inputs_dims);
  580. AbstractBasePtr abstract_input = FromValue(inputs, false);
  581. AbstractBasePtrList args_spec_list = {abstract_input};
  582. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  583. AbstractBasePtr expected = abstract_input->Clone()->Broaden();
  584. std::vector<int64_t> expected_dims = {8, 64, 2, 2};
  585. expected->set_shape(std::make_shared<Shape>(expected_dims));
  586. MS_LOG(INFO) << "result: " << res->ToString();
  587. MS_LOG(INFO) << "expected: " << expected->ToString();
  588. ASSERT_TRUE(*res == *expected);
  589. }
  590. TEST_F(TestPrim, test_hastype) {
  591. AbstractBasePtrList args_spec_list;
  592. int64_t v1 = 1;
  593. TypePtr v2 = std::make_shared<Number>();
  594. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  595. AbstractTypePtr abstract_v2 = UTPrimUtils::TypeToAbstract(v2);
  596. AbstractBasePtr expected = FromValue(true, false);
  597. args_spec_list.push_back(abstract_v1);
  598. args_spec_list.push_back(abstract_v2);
  599. auto prim = std::make_shared<Primitive>("hastype");
  600. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  601. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  602. ASSERT_TRUE(*res == *expected);
  603. }
  604. TEST_F(TestPrim, test_array_len) {
  605. AbstractBasePtrList args_spec_list;
  606. auto v1 = UTPrimUtils::ArrayFloat64Of({3, 4, 0, 2});
  607. auto expected = std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  608. args_spec_list.push_back(v1);
  609. auto prim = std::make_shared<Primitive>("array_len");
  610. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  611. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  612. ASSERT_TRUE(*res == *expected);
  613. }
  614. TEST_F(TestPrim, test_list_len) {
  615. AbstractBasePtrList args_spec_list;
  616. auto v1 = UTPrimUtils::ListShapeOf({3, 4, 0, 2});
  617. auto expected = std::make_shared<AbstractScalar>(4);
  618. args_spec_list.push_back(v1);
  619. auto prim = std::make_shared<Primitive>("list_len");
  620. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  621. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  622. ASSERT_TRUE(*res == *expected);
  623. }
  624. TEST_F(TestPrim, test_tuple_len) {
  625. AbstractBasePtrList args_spec_list;
  626. auto v1 = UTPrimUtils::ShapeOf({3, 4, 0, 2});
  627. auto expected = std::make_shared<AbstractScalar>(4);
  628. args_spec_list.push_back(v1);
  629. auto prim = std::make_shared<Primitive>("tuple_len");
  630. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  631. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  632. ASSERT_TRUE(*res == *expected);
  633. }
  634. TEST_F(TestPrim, test_tuple_reversed) {
  635. AbstractBasePtrList args_spec_list;
  636. auto v1 = UTPrimUtils::ShapeOf({0, 1, 2, 3});
  637. auto expected = UTPrimUtils::ShapeOf({3, 2, 1, 0});
  638. args_spec_list.push_back(v1);
  639. auto prim = std::make_shared<Primitive>("tuple_reversed");
  640. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  641. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  642. MS_LOG(INFO) << "expect=" << expected->ToString();
  643. ASSERT_TRUE(*res == *expected);
  644. }
  645. TEST_F(TestPrim, test_list_getitem) {
  646. AbstractBasePtrList args_spec_list;
  647. int64_t v1 = 2;
  648. int64_t v2 = 1;
  649. AbstractBasePtr elem = FromValue(v1, false);
  650. AbstractBasePtr elem2 = FromValue(v2, false);
  651. AbstractBasePtrList elems = {elem, elem};
  652. auto abstract_v1 = std::make_shared<AbstractList>(elems);
  653. AbstractBasePtr abstract_v2 = FromValue(v2, false);
  654. args_spec_list.push_back(abstract_v1);
  655. args_spec_list.push_back(abstract_v2);
  656. auto prim = std::make_shared<Primitive>("list_getitem");
  657. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  658. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  659. ASSERT_TRUE(*res == *elem);
  660. }
  661. TEST_F(TestPrim, test_list_setitem) {
  662. int64_t v1 = 1;
  663. int64_t v2 = 2;
  664. AbstractBasePtr elem1 = FromValue(v1, false);
  665. AbstractBasePtr elem2 = FromValue(v2, false);
  666. AbstractBasePtrList elems = {elem1, elem1};
  667. auto abstract_tuple = std::make_shared<AbstractList>(elems);
  668. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  669. AbstractBasePtr abstract_v3 = FromValue(v2, false);
  670. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2, abstract_v3};
  671. auto prim = std::make_shared<Primitive>("list_setitem");
  672. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  673. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  674. MS_LOG(INFO) << "result: " << res->ToString();
  675. AbstractBasePtrList elems_exp = {elem1, elem2};
  676. auto expected = std::make_shared<AbstractList>(elems_exp);
  677. MS_LOG(INFO) << "expected: " << expected->ToString();
  678. auto res_list = dyn_cast<AbstractList>(res);
  679. ASSERT_TRUE(*expected == *res_list);
  680. }
  681. TEST_F(TestPrim, test_list_append) {
  682. int64_t v1 = 1;
  683. AbstractBasePtr elem1 = FromValue(v1, false);
  684. AbstractBasePtr elem2 = FromValue(v1, false);
  685. auto abstract_tuple = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
  686. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  687. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2};
  688. auto prim = std::make_shared<Primitive>("list_append");
  689. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  690. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  691. MS_LOG(INFO) << "result: " << res->ToString();
  692. auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
  693. MS_LOG(INFO) << "expected: " << expected->ToString();
  694. auto res_list = dyn_cast<AbstractList>(res);
  695. ASSERT_TRUE(*res_list == *expected);
  696. }
  697. TEST_F(TestPrim, test_tuple_setitem) {
  698. int64_t v1 = 1;
  699. int64_t v2 = 2;
  700. AbstractBasePtr elem1 = FromValue(v1, false);
  701. AbstractBasePtr elem2 = FromValue(v2, false);
  702. AbstractBasePtrList elems = {elem1, elem1};
  703. auto abstract_tuple = std::make_shared<AbstractTuple>(elems);
  704. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  705. AbstractBasePtr abstract_v3 = FromValue(v2, false);
  706. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2, abstract_v3};
  707. auto prim = std::make_shared<Primitive>("tuple_setitem");
  708. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  709. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  710. MS_LOG(INFO) << "result: " << res->ToString();
  711. AbstractBasePtrList elems_exp = {elem1, elem2};
  712. auto expected = std::make_shared<AbstractTuple>(elems_exp);
  713. MS_LOG(INFO) << "expected: " << expected->ToString();
  714. auto res_tuple = dyn_cast<AbstractTuple>(res);
  715. ASSERT_TRUE(*res == *expected);
  716. }
  717. TEST_F(TestPrim, test_make_list) {
  718. AbstractBasePtrList args_spec_list;
  719. int64_t v1 = 2;
  720. int64_t v2 = 2;
  721. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  722. AbstractBasePtr abstract_v2 = FromValue(v2, false);
  723. auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
  724. args_spec_list.push_back(abstract_v1);
  725. args_spec_list.push_back(abstract_v2);
  726. auto prim = std::make_shared<Primitive>("make_list");
  727. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  728. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  729. ASSERT_TRUE(*res == *expected);
  730. }
  731. TEST_F(TestPrim, test_make_range) {
  732. AbstractBasePtrList args_spec_list;
  733. int64_t v1 = 1;
  734. int64_t v2 = 4;
  735. AbstractBasePtr abstract_v1 = FromValue(v1);
  736. AbstractBasePtr abstract_v2 = FromValue(v2);
  737. args_spec_list.push_back(abstract_v1);
  738. args_spec_list.push_back(abstract_v2);
  739. auto prim = std::make_shared<Primitive>("make_range");
  740. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(prim, 2);
  741. AbstractBasePtr ele1 = FromValue(1);
  742. AbstractBasePtr ele2 = FromValue(2);
  743. AbstractBasePtr ele3 = FromValue(3);
  744. AbstractBasePtrList elem_list({ele1, ele2, ele3});
  745. AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);
  746. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  747. MS_LOG(INFO) << "res=" << res->ToString();
  748. MS_LOG(INFO) << "expected=" << expected->ToString();
  749. ASSERT_TRUE(*res == *expected);
  750. }
  751. TEST_F(TestPrim, test_layernorm) {
  752. PrimitivePtr layerNorm = prim::kPrimLayerNorm;
  753. layerNorm->AddAttr("begin_norm_axis", MakeValue(1));
  754. layerNorm->AddAttr("begin_params_axis", MakeValue(1));
  755. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(layerNorm, 3);
  756. std::vector<int64_t> inputs_dims = {128, 64, 32, 64};
  757. std::vector<int64_t> mean_var_dims = {128, 64, 32, 1};
  758. std::vector<int64_t> params_dims = {64, 32, 64};
  759. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  760. inputs->set_data_type(kNumberTypeFloat32);
  761. inputs->set_shape(inputs_dims);
  762. tensor::TensorPtr mean_var = std::make_shared<tensor::Tensor>();
  763. mean_var->set_data_type(kNumberTypeFloat32);
  764. mean_var->set_shape(mean_var_dims);
  765. tensor::TensorPtr gamma = std::make_shared<tensor::Tensor>();
  766. gamma->set_data_type(kNumberTypeFloat32);
  767. gamma->set_shape(params_dims);
  768. tensor::TensorPtr beta = std::make_shared<tensor::Tensor>();
  769. beta->set_data_type(kNumberTypeFloat32);
  770. beta->set_shape(params_dims);
  771. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  772. AbstractBasePtr abstract_mean_var = FromValue(mean_var, true);
  773. AbstractBasePtr abstract_gamma = FromValue(gamma, true);
  774. AbstractBasePtr abstract_beta = FromValue(beta, true);
  775. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_gamma, abstract_beta};
  776. AbstractBasePtr expected0 = abstract_inputs->Clone();
  777. AbstractBasePtr expected1 = abstract_mean_var->Clone();
  778. AbstractBasePtr expected2 = abstract_mean_var->Clone();
  779. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  780. MS_LOG(INFO) << "result: " << res->ToString();
  781. MS_LOG(INFO) << "expected0: " << expected0->ToString();
  782. MS_LOG(INFO) << "expected1: " << expected1->ToString();
  783. MS_LOG(INFO) << "expected2: " << expected2->ToString();
  784. std::shared_ptr<AbstractTuple> abs_tuple = dyn_cast<AbstractTuple>(res);
  785. ASSERT_TRUE(abs_tuple != nullptr);
  786. auto res_ptr0 = dyn_cast<AbstractTensor>(abs_tuple->elements()[0]);
  787. auto expected_ptr0 = dyn_cast<AbstractTensor>(expected0);
  788. ASSERT_TRUE(*res_ptr0->shape() == *expected_ptr0->shape());
  789. ASSERT_TRUE(*res_ptr0->element() == *expected_ptr0->element());
  790. auto res_ptr1 = dyn_cast<AbstractTensor>(abs_tuple->elements()[1]);
  791. auto expected_ptr1 = dyn_cast<AbstractTensor>(expected1);
  792. ASSERT_TRUE(*res_ptr1->shape() == *expected_ptr1->shape());
  793. ASSERT_TRUE(*res_ptr1->element() == *expected_ptr1->element());
  794. auto res_ptr2 = dyn_cast<AbstractTensor>(abs_tuple->elements()[2]);
  795. auto expected_ptr2 = dyn_cast<AbstractTensor>(expected2);
  796. ASSERT_TRUE(*res_ptr2->shape() == *expected_ptr2->shape());
  797. ASSERT_TRUE(*res_ptr2->element() == *expected_ptr2->element());
  798. }
  799. TEST_F(TestPrim, test_DropoutGenMask) {
  800. AbstractBasePtrList args_spec_list;
  801. auto arg0 = UTPrimUtils::ShapeOf({5, 5, 5, 5});
  802. std::vector<int64_t> keep_prob_shape = {};
  803. tensor::TensorPtr keep_prob = std::make_shared<tensor::Tensor>(0.5f);
  804. keep_prob->set_data_type(kNumberTypeFloat32);
  805. keep_prob->set_shape(keep_prob_shape);
  806. AbstractBasePtr abstract_keep_prob = FromValue(keep_prob);
  807. auto prim = std::make_shared<Primitive>("DropoutGenMask");
  808. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(prim, 2);
  809. args_spec_list.push_back(arg0);
  810. args_spec_list.push_back(abstract_keep_prob);
  811. // should return a tensor with on dimension of 79 elements
  812. AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
  813. std::make_shared<Shape>(std::vector<int64_t>{79}));
  814. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  815. MS_LOG(INFO) << "res=" << res->ToString();
  816. MS_LOG(INFO) << "expected=" << expected->ToString();
  817. ASSERT_TRUE(*res == *expected);
  818. }
  819. TEST_F(TestPrim, test_dropout) {
  820. std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
  821. std::shared_ptr<FuncGraph> func_graph = getPyFun.CallAndParseRet("test_dropout");
  822. std::vector<int64_t> inputs_dims = {2, 20, 32, 32};
  823. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  824. inputs->set_data_type(kNumberTypeFloat32);
  825. inputs->set_shape(inputs_dims);
  826. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  827. std::vector<int64_t> keep_prob_shape = {};
  828. tensor::TensorPtr keep_prob = std::make_shared<tensor::Tensor>(0.5f);
  829. keep_prob->set_data_type(kNumberTypeFloat32);
  830. keep_prob->set_shape(keep_prob_shape);
  831. AbstractBasePtr abstract_keep_prob = FromValue(keep_prob);
  832. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_keep_prob};
  833. AbstractBasePtr expected = abstract_inputs->Clone();
  834. // NCHW
  835. std::vector<int64_t> shape = {2, 20, 32, 32};
  836. expected->set_shape(std::make_shared<Shape>(shape));
  837. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  838. MS_LOG(INFO) << "result: " << res->ToString();
  839. MS_LOG(INFO) << "expected: " << expected->ToString();
  840. auto res_ptr = dyn_cast<AbstractTensor>(res);
  841. auto expected_ptr = dyn_cast<AbstractTensor>(expected);
  842. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  843. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  844. }
  845. TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) {
  846. PrimitivePtr broadcatGradientArgs = prim::kPrimBroadcastGradientArgs;
  847. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(broadcatGradientArgs, 2);
  848. // broadcast shape: x: 8,5,3, y:3
  849. // output: ((),(0, 1))
  850. AbstractBasePtrList x_arg_list({abstract::FromValue(8), abstract::FromValue(5), abstract::FromValue(3)});
  851. AbstractBasePtrList y_arg_list({abstract::FromValue(3)});
  852. auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
  853. auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
  854. AbstractBasePtrList args_spec_list = {x_input, y_input};
  855. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  856. auto res = dyn_cast<AbstractTuple>(ret);
  857. AbstractBasePtrList x_idx_list;
  858. auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
  859. AbstractBasePtrList y_idx_list({abstract::FromValue(0), abstract::FromValue(1)});
  860. auto r_y = std::make_shared<AbstractTuple>(y_idx_list);
  861. AbstractBasePtrList elem_list({r_x, r_y});
  862. auto expected = std::make_shared<AbstractTuple>(elem_list);
  863. MS_LOG(INFO) << "result: " << res->ToString();
  864. MS_LOG(INFO) << "expected: " << expected->ToString();
  865. ASSERT_TRUE(*res == *expected);
  866. }
  867. TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) {
  868. PrimitivePtr broadcatGradientArgs = prim::kPrimBroadcastGradientArgs;
  869. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(broadcatGradientArgs, 2);
  870. // broadcast shape: x: 8,1,3, y:8 5 3
  871. // output: ((1),())
  872. AbstractBasePtrList x_arg_list({abstract::FromValue(8), abstract::FromValue(1), abstract::FromValue(3)});
  873. AbstractBasePtrList y_arg_list({abstract::FromValue(8), abstract::FromValue(5), abstract::FromValue(3)});
  874. auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
  875. auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
  876. AbstractBasePtrList args_spec_list = {x_input, y_input};
  877. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  878. auto res = dyn_cast<AbstractTuple>(ret);
  879. AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
  880. auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
  881. AbstractBasePtrList y_idx_list;
  882. auto r_y = std::make_shared<AbstractTuple>(y_idx_list);
  883. AbstractBasePtrList elem_list({r_x, r_y});
  884. auto expected = std::make_shared<AbstractTuple>(elem_list);
  885. MS_LOG(INFO) << "result: " << res->ToString();
  886. MS_LOG(INFO) << "expected: " << expected->ToString();
  887. ASSERT_TRUE(*res == *expected);
  888. }
  889. TEST_F(TestPrim, test_DictGetItem) {
  890. PrimitivePtr dictGetItem = prim::kPrimDictGetItem;
  891. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(dictGetItem, 2);
  892. std::vector<std::pair<std::string, ValuePtr>> tensor_map = {
  893. {"x", std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>{2, 3, 4})},
  894. {"y", std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>{2, 1, 4})}};
  895. ValueDictionary value_dict(tensor_map);
  896. AbstractBasePtr array_dict = value_dict.ToAbstract();
  897. AbstractBasePtr key = abstract::FromValue("x");
  898. AbstractBasePtrList args_spec_list = {array_dict, key};
  899. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  900. AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
  901. AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));
  902. ASSERT_TRUE(*tensor_ret == *expect);
  903. }
  904. TEST_F(TestPrim, test_DictGetItem2) {
  905. PrimitivePtr dictGetItem = prim::kPrimDictGetItem;
  906. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(dictGetItem, 2);
  907. AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5});
  908. AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5});
  909. AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5});
  910. std::vector<AbstractAttribute> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  911. AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map);
  912. AbstractBasePtr key = abstract::FromValue("x");
  913. AbstractBasePtrList args_spec_list = {array_dict, key};
  914. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
  915. AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
  916. AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);
  917. ASSERT_TRUE(*tensor_ret == *expect);
  918. }
  919. */
  920. } // namespace abstract
  921. } // namespace mindspore