You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_test.cc 43 kB

5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include "pybind11/pybind11.h"
  19. #include "common/common_test.h"
  20. #include "common/py_func_graph_fetcher.h"
  21. #include "ir/manager.h"
  22. #include "pipeline/jit/static_analysis/prim.h"
  23. #include "pipeline/static_analysis/helper.h"
  24. #include "frontend/operator/ops.h"
  25. #include "debug/draw.h"
  26. #include "ir/tensor.h"
  27. #include "utils/symbolic.h"
  28. #include "base/core_ops.h"
  29. namespace mindspore {
  30. namespace abstract {
  31. namespace py = pybind11;
  32. namespace python_adapter = mindspore::parse::python_adapter;
  33. class UTPrimUtils {
  34. public:
  35. using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
  36. using AbstractTuplePtr = std::shared_ptr<AbstractTuple>;
  37. static const std::shared_ptr<Float> kF32;
  38. static const std::shared_ptr<Float> kF64;
  39. static const std::shared_ptr<Int> kI16;
  40. static const std::shared_ptr<Int> kI64;
  41. static const std::shared_ptr<UInt> kU64;
  42. static std::shared_ptr<AbstractType> TypeToAbstract(TypePtr t) { return std::make_shared<AbstractType>(t); }
  43. static AbstractTensorPtr ArrayFloat64Of(std::initializer_list<int64_t> shp) {
  44. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
  45. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  46. }
  47. static AbstractTensorPtr ArrayFloat32Of(std::initializer_list<int64_t> shp) {
  48. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kFloat32);
  49. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  50. }
  51. static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
  52. auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt64);
  53. return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
  54. }
  55. static AbstractTuplePtr ShapeOf(std::initializer_list<int64_t> vals) {
  56. AbstractBasePtrList te;
  57. for (auto v : vals) {
  58. te.push_back(std::make_shared<AbstractScalar>(v));
  59. }
  60. return std::make_shared<AbstractTuple>(te);
  61. }
  62. static AbstractListPtr ListShapeOf(std::initializer_list<int64_t> vals) {
  63. AbstractBasePtrList te;
  64. for (auto v : vals) {
  65. te.push_back(std::make_shared<AbstractScalar>(v));
  66. }
  67. return std::make_shared<AbstractList>(te);
  68. }
  69. };
  70. const std::shared_ptr<Float> UTPrimUtils::kF64 = std::make_shared<Float>(64);
  71. const std::shared_ptr<Float> UTPrimUtils::kF32 = std::make_shared<Float>(32);
  72. const std::shared_ptr<Int> UTPrimUtils::kI16 = std::make_shared<Int>(16);
  73. const std::shared_ptr<Int> UTPrimUtils::kI64 = std::make_shared<Int>(64);
  74. const std::shared_ptr<UInt> UTPrimUtils::kU64 = std::make_shared<UInt>(64);
  75. namespace {
  76. /* skip ut test cases temporarily
  77. AbstractBasePtr ArrayOfTensor(const TypePtr &t, std::initializer_list<int64_t> shp) {
  78. auto shape = std::vector<int64_t>(shp);
  79. auto tensor = std::make_shared<tensor::Tensor>(t->type_id(), shape);
  80. return ToAbstract(tensor);
  81. }
  82. */
  83. } // namespace
  84. class TestPrim : public UT::Common {
  85. public:
  86. TestPrim() : getPyFun("gtest_input.pipeline.infer", true) {}
  87. void SetUp();
  88. void TearDown();
  89. AnalysisEnginePtr engine_;
  90. UT::PyFuncGraphFetcher getPyFun;
  91. };
  92. void TestPrim::SetUp() { engine_ = SetupAnalysisEngine(); }
  93. void TestPrim::TearDown() {
  94. // destroy resource
  95. }
  96. static FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, uint64_t nparam) {
  97. // build the func_graph manually, eg:
  98. // MakeFuncGraph(std::make_shared<Primitive>("scalar_add"), 2) means:
  99. /* python source code:
  100. * @mindspore
  101. * def f(x, y):
  102. * return x + y
  103. */
  104. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  105. std::vector<AnfNodePtr> inputs;
  106. inputs.push_back(NewValueNode(prim));
  107. for (uint64_t i = 0; i < nparam; i++) {
  108. inputs.push_back(func_graph->add_parameter());
  109. }
  110. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  111. inputs.clear();
  112. inputs.push_back(NewValueNode(prim::kPrimReturn));
  113. inputs.push_back(cnode_prim);
  114. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  115. func_graph->set_return(cnode_return);
  116. return func_graph;
  117. }
  118. TEST_F(TestPrim, test_typeof) {
  119. AbstractBasePtrList args_spec_list;
  120. int64_t v1 = 1;
  121. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  122. args_spec_list.push_back(abstract_v1);
  123. auto prim_typeof = std::make_shared<Primitive>("typeof");
  124. FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
  125. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  126. res->dump();
  127. TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
  128. res_value->dump();
  129. ASSERT_TRUE(*res_value == Int(64));
  130. }
  131. TEST_F(TestPrim, test_list_map) {
  132. AbstractBasePtrList args_spec_list;
  133. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
  134. AbstractBasePtr abstract_u1 = FromValue(static_cast<int64_t>(1), false);
  135. auto abstract_list1 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_u1}));
  136. AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), false);
  137. AbstractBasePtr abstract_u2 = FromValue(static_cast<int64_t>(2), false);
  138. auto abstract_list2 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v2, abstract_u2}));
  139. auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
  140. AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
  141. args_spec_list.push_back(abstract_func);
  142. args_spec_list.push_back(abstract_list1);
  143. args_spec_list.push_back(abstract_list2);
  144. auto prim_list_map = std::make_shared<Primitive>("list_map");
  145. FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
  146. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  147. auto expected = std::make_shared<AbstractList>(
  148. AbstractBasePtrList({FromValue(static_cast<int64_t>(3), false), FromValue(static_cast<int64_t>(3), false)}));
  149. res->dump();
  150. MS_LOG(INFO) << "result res: " << res->ToString();
  151. MS_LOG(INFO) << "result expected: " << expected->ToString();
  152. ASSERT_TRUE(*res == *expected);
  153. }
  154. TEST_F(TestPrim, test_list_reduce) {
  155. AbstractBasePtrList args_spec_list;
  156. int64_t v1 = 1;
  157. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  158. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  159. auto abstract_list = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
  160. auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
  161. AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
  162. args_spec_list.push_back(abstract_func);
  163. args_spec_list.push_back(abstract_list);
  164. args_spec_list.push_back(abstract_v1);
  165. auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
  166. FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
  167. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  168. res->dump();
  169. TypePtr res_type = res->GetTypeTrack();
  170. res_type->dump();
  171. ASSERT_TRUE(*res_type == Int(64));
  172. }
  173. TEST_F(TestPrim, test_scalar_to_array) {
  174. AbstractBasePtrList args_spec_list;
  175. int64_t v1 = 1;
  176. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  177. args_spec_list.push_back(abstract_v1);
  178. auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
  179. FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
  180. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  181. res->dump();
  182. TypePtr res_type = res->BuildType();
  183. res_type->dump();
  184. ASSERT_TRUE(*res_type == TensorType(std::make_shared<Int>(64)));
  185. }
  186. TEST_F(TestPrim, test_array_to_scalar) {
  187. AbstractBasePtrList args_spec_list;
  188. int64_t v1 = 1;
  189. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  190. auto abstract_a1 = std::make_shared<AbstractTensor>(abstract_v1, std::make_shared<Shape>());
  191. args_spec_list.push_back(abstract_a1);
  192. auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
  193. FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
  194. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  195. res->dump();
  196. TypePtr res_type = res->BuildType();
  197. res_type->dump();
  198. ASSERT_TRUE(*res_type == Int(64));
  199. }
  200. TEST_F(TestPrim, test_J_1) {
  201. AbstractBasePtrList args_spec_list;
  202. int64_t v1 = 1;
  203. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  204. args_spec_list.push_back(abstract_v1);
  205. auto prim_J = std::make_shared<Primitive>("J");
  206. FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
  207. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  208. AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
  209. ASSERT_TRUE(res_J != nullptr);
  210. ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
  211. }
  212. TEST_F(TestPrim, test_J_2) {
  213. // def add(x):
  214. // return x + x
  215. // def f(x):
  216. // return J(add)(x)
  217. std::vector<AnfNodePtr> inputs;
  218. FuncGraphPtr func_graph1 = std::make_shared<FuncGraph>();
  219. inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
  220. auto x = func_graph1->add_parameter();
  221. inputs.push_back(x);
  222. inputs.push_back(x);
  223. CNodePtr cnode1 = func_graph1->NewCNode(inputs);
  224. func_graph1->set_return(cnode1);
  225. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  226. inputs.clear();
  227. auto x1 = func_graph->add_parameter();
  228. inputs.clear();
  229. inputs.push_back(NewValueNode(prim::kPrimJ));
  230. inputs.push_back(NewValueNode(func_graph1));
  231. CNodePtr jf = func_graph->NewCNode(inputs);
  232. inputs.clear();
  233. inputs.push_back(jf);
  234. inputs.push_back(x1);
  235. CNodePtr jf_jx = func_graph->NewCNode(inputs);
  236. inputs.clear();
  237. inputs.push_back(NewValueNode(prim::kPrimReturn));
  238. inputs.push_back(jf_jx);
  239. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  240. func_graph->set_return(cnode_return);
  241. int64_t v1 = 1;
  242. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  243. AbstractBasePtrList args_spec_list = {abstract_v1};
  244. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  245. res->dump();
  246. AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
  247. ASSERT_TRUE(res_J != nullptr);
  248. auto res_J_0 = res_J->elements()[0];
  249. ASSERT_TRUE(res_J_0 != nullptr);
  250. ASSERT_TRUE(*res_J_0 == *(FromValue(static_cast<int64_t>(2), false)));
  251. AbstractFunctionPtr res_J_1 = dyn_cast<AbstractFunction>(res_J->elements()[1]);
  252. ASSERT_TRUE(res_J_1 != nullptr);
  253. }
  254. // tail half
  255. TEST_F(TestPrim, test_switch1) {
  256. PrimitivePtr switch_ = std::make_shared<Primitive>("Switch");
  257. FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
  258. AbstractBasePtr arg0 = FromValue(true, false);
  259. AbstractBasePtr arg1 = FromValue(static_cast<int64_t>(1), false);
  260. AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false);
  261. AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
  262. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  263. ASSERT_TRUE(*res == *arg1);
  264. }
  265. TEST_F(TestPrim, test_switch2) {
  266. PrimitivePtr switch_ = std::make_shared<Primitive>("Switch");
  267. FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
  268. AbstractBasePtr arg0 = FromValue(false, false);
  269. AbstractBasePtr arg1 = FromValue(static_cast<int64_t>(1), false);
  270. AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false);
  271. AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
  272. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  273. MS_LOG(INFO) << "make result res: " << res->ToString();
  274. MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
  275. ASSERT_TRUE(*res == *arg2);
  276. }
  277. TEST_F(TestPrim, test_identity) {
  278. PrimitivePtr identity = std::make_shared<Primitive>("identity");
  279. FuncGraphPtr func_graph = MakeFuncGraph(identity, 1);
  280. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
  281. AbstractBasePtrList args_spec_list = {abstract_v1};
  282. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  283. ASSERT_TRUE(*res == *abstract_v1);
  284. }
  285. TEST_F(TestPrim, test_broadcast_shape) {
  286. PrimitivePtr broadcast_shape = std::make_shared<Primitive>("broadcast_shape");
  287. FuncGraphPtr func_graph = MakeFuncGraph(broadcast_shape, 2);
  288. auto a = UTPrimUtils::ShapeOf({Shape::SHP_ANY, Shape::SHP_ANY});
  289. auto b = UTPrimUtils::ShapeOf({Shape::SHP_ANY});
  290. std::vector<Any> expected{Shape::SHP_ANY, Shape::SHP_ANY};
  291. AbstractBasePtrList args_spec_list = {a, b};
  292. AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).eval_result->abstract());
  293. auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
  294. std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
  295. ASSERT_TRUE(ret.size() == element_list.size());
  296. for (int64_t i = 0; i < element_list.size(); i++) {
  297. ASSERT_TRUE(*ret[i] == *element_list[i]);
  298. }
  299. }
  300. TEST_F(TestPrim, test_partial) {
  301. PrimitivePtr prim = prim::kPrimPartial;
  302. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  303. PrimitivePtr add = prim::kPrimScalarAdd;
  304. AbstractBasePtr abstract_add = ToAbstract(add);
  305. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
  306. AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(1), false);
  307. AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};
  308. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  309. AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
  310. auto expected = std::make_shared<PartialAbstractClosure>(
  311. std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
  312. MS_LOG(INFO) << "result: " << res->ToString();
  313. MS_LOG(INFO) << "expected: " << expected->ToString();
  314. ASSERT_TRUE(res->ToString() == expected->ToString());
  315. }
  316. /// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
  317. /// Description: Check inference result of primitive EnvironSet.
  318. /// Expectation: Equal
  319. TEST_F(TestPrim, test_environ_set) {
  320. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  321. AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
  322. AbstractBasePtrList args_spec_list = {abstract_x};
  323. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract();
  324. FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvironSet, 3);
  325. AbstractBasePtr abstract_environ = MakeEnvironAbstract();
  326. AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
  327. args_spec_list = {abstract_environ, embed_x, abstract_y};
  328. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  329. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  330. ASSERT_TRUE(*res == *exp);
  331. }
  332. /// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
  333. /// Description: Check inference result of primitive EnvironGet.
  334. /// Expectation: Equal
  335. TEST_F(TestPrim, test_environ_get) {
  336. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  337. AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
  338. AbstractBasePtrList args_spec_list = {abstract_x};
  339. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract();
  340. FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3);
  341. AbstractBasePtr abstract_environ = MakeEnvironAbstract();
  342. AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
  343. args_spec_list = {abstract_environ, embed_x, abstract_y};
  344. AbstractBasePtr res = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract();
  345. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  346. ASSERT_TRUE(*res == *exp);
  347. FuncGraphPtr graph_environ_get = MakeFuncGraph(prim::kPrimEnvironGet, 3);
  348. AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
  349. args_spec_list = {res, embed_x, abstract_z};
  350. res = engine_->Run(graph_environ_get, args_spec_list).eval_result->abstract();
  351. ASSERT_TRUE(*res == *abstract_x);
  352. }
  353. /// Feature: Check inference result of primitive by build a FuncGraph with single primitive.
  354. /// Description: Check inference result of primitive EnvironAdd.
  355. /// Expectation: Equal
  356. TEST_F(TestPrim, test_environ_add) {
  357. FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
  358. AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
  359. AbstractBasePtrList args_spec_list = {abstract_x};
  360. AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract();
  361. FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3);
  362. AbstractBasePtr abstract_environ = MakeEnvironAbstract();
  363. AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
  364. args_spec_list = {abstract_environ, embed_x, abstract_y};
  365. AbstractBasePtr abstract_e1 = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract();
  366. AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  367. ASSERT_TRUE(*abstract_e1 == *exp);
  368. AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
  369. args_spec_list = {abstract_environ, embed_x, abstract_z};
  370. AbstractBasePtr abstract_e2 = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract();
  371. ASSERT_TRUE(*abstract_e2 == *exp);
  372. FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvironAdd, 2);
  373. args_spec_list = {abstract_e1, abstract_e2};
  374. AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).eval_result->abstract();
  375. ASSERT_TRUE(*res == *exp);
  376. }
  377. TEST_F(TestPrim, test_relu) {
  378. PrimitivePtr relu = prim::kPrimRelu;
  379. relu->AddAttr("T", MakeValue(static_cast<int64_t>(kNumberTypeFloat64)));
  380. FuncGraphPtr func_graph = MakeFuncGraph(relu, 1);
  381. AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW
  382. AbstractBasePtrList args_spec_list = {expected};
  383. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  384. ASSERT_TRUE(*res == *expected);
  385. }
  386. /*
  387. TEST_F(TestPrim, test_relu2) {
  388. FuncGraphPtr func_graph = getPyFun("get_relu");
  389. ASSERT_TRUE(func_graph != nullptr);
  390. auto arr = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
  391. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
  392. AbstractBasePtrList args_spec_list = {arr};
  393. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  394. auto res = dyn_cast<AbstractTensor>(ret);
  395. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  396. }
  397. TEST_F(TestPrim, test_conv2d1) {
  398. std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
  399. py::tuple kernel_size(2);
  400. kernel_size[0] = 5;
  401. kernel_size[1] = 5;
  402. std::shared_ptr<FuncGraph> func_graph = getPyFun.CallAndParseRet("test_conv2d", 64, kernel_size, 0, 2, 1);
  403. // NCHW
  404. std::vector<int64_t> inputs_dims = {2, 20, 32, 32};
  405. std::vector<int64_t> weight_dims = {64, 20, 5, 5};
  406. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  407. inputs->set_data_type(kNumberTypeInt32);
  408. inputs->set_shape(inputs_dims);
  409. // Cout, Cin, kernel_size
  410. tensor::TensorPtr weight = std::make_shared<tensor::Tensor>();
  411. weight->set_data_type(kNumberTypeInt32);
  412. weight->set_shape(weight_dims);
  413. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  414. AbstractBasePtr abstract_weight = FromValue(weight, true);
  415. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_weight};
  416. AbstractBasePtr expected = abstract_inputs->Clone();
  417. // NCHW
  418. std::vector<int64_t> shape = {2, 64, 14, 14};
  419. expected->set_shape(std::make_shared<Shape>(shape));
  420. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  421. MS_LOG(INFO) << "result: " << res->ToString();
  422. MS_LOG(INFO) << "expected: " << expected->ToString();
  423. auto res_ptr = dyn_cast<AbstractTensor>(res);
  424. auto expected_ptr = dyn_cast<AbstractTensor>(expected);
  425. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  426. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  427. }
  428. TEST_F(TestPrim, test_conv2d) {
  429. FuncGraphPtr func_graph = getPyFun("get_conv2d");
  430. ASSERT_TRUE(func_graph != nullptr);
  431. auto input = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  432. auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});
  433. AbstractBasePtrList args_spec_list = {input, weight};
  434. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  435. auto res = dyn_cast<AbstractTensor>(ret);
  436. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
  437. MS_LOG(INFO) << "result: " << res->ToString();
  438. MS_LOG(INFO) << "expected: " << expected->ToString();
  439. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  440. }
  441. TEST_F(TestPrim, test_conv2d_native) {
  442. FuncGraphPtr func_graph = getPyFun("get_conv2d_native");
  443. ASSERT_TRUE(func_graph != nullptr);
  444. auto input = ArrayOfTensor(UTPrimUtils::kF64, {10, 32, 32, 32});
  445. auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});
  446. AbstractBasePtrList args_spec_list = {input, weight};
  447. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  448. auto res = dyn_cast<AbstractTensor>(ret);
  449. auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
  450. MS_LOG(INFO) << "result: " << res->ToString();
  451. MS_LOG(INFO) << "expected: " << expected->ToString();
  452. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  453. }
  454. TEST_F(TestPrim, test_biasAdd) {
  455. FuncGraphPtr func_graph = getPyFun("get_bias_add");
  456. ASSERT_TRUE(func_graph != nullptr);
  457. auto value = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  458. auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});
  459. AbstractBasePtrList args_spec_list = {value, bias};
  460. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  461. auto res = dyn_cast<AbstractTensor>(ret);
  462. auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
  463. MS_LOG(INFO) << "result: " << res->ToString();
  464. MS_LOG(INFO) << "expected: " << expected->ToString();
  465. ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
  466. }
  467. TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) {
  468. FuncGraphPtr func_graph = getPyFun("get_softmax_cross_entropy_with_logits");
  469. ASSERT_TRUE(func_graph != nullptr);
  470. auto logits = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  471. auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  472. AbstractBasePtrList args_spec_list = {logits, labels};
  473. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  474. ASSERT_NE(ret, nullptr);
  475. auto res = dyn_cast<AbstractTuple>(ret);
  476. auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
  477. auto dLogits = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
  478. AbstractBasePtrList expected_list = {loss, dLogits};
  479. auto expected = std::make_shared<AbstractTuple>(expected_list);
  480. MS_LOG(INFO) << "result: " << res->ToString();
  481. MS_LOG(INFO) << "expected: " << expected->ToString();
  482. auto res_ptr0 = dyn_cast<AbstractTuple>(res);
  483. auto expected_ptr0 = dyn_cast<AbstractTuple>(expected);
  484. ASSERT_GT((*res_ptr0).size(), 1);
  485. auto res_ptr = dyn_cast<AbstractTensor>((*res_ptr0)[1]);
  486. ASSERT_GT((*expected_ptr0).size(), 1);
  487. auto expected_ptr = dyn_cast<AbstractTensor>((*expected_ptr0)[1]);
  488. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  489. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  490. }
  491. TEST_F(TestPrim, test_tensor_to_scalar_prim) {
  492. FuncGraphPtr func_graph = getPyFun("get_tensor_to_scalar");
  493. ASSERT_TRUE(func_graph != nullptr);
  494. auto logits = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
  495. auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
  496. AbstractBasePtrList args_spec_list = {logits, labels};
  497. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  498. auto res = dyn_cast<AbstractScalar>(ret);
  499. AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
  500. expected->set_type(UTPrimUtils::kF64);
  501. MS_LOG(INFO) << "result: " << res->ToString();
  502. MS_LOG(INFO) << "expected: " << expected->ToString();
  503. ASSERT_TRUE(*res == *expected);
  504. }
  505. TEST_F(TestPrim, test_pooling) {
  506. PrimitivePtr pooling = prim::kPrimPooling;
  507. pooling->AddAttr("mode", MakeValue(std::string("avg")));
  508. pooling->AddAttr("pad_mode", MakeValue(std::string("valid")));
  509. pooling->AddAttr("nan_opt", MakeValue(0));
  510. pooling->AddAttr("window", MakeValue(2));
  511. pooling->AddAttr("pad", MakeValue(1));
  512. pooling->AddAttr("stride", MakeValue(1));
  513. pooling->AddAttr("data_mode", MakeValue(1));
  514. pooling->AddAttr("ceil_mode", MakeValue(0));
  515. FuncGraphPtr func_graph = MakeFuncGraph(pooling, 1);
  516. std::vector<int64_t> inputs_dims = {8, 64, 3, 3};
  517. auto inputs = std::make_shared<tensor::Tensor>();
  518. inputs->set_data_type(kNumberTypeFloat32);
  519. inputs->set_shape(inputs_dims);
  520. AbstractBasePtr abstract_input = FromValue(inputs, false);
  521. AbstractBasePtrList args_spec_list = {abstract_input};
  522. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  523. AbstractBasePtr expected = abstract_input->Clone()->Broaden();
  524. std::vector<int64_t> expected_dims = {8, 64, 2, 2};
  525. expected->set_shape(std::make_shared<Shape>(expected_dims));
  526. MS_LOG(INFO) << "result: " << res->ToString();
  527. MS_LOG(INFO) << "expected: " << expected->ToString();
  528. ASSERT_TRUE(*res == *expected);
  529. }
  530. TEST_F(TestPrim, test_hastype) {
  531. AbstractBasePtrList args_spec_list;
  532. int64_t v1 = 1;
  533. TypePtr v2 = std::make_shared<Number>();
  534. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  535. AbstractTypePtr abstract_v2 = UTPrimUtils::TypeToAbstract(v2);
  536. AbstractBasePtr expected = FromValue(true, false);
  537. args_spec_list.push_back(abstract_v1);
  538. args_spec_list.push_back(abstract_v2);
  539. auto prim = std::make_shared<Primitive>("hastype");
  540. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  541. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  542. ASSERT_TRUE(*res == *expected);
  543. }
  544. TEST_F(TestPrim, test_array_len) {
  545. AbstractBasePtrList args_spec_list;
  546. auto v1 = UTPrimUtils::ArrayFloat64Of({3, 4, 0, 2});
  547. auto expected = std::make_shared<AbstractScalar>(kAnyValue, kInt32);
  548. args_spec_list.push_back(v1);
  549. auto prim = std::make_shared<Primitive>("array_len");
  550. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  551. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  552. ASSERT_TRUE(*res == *expected);
  553. }
  554. TEST_F(TestPrim, test_list_len) {
  555. AbstractBasePtrList args_spec_list;
  556. auto v1 = UTPrimUtils::ListShapeOf({3, 4, 0, 2});
  557. auto expected = std::make_shared<AbstractScalar>(4);
  558. args_spec_list.push_back(v1);
  559. auto prim = std::make_shared<Primitive>("list_len");
  560. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  561. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  562. ASSERT_TRUE(*res == *expected);
  563. }
  564. TEST_F(TestPrim, test_tuple_len) {
  565. AbstractBasePtrList args_spec_list;
  566. auto v1 = UTPrimUtils::ShapeOf({3, 4, 0, 2});
  567. auto expected = std::make_shared<AbstractScalar>(4);
  568. args_spec_list.push_back(v1);
  569. auto prim = std::make_shared<Primitive>("tuple_len");
  570. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  571. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  572. ASSERT_TRUE(*res == *expected);
  573. }
  574. TEST_F(TestPrim, test_tuple_reversed) {
  575. AbstractBasePtrList args_spec_list;
  576. auto v1 = UTPrimUtils::ShapeOf({0, 1, 2, 3});
  577. auto expected = UTPrimUtils::ShapeOf({3, 2, 1, 0});
  578. args_spec_list.push_back(v1);
  579. auto prim = std::make_shared<Primitive>("tuple_reversed");
  580. FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
  581. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  582. MS_LOG(INFO) << "expect=" << expected->ToString();
  583. ASSERT_TRUE(*res == *expected);
  584. }
  585. TEST_F(TestPrim, test_list_getitem) {
  586. AbstractBasePtrList args_spec_list;
  587. int64_t v1 = 2;
  588. int64_t v2 = 1;
  589. AbstractBasePtr elem = FromValue(v1, false);
  590. AbstractBasePtr elem2 = FromValue(v2, false);
  591. AbstractBasePtrList elems = {elem, elem};
  592. auto abstract_v1 = std::make_shared<AbstractList>(elems);
  593. AbstractBasePtr abstract_v2 = FromValue(v2, false);
  594. args_spec_list.push_back(abstract_v1);
  595. args_spec_list.push_back(abstract_v2);
  596. auto prim = std::make_shared<Primitive>("list_getitem");
  597. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  598. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  599. ASSERT_TRUE(*res == *elem);
  600. }
  601. TEST_F(TestPrim, test_list_setitem) {
  602. int64_t v1 = 1;
  603. int64_t v2 = 2;
  604. AbstractBasePtr elem1 = FromValue(v1, false);
  605. AbstractBasePtr elem2 = FromValue(v2, false);
  606. AbstractBasePtrList elems = {elem1, elem1};
  607. auto abstract_tuple = std::make_shared<AbstractList>(elems);
  608. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  609. AbstractBasePtr abstract_v3 = FromValue(v2, false);
  610. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2, abstract_v3};
  611. auto prim = std::make_shared<Primitive>("list_setitem");
  612. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  613. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  614. MS_LOG(INFO) << "result: " << res->ToString();
  615. AbstractBasePtrList elems_exp = {elem1, elem2};
  616. auto expected = std::make_shared<AbstractList>(elems_exp);
  617. MS_LOG(INFO) << "expected: " << expected->ToString();
  618. auto res_list = dyn_cast<AbstractList>(res);
  619. ASSERT_TRUE(*expected == *res_list);
  620. }
  621. TEST_F(TestPrim, test_list_append) {
  622. int64_t v1 = 1;
  623. AbstractBasePtr elem1 = FromValue(v1, false);
  624. AbstractBasePtr elem2 = FromValue(v1, false);
  625. auto abstract_tuple = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
  626. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  627. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2};
  628. auto prim = std::make_shared<Primitive>("list_append");
  629. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  630. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  631. MS_LOG(INFO) << "result: " << res->ToString();
  632. auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
  633. MS_LOG(INFO) << "expected: " << expected->ToString();
  634. auto res_list = dyn_cast<AbstractList>(res);
  635. ASSERT_TRUE(*res_list == *expected);
  636. }
  637. TEST_F(TestPrim, test_tuple_setitem) {
  638. int64_t v1 = 1;
  639. int64_t v2 = 2;
  640. AbstractBasePtr elem1 = FromValue(v1, false);
  641. AbstractBasePtr elem2 = FromValue(v2, false);
  642. AbstractBasePtrList elems = {elem1, elem1};
  643. auto abstract_tuple = std::make_shared<AbstractTuple>(elems);
  644. AbstractBasePtr abstract_v2 = FromValue(v1, false);
  645. AbstractBasePtr abstract_v3 = FromValue(v2, false);
  646. AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2, abstract_v3};
  647. auto prim = std::make_shared<Primitive>("tuple_setitem");
  648. FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
  649. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  650. MS_LOG(INFO) << "result: " << res->ToString();
  651. AbstractBasePtrList elems_exp = {elem1, elem2};
  652. auto expected = std::make_shared<AbstractTuple>(elems_exp);
  653. MS_LOG(INFO) << "expected: " << expected->ToString();
  654. auto res_tuple = dyn_cast<AbstractTuple>(res);
  655. ASSERT_TRUE(*res == *expected);
  656. }
  657. TEST_F(TestPrim, test_make_list) {
  658. AbstractBasePtrList args_spec_list;
  659. int64_t v1 = 2;
  660. int64_t v2 = 2;
  661. AbstractBasePtr abstract_v1 = FromValue(v1, false);
  662. AbstractBasePtr abstract_v2 = FromValue(v2, false);
  663. auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
  664. args_spec_list.push_back(abstract_v1);
  665. args_spec_list.push_back(abstract_v2);
  666. auto prim = std::make_shared<Primitive>("make_list");
  667. FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
  668. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  669. ASSERT_TRUE(*res == *expected);
  670. }
  671. TEST_F(TestPrim, test_make_range) {
  672. AbstractBasePtrList args_spec_list;
  673. int64_t v1 = 1;
  674. int64_t v2 = 4;
  675. AbstractBasePtr abstract_v1 = FromValue(v1);
  676. AbstractBasePtr abstract_v2 = FromValue(v2);
  677. args_spec_list.push_back(abstract_v1);
  678. args_spec_list.push_back(abstract_v2);
  679. auto prim = std::make_shared<Primitive>("make_range");
  680. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(prim, 2);
  681. AbstractBasePtr ele1 = FromValue(1);
  682. AbstractBasePtr ele2 = FromValue(2);
  683. AbstractBasePtr ele3 = FromValue(3);
  684. AbstractBasePtrList elem_list({ele1, ele2, ele3});
  685. AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);
  686. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  687. MS_LOG(INFO) << "res=" << res->ToString();
  688. MS_LOG(INFO) << "expected=" << expected->ToString();
  689. ASSERT_TRUE(*res == *expected);
  690. }
  691. TEST_F(TestPrim, test_layernorm) {
  692. PrimitivePtr layerNorm = prim::kPrimLayerNorm;
  693. layerNorm->AddAttr("begin_norm_axis", MakeValue(1));
  694. layerNorm->AddAttr("begin_params_axis", MakeValue(1));
  695. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(layerNorm, 3);
  696. std::vector<int64_t> inputs_dims = {128, 64, 32, 64};
  697. std::vector<int64_t> mean_var_dims = {128, 64, 32, 1};
  698. std::vector<int64_t> params_dims = {64, 32, 64};
  699. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  700. inputs->set_data_type(kNumberTypeFloat32);
  701. inputs->set_shape(inputs_dims);
  702. tensor::TensorPtr mean_var = std::make_shared<tensor::Tensor>();
  703. mean_var->set_data_type(kNumberTypeFloat32);
  704. mean_var->set_shape(mean_var_dims);
  705. tensor::TensorPtr gamma = std::make_shared<tensor::Tensor>();
  706. gamma->set_data_type(kNumberTypeFloat32);
  707. gamma->set_shape(params_dims);
  708. tensor::TensorPtr beta = std::make_shared<tensor::Tensor>();
  709. beta->set_data_type(kNumberTypeFloat32);
  710. beta->set_shape(params_dims);
  711. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  712. AbstractBasePtr abstract_mean_var = FromValue(mean_var, true);
  713. AbstractBasePtr abstract_gamma = FromValue(gamma, true);
  714. AbstractBasePtr abstract_beta = FromValue(beta, true);
  715. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_gamma, abstract_beta};
  716. AbstractBasePtr expected0 = abstract_inputs->Clone();
  717. AbstractBasePtr expected1 = abstract_mean_var->Clone();
  718. AbstractBasePtr expected2 = abstract_mean_var->Clone();
  719. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  720. MS_LOG(INFO) << "result: " << res->ToString();
  721. MS_LOG(INFO) << "expected0: " << expected0->ToString();
  722. MS_LOG(INFO) << "expected1: " << expected1->ToString();
  723. MS_LOG(INFO) << "expected2: " << expected2->ToString();
  724. std::shared_ptr<AbstractTuple> abs_tuple = dyn_cast<AbstractTuple>(res);
  725. ASSERT_TRUE(abs_tuple != nullptr);
  726. auto res_ptr0 = dyn_cast<AbstractTensor>(abs_tuple->elements()[0]);
  727. auto expected_ptr0 = dyn_cast<AbstractTensor>(expected0);
  728. ASSERT_TRUE(*res_ptr0->shape() == *expected_ptr0->shape());
  729. ASSERT_TRUE(*res_ptr0->element() == *expected_ptr0->element());
  730. auto res_ptr1 = dyn_cast<AbstractTensor>(abs_tuple->elements()[1]);
  731. auto expected_ptr1 = dyn_cast<AbstractTensor>(expected1);
  732. ASSERT_TRUE(*res_ptr1->shape() == *expected_ptr1->shape());
  733. ASSERT_TRUE(*res_ptr1->element() == *expected_ptr1->element());
  734. auto res_ptr2 = dyn_cast<AbstractTensor>(abs_tuple->elements()[2]);
  735. auto expected_ptr2 = dyn_cast<AbstractTensor>(expected2);
  736. ASSERT_TRUE(*res_ptr2->shape() == *expected_ptr2->shape());
  737. ASSERT_TRUE(*res_ptr2->element() == *expected_ptr2->element());
  738. }
  739. TEST_F(TestPrim, test_DropoutGenMask) {
  740. AbstractBasePtrList args_spec_list;
  741. auto arg0 = UTPrimUtils::ShapeOf({5, 5, 5, 5});
  742. std::vector<int64_t> keep_prob_shape = {};
  743. tensor::TensorPtr keep_prob = std::make_shared<tensor::Tensor>(0.5f);
  744. keep_prob->set_data_type(kNumberTypeFloat32);
  745. keep_prob->set_shape(keep_prob_shape);
  746. AbstractBasePtr abstract_keep_prob = FromValue(keep_prob);
  747. auto prim = std::make_shared<Primitive>("DropoutGenMask");
  748. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(prim, 2);
  749. args_spec_list.push_back(arg0);
  750. args_spec_list.push_back(abstract_keep_prob);
  751. // should return a tensor with on dimension of 79 elements
  752. AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
  753. std::make_shared<Shape>(std::vector<int64_t>{79}));
  754. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  755. MS_LOG(INFO) << "res=" << res->ToString();
  756. MS_LOG(INFO) << "expected=" << expected->ToString();
  757. ASSERT_TRUE(*res == *expected);
  758. }
  759. TEST_F(TestPrim, test_dropout) {
  760. std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
  761. std::shared_ptr<FuncGraph> func_graph = getPyFun.CallAndParseRet("test_dropout");
  762. std::vector<int64_t> inputs_dims = {2, 20, 32, 32};
  763. tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
  764. inputs->set_data_type(kNumberTypeFloat32);
  765. inputs->set_shape(inputs_dims);
  766. AbstractBasePtr abstract_inputs = FromValue(inputs, true);
  767. std::vector<int64_t> keep_prob_shape = {};
  768. tensor::TensorPtr keep_prob = std::make_shared<tensor::Tensor>(0.5f);
  769. keep_prob->set_data_type(kNumberTypeFloat32);
  770. keep_prob->set_shape(keep_prob_shape);
  771. AbstractBasePtr abstract_keep_prob = FromValue(keep_prob);
  772. AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_keep_prob};
  773. AbstractBasePtr expected = abstract_inputs->Clone();
  774. // NCHW
  775. std::vector<int64_t> shape = {2, 20, 32, 32};
  776. expected->set_shape(std::make_shared<Shape>(shape));
  777. AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  778. MS_LOG(INFO) << "result: " << res->ToString();
  779. MS_LOG(INFO) << "expected: " << expected->ToString();
  780. auto res_ptr = dyn_cast<AbstractTensor>(res);
  781. auto expected_ptr = dyn_cast<AbstractTensor>(expected);
  782. ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
  783. ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
  784. }
  785. TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) {
  786. PrimitivePtr broadcatGradientArgs = prim::kPrimBroadcastGradientArgs;
  787. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(broadcatGradientArgs, 2);
  788. // broadcast shape: x: 8,5,3, y:3
  789. // output: ((),(0, 1))
  790. AbstractBasePtrList x_arg_list({abstract::FromValue(8), abstract::FromValue(5), abstract::FromValue(3)});
  791. AbstractBasePtrList y_arg_list({abstract::FromValue(3)});
  792. auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
  793. auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
  794. AbstractBasePtrList args_spec_list = {x_input, y_input};
  795. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  796. auto res = dyn_cast<AbstractTuple>(ret);
  797. AbstractBasePtrList x_idx_list;
  798. auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
  799. AbstractBasePtrList y_idx_list({abstract::FromValue(0), abstract::FromValue(1)});
  800. auto r_y = std::make_shared<AbstractTuple>(y_idx_list);
  801. AbstractBasePtrList elem_list({r_x, r_y});
  802. auto expected = std::make_shared<AbstractTuple>(elem_list);
  803. MS_LOG(INFO) << "result: " << res->ToString();
  804. MS_LOG(INFO) << "expected: " << expected->ToString();
  805. ASSERT_TRUE(*res == *expected);
  806. }
  807. TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) {
  808. PrimitivePtr broadcatGradientArgs = prim::kPrimBroadcastGradientArgs;
  809. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(broadcatGradientArgs, 2);
  810. // broadcast shape: x: 8,1,3, y:8 5 3
  811. // output: ((1),())
  812. AbstractBasePtrList x_arg_list({abstract::FromValue(8), abstract::FromValue(1), abstract::FromValue(3)});
  813. AbstractBasePtrList y_arg_list({abstract::FromValue(8), abstract::FromValue(5), abstract::FromValue(3)});
  814. auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
  815. auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
  816. AbstractBasePtrList args_spec_list = {x_input, y_input};
  817. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  818. auto res = dyn_cast<AbstractTuple>(ret);
  819. AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
  820. auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
  821. AbstractBasePtrList y_idx_list;
  822. auto r_y = std::make_shared<AbstractTuple>(y_idx_list);
  823. AbstractBasePtrList elem_list({r_x, r_y});
  824. auto expected = std::make_shared<AbstractTuple>(elem_list);
  825. MS_LOG(INFO) << "result: " << res->ToString();
  826. MS_LOG(INFO) << "expected: " << expected->ToString();
  827. ASSERT_TRUE(*res == *expected);
  828. }
  829. TEST_F(TestPrim, test_DictGetItem) {
  830. PrimitivePtr dictGetItem = prim::kPrimDictGetItem;
  831. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(dictGetItem, 2);
  832. std::vector<std::pair<std::string, ValuePtr>> tensor_map = {
  833. {"x", std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>{2, 3, 4})},
  834. {"y", std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>{2, 1, 4})}};
  835. ValueDictionary value_dict(tensor_map);
  836. AbstractBasePtr array_dict = value_dict.ToAbstract();
  837. AbstractBasePtr key = abstract::FromValue("x");
  838. AbstractBasePtrList args_spec_list = {array_dict, key};
  839. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  840. AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
  841. AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));
  842. ASSERT_TRUE(*tensor_ret == *expect);
  843. }
  844. TEST_F(TestPrim, test_DictGetItem2) {
  845. PrimitivePtr dictGetItem = prim::kPrimDictGetItem;
  846. std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(dictGetItem, 2);
  847. AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5});
  848. AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5});
  849. AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5});
  850. std::vector<AbstractAttribute> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
  851. AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map);
  852. AbstractBasePtr key = abstract::FromValue("x");
  853. AbstractBasePtrList args_spec_list = {array_dict, key};
  854. AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
  855. AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
  856. AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);
  857. ASSERT_TRUE(*tensor_ret == *expect);
  858. }
  859. */
  860. } // namespace abstract
  861. } // namespace mindspore