You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_test.cc 31 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <unordered_map>
  18. #include "pybind11/pybind11.h"
  19. #include "transform/transform_base_test.h"
  20. #include "common/py_func_graph_fetcher.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "debug/draw.h"
  23. #include "debug/anf_ir_dump.h"
  24. #include "pipeline/jit/static_analysis/prim.h"
  25. #include "frontend/operator/ops.h"
  26. #include "common/common_test.h"
  27. #define private public
  28. #include "transform/graph_ir/types.h"
  29. #include "transform/graph_ir/convert.h"
  30. #include "securec/include/securec.h"
  31. #include "utils/utils.h"
  32. using std::cout;
  33. using std::endl;
  34. using std::string;
  35. using std::unordered_map;
  36. namespace mindspore {
  37. namespace transform {
  38. using AbstractScalar = abstract::AbstractScalar;
  39. using mindspore::parse::ResolveAll;
  40. class TestConvert : public UT::Common {
  41. public:
  42. TestConvert() {}
  43. virtual void SetUp();
  44. virtual void TearDown();
  45. static const std::shared_ptr<Float> kF32;
  46. };
  47. void TestConvert::SetUp() { UT::InitPythonPath(); }
  48. void TestConvert::TearDown() {}
  49. const std::shared_ptr<Float> TestConvert::kF32 = std::make_shared<Float>(32);
  50. AnfGraphPtr createAnfGraph() { return std::make_shared<AnfGraph>(); }
  51. TEST_F(TestConvert, TestConstruct) {
  52. AnfGraphPtr func_graph = std::make_shared<AnfGraph>();
  53. DfGraphConvertor converter(func_graph);
  54. converter.ConvertAllNode().GetComputeGraph();
  55. ASSERT_NE(converter.ErrCode(), SUCCESS);
  56. }
  57. namespace {
  58. bool MakeDfGraph(PrimitivePtr prim, unsigned int nparam) {
  59. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, nparam);
  60. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  61. DfGraphConvertor converter(anf_graph);
  62. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  63. if (converter.ErrCode() != 0) {
  64. MS_LOG(ERROR) << "DfGraphConvertor convert " << prim->name() << " error, error code is: " << converter.ErrCode();
  65. return false;
  66. }
  67. if (df_graph == nullptr) {
  68. MS_LOG(ERROR) << "DfGraphConvertor get " << prim->name() << " compute func_graph failed";
  69. return false;
  70. }
  71. return true;
  72. }
  73. } // namespace
  74. TEST_F(TestConvert, TestConvertConv2d) {
  75. PrimitivePtr conv2d = prim::kPrimConv2D;
  76. conv2d->AddAttr("stride", MakeValue(static_cast<int64_t>(2)));
  77. conv2d->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  78. conv2d->AddAttr("dilation", MakeValue(static_cast<int64_t>(0)));
  79. FuncGraphPtr anf_graph = MakeFuncGraph(conv2d, 2);
  80. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  81. DfGraphConvertor converter(anf_graph);
  82. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  83. ASSERT_EQ(converter.ErrCode(), 0);
  84. ASSERT_NE(df_graph, nullptr);
  85. }
  86. TEST_F(TestConvert, TestConvertMaxpooling) {
  87. auto prim = std::make_shared<Primitive>("MaxPool");
  88. FuncGraphPtr anf_graph = MakeFuncGraph(prim, 5); // ary, ksize, stride, padding, data_format
  89. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  90. DfGraphConvertor converter(anf_graph);
  91. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  92. ASSERT_EQ(converter.ErrCode(), 0);
  93. ASSERT_NE(df_graph, nullptr);
  94. }
  95. TEST_F(TestConvert, TestReluOps) {
  96. auto prim = prim::kPrimRelu;
  97. prim->AddAttr("T", MakeValue(static_cast<int64_t>(0)));
  98. auto func_graph = MakeFuncGraph(prim, 1);
  99. ASSERT_TRUE(nullptr != func_graph);
  100. // save the func_graph to manager
  101. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  102. // call resolve
  103. bool ret_ = ResolveAll(manager);
  104. ASSERT_TRUE(ret_);
  105. // draw graph
  106. auto anfGraph = *(manager->func_graphs().begin());
  107. DfGraphConvertor converter(anfGraph);
  108. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  109. ASSERT_EQ(converter.ErrCode(), 0);
  110. }
  111. TEST_F(TestConvert, TestConvertBatchNorm) {
  112. PrimitivePtr batch_norm = prim::kPrimBatchNorm;
  113. batch_norm->AddAttr("epsilon", MakeValue(0.001f));
  114. batch_norm->AddAttr("momentum", MakeValue(0.1f));
  115. FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
  116. std::vector<AnfNodePtr> inputs;
  117. inputs.push_back(NewValueNode(batch_norm));
  118. for (unsigned int i = 0; i < 5; i++) {
  119. inputs.push_back(anf_graph->add_parameter());
  120. }
  121. CNodePtr cnode_prim = anf_graph->NewCNode(inputs);
  122. inputs.clear();
  123. inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
  124. inputs.push_back(cnode_prim);
  125. inputs.push_back(NewValueNode(static_cast<int64_t>(2)));
  126. CNodePtr cnode_getitem = anf_graph->NewCNode(inputs);
  127. inputs.clear();
  128. inputs.push_back(NewValueNode(prim::kPrimRelu));
  129. inputs.push_back(cnode_getitem);
  130. CNodePtr cnode_relu = anf_graph->NewCNode(inputs);
  131. inputs.clear();
  132. inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
  133. inputs.push_back(cnode_relu);
  134. CNodePtr cnode_return = anf_graph->NewCNode(inputs);
  135. anf_graph->set_return(cnode_return);
  136. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  137. DfGraphConvertor converter(anf_graph);
  138. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  139. ASSERT_EQ(converter.ErrCode(), 0);
  140. ASSERT_NE(df_graph, nullptr);
  141. }
  142. TEST_F(TestConvert, TestConvertConvBackpropInput) {
  143. auto prim = prim::kPrimConv2DBackpropInput;
  144. const std::vector<int64_t> list{1,1};
  145. prim->AddAttr("stride", MakeValue(list));
  146. prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  147. prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
  148. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  149. prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
  150. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  151. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  152. auto func_graph = MakeFuncGraph(prim, 3);
  153. ASSERT_NE(func_graph, nullptr);
  154. // save the func_graph to manager
  155. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  156. // call resolve
  157. bool ret_ = ResolveAll(manager);
  158. ASSERT_TRUE(ret_);
  159. // draw graph
  160. auto anf_graph = *(manager->func_graphs().begin());
  161. DfGraphConvertor converter(anf_graph);
  162. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  163. ASSERT_EQ(converter.ErrCode(), 0);
  164. ASSERT_NE(df_graph, nullptr);
  165. }
  166. TEST_F(TestConvert, TestConvertConvBackpropFilter) {
  167. auto prim = prim::kPrimConv2DBackpropFilter;
  168. const std::vector<int64_t> list{1,1};
  169. prim->AddAttr("stride", MakeValue(list));
  170. prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  171. prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
  172. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  173. prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
  174. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  175. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  176. auto func_graph = MakeFuncGraph(prim, 3);
  177. ASSERT_NE(func_graph, nullptr);
  178. // save the func_graph to manager
  179. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  180. // call resolve
  181. bool ret_ = ResolveAll(manager);
  182. ASSERT_TRUE(ret_);
  183. // draw graph
  184. auto anf_graph = *(manager->func_graphs().begin());
  185. DfGraphConvertor converter(anf_graph);
  186. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  187. ASSERT_EQ(converter.ErrCode(), 0);
  188. ASSERT_NE(df_graph, nullptr);
  189. }
  190. TEST_F(TestConvert, TestConvertReluGrad) {
  191. auto prim = prim::kPrimReluGrad;
  192. prim->AddAttr("alpha", MakeValue(0.1f));
  193. prim->AddAttr("beta", MakeValue(0.1f));
  194. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  195. auto func_graph = MakeFuncGraph(prim, 2);
  196. ASSERT_NE(func_graph, nullptr);
  197. // save the func_graph to manager
  198. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  199. // call resolve
  200. bool ret_ = ResolveAll(manager);
  201. ASSERT_TRUE(ret_);
  202. // draw graph
  203. auto anf_graph = *(manager->func_graphs().begin());
  204. DfGraphConvertor converter(anf_graph);
  205. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  206. ASSERT_EQ(converter.ErrCode(), 0);
  207. ASSERT_NE(df_graph, nullptr);
  208. }
  209. TEST_F(TestConvert, TestConvertBiasAdd) {
  210. auto prim = std::make_shared<Primitive>("BiasAdd");
  211. prim->AddAttr("alpha", MakeValue(0.0f));
  212. prim->AddAttr("beta", MakeValue(1.0f));
  213. auto func_graph = MakeFuncGraph(prim, 2);
  214. ASSERT_NE(func_graph, nullptr);
  215. // save the func_graph to manager
  216. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  217. // call resolve
  218. bool ret_ = ResolveAll(manager);
  219. ASSERT_TRUE(ret_);
  220. // draw graph
  221. auto anf_graph = *(manager->func_graphs().begin());
  222. DfGraphConvertor converter(anf_graph);
  223. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  224. ASSERT_EQ(converter.ErrCode(), 0);
  225. ASSERT_NE(df_graph, nullptr);
  226. }
  227. TEST_F(TestConvert, TestConvertBiasAddGrad) {
  228. auto prim = prim::kPrimBiasAddGrad;
  229. prim->AddAttr("alpha", MakeValue(0.0f));
  230. prim->AddAttr("beta", MakeValue(1.0f));
  231. auto func_graph = MakeFuncGraph(prim, 2);
  232. ASSERT_NE(func_graph, nullptr);
  233. // save the func_graph to manager
  234. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  235. // call resolve
  236. bool ret_ = ResolveAll(manager);
  237. ASSERT_TRUE(ret_);
  238. // draw graph
  239. auto anf_graph = *(manager->func_graphs().begin());
  240. DfGraphConvertor converter(anf_graph);
  241. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  242. ASSERT_EQ(converter.ErrCode(), 0);
  243. ASSERT_NE(df_graph, nullptr);
  244. }
  245. TEST_F(TestConvert, TestConvertMaxPoolGradWithArgmax) {
  246. auto prim = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
  247. prim->AddAttr("alpha", MakeValue(0.0f));
  248. prim->AddAttr("beta", MakeValue(1.0f));
  249. prim->AddAttr("window", MakeValue(static_cast<int64_t>(2)));
  250. prim->AddAttr("stride", MakeValue(static_cast<int64_t>(1)));
  251. prim->AddAttr("ceil_mode", MakeValue(static_cast<int64_t>(0)));
  252. prim->AddAttr("data_mode", MakeValue(static_cast<int64_t>(0)));
  253. prim->AddAttr("alpha", MakeValue(0.1f));
  254. prim->AddAttr("beta", MakeValue(1.0f));
  255. auto func_graph = MakeFuncGraph(prim, 2);
  256. ASSERT_NE(func_graph, nullptr);
  257. // save the func_graph to manager
  258. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  259. // call resolve
  260. bool ret_ = ResolveAll(manager);
  261. ASSERT_TRUE(ret_);
  262. // draw graph
  263. auto anf_graph = *(manager->func_graphs().begin());
  264. DfGraphConvertor converter(anf_graph);
  265. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  266. ASSERT_EQ(converter.ErrCode(), 0);
  267. ASSERT_NE(df_graph, nullptr);
  268. }
  269. TEST_F(TestConvert, TestConcat) {
  270. auto prim = prim::kPrimConcat;
  271. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  272. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  273. DfGraphConvertor converter(anf_graph);
  274. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  275. ASSERT_EQ(converter.ErrCode(), 0);
  276. ASSERT_NE(df_graph, nullptr);
  277. }
  278. TEST_F(TestConvert, TestGatherV2) {
  279. auto prim = prim::kPrimGather;
  280. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 3);
  281. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  282. DfGraphConvertor converter(anf_graph);
  283. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  284. ASSERT_EQ(converter.ErrCode(), 0);
  285. ASSERT_NE(df_graph, nullptr);
  286. }
  287. TEST_F(TestConvert, TestCast) {
  288. auto prim = prim::kPrimCast;
  289. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  290. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  291. DfGraphConvertor converter(anf_graph);
  292. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  293. ASSERT_EQ(converter.ErrCode(), 0);
  294. ASSERT_NE(df_graph, nullptr);
  295. }
  296. TEST_F(TestConvert, TestExp) {
  297. auto prim = std::make_shared<Primitive>("Exp");
  298. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  299. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  300. DfGraphConvertor converter(anf_graph);
  301. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  302. ASSERT_EQ(converter.ErrCode(), 0);
  303. ASSERT_NE(df_graph, nullptr);
  304. }
  305. TEST_F(TestConvert, TestFloor) {
  306. auto prim = std::make_shared<Primitive>("Floor");
  307. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  308. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  309. DfGraphConvertor converter(anf_graph);
  310. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  311. ASSERT_EQ(converter.ErrCode(), 0);
  312. ASSERT_NE(df_graph, nullptr);
  313. }
  314. TEST_F(TestConvert, TestGreaterEqual) {
  315. auto prim = std::make_shared<Primitive>("GreaterEqual");
  316. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  317. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  318. DfGraphConvertor converter(anf_graph);
  319. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  320. ASSERT_EQ(converter.ErrCode(), 0);
  321. ASSERT_NE(df_graph, nullptr);
  322. }
  323. TEST_F(TestConvert, TestLess) {
  324. auto prim = std::make_shared<Primitive>("Less");
  325. prim->AddAttr("T", MakeValue(kFloat32));
  326. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  327. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  328. DfGraphConvertor converter(anf_graph);
  329. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  330. ASSERT_EQ(converter.ErrCode(), 0);
  331. ASSERT_NE(df_graph, nullptr);
  332. }
  333. TEST_F(TestConvert, TestLessEqual) {
  334. auto prim = std::make_shared<Primitive>("LessEqual");
  335. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  336. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  337. DfGraphConvertor converter(anf_graph);
  338. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  339. ASSERT_EQ(converter.ErrCode(), 0);
  340. ASSERT_NE(df_graph, nullptr);
  341. }
  342. TEST_F(TestConvert, TestLogicalNot) {
  343. auto prim = std::make_shared<Primitive>("LogicalNot");
  344. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  345. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  346. DfGraphConvertor converter(anf_graph);
  347. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  348. ASSERT_EQ(converter.ErrCode(), 0);
  349. ASSERT_NE(df_graph, nullptr);
  350. }
  351. TEST_F(TestConvert, TestAssignAdd) {
  352. auto prim = prim::kPrimAssignAdd;
  353. prim->AddAttr("use_locking", MakeValue(true));
  354. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  355. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  356. DfGraphConvertor converter(anf_graph);
  357. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  358. ASSERT_EQ(converter.ErrCode(), 0);
  359. ASSERT_NE(df_graph, nullptr);
  360. }
  361. TEST_F(TestConvert, LogSoftmax) {
  362. auto prim = prim::kPrimLogSoftmax;
  363. prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
  364. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  365. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  366. DfGraphConvertor converter(anf_graph);
  367. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  368. ASSERT_EQ(converter.ErrCode(), 0);
  369. ASSERT_NE(df_graph, nullptr);
  370. }
  371. TEST_F(TestConvert, TestMaximumOps) {
  372. auto prim = prim::kPrimMaximum;
  373. bool ret = MakeDfGraph(prim, 2);
  374. ASSERT_TRUE(ret);
  375. }
  376. TEST_F(TestConvert, TestReduceMeanOps) {
  377. auto prim = prim::kPrimReduceMean;
  378. prim->AddAttr("keepdims", MakeValue(true));
  379. bool ret = MakeDfGraph(prim, 2);
  380. ASSERT_TRUE(ret);
  381. }
  382. TEST_F(TestConvert, TestMinimumOps) {
  383. auto prim = prim::kPrimMinimum;
  384. bool ret = MakeDfGraph(prim, 2);
  385. ASSERT_TRUE(ret);
  386. }
  387. TEST_F(TestConvert, TestFusedMinOrMaxGradOps) {
  388. // Add infer step to this test case
  389. ASSERT_TRUE(true);
  390. }
  391. TEST_F(TestConvert, TestSqueezeOps) {
  392. auto prim = prim::kPrimSqueeze;
  393. bool ret = MakeDfGraph(prim, 2);
  394. ASSERT_TRUE(ret);
  395. }
  396. TEST_F(TestConvert, TestMulOps) {
  397. auto prim = prim::kPrimMul;
  398. bool ret = MakeDfGraph(prim, 2);
  399. ASSERT_TRUE(ret);
  400. }
  401. TEST_F(TestConvert, TestNegOps) {
  402. auto prim = prim::kPrimNeg;
  403. bool ret = MakeDfGraph(prim, 1);
  404. ASSERT_TRUE(ret);
  405. }
  406. TEST_F(TestConvert, TestOneHotOps) {
  407. auto prim = prim::kPrimOneHot;
  408. prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
  409. bool ret = MakeDfGraph(prim, 4);
  410. ASSERT_TRUE(ret);
  411. }
  412. TEST_F(TestConvert, TestPowOps) {
  413. auto prim = std::make_shared<Primitive>("Pow");
  414. bool ret = MakeDfGraph(prim, 2);
  415. ASSERT_TRUE(ret);
  416. }
  417. TEST_F(TestConvert, TestReciprocalOps) {
  418. auto prim = std::make_shared<Primitive>("Reciprocal");
  419. bool ret = MakeDfGraph(prim, 1);
  420. ASSERT_TRUE(ret);
  421. }
  422. TEST_F(TestConvert, TestSelectOps) {
  423. auto prim = prim::kPrimSelect;
  424. bool ret = MakeDfGraph(prim, 3);
  425. ASSERT_TRUE(ret);
  426. }
  427. TEST_F(TestConvert, TestSqrtOps) {
  428. auto prim = std::make_shared<Primitive>("Sqrt");
  429. bool ret = MakeDfGraph(prim, 1);
  430. ASSERT_TRUE(ret);
  431. }
  432. TEST_F(TestConvert, TestSquareOps) {
  433. auto prim = std::make_shared<Primitive>("Square");
  434. bool ret = MakeDfGraph(prim, 1);
  435. ASSERT_TRUE(ret);
  436. }
  437. #ifndef ENABLE_SECURITY
  438. TEST_F(TestConvert, TestScalarSummaryOps) {
  439. auto prim = prim::kPrimScalarSummary;
  440. // should have only 1 input.
  441. bool ret = MakeDfGraph(prim, 2);
  442. ASSERT_TRUE(ret);
  443. }
  444. TEST_F(TestConvert, TestTensorSummaryOps) {
  445. auto prim = prim::kPrimTensorSummary;
  446. bool ret = MakeDfGraph(prim, 2);
  447. ASSERT_TRUE(ret);
  448. }
  449. TEST_F(TestConvert, TestHistogramSummaryOps) {
  450. auto prim = prim::kPrimHistogramSummary;
  451. bool ret = MakeDfGraph(prim, 2);
  452. ASSERT_TRUE(ret);
  453. }
  454. #endif
  455. TEST_F(TestConvert, TestGreaterOps) {
  456. auto prim = std::make_shared<Primitive>("Greater");
  457. bool ret = MakeDfGraph(prim, 2);
  458. ASSERT_TRUE(ret);
  459. }
  460. TEST_F(TestConvert, TestEqualOps) {
  461. auto prim = std::make_shared<Primitive>("Equal");
  462. bool ret = MakeDfGraph(prim, 2);
  463. ASSERT_TRUE(ret);
  464. }
  465. TEST_F(TestConvert, TestArgMaxiOps) {
  466. auto prim = std::make_shared<Primitive>("Argmax");
  467. bool ret = MakeDfGraph(prim, 2);
  468. ASSERT_TRUE(ret);
  469. }
  470. TEST_F(TestConvert, TestResizeNearestNeighborOps) {
  471. auto prim = std::make_shared<Primitive>("ResizeNearestNeighbor");
  472. bool ret = MakeDfGraph(prim, 1);
  473. ASSERT_TRUE(ret);
  474. }
  475. TEST_F(TestConvert, TestApplyMomentumOps) {
  476. auto prim = std::make_shared<Primitive>("ApplyMomentum");
  477. bool ret = MakeDfGraph(prim, 5);
  478. ASSERT_TRUE(ret);
  479. }
  480. TEST_F(TestConvert, TestNPUGetFloatStatusOps) {
  481. auto prim = std::make_shared<Primitive>("NPUGetFloatStatus");
  482. bool ret = MakeDfGraph(prim, 1);
  483. ASSERT_TRUE(ret);
  484. }
  485. TEST_F(TestConvert, TestNPUAllocFloatStatusOps) {
  486. auto prim = std::make_shared<Primitive>("NPUAllocFloatStatus");
  487. bool ret = MakeDfGraph(prim, 0);
  488. ASSERT_TRUE(ret);
  489. }
  490. TEST_F(TestConvert, TestNPUClearFloatStatusOps) {
  491. auto prim = std::make_shared<Primitive>("NPUClearFloatStatus");
  492. bool ret = MakeDfGraph(prim, 1);
  493. ASSERT_TRUE(ret);
  494. }
  495. TEST_F(TestConvert, TestAddOps) {
  496. auto prim = std::make_shared<Primitive>("Add");
  497. auto func_graph = MakeFuncGraph(prim, 2);
  498. ASSERT_TRUE(nullptr != func_graph);
  499. // save the func_graph to manager
  500. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  501. // call resolve
  502. bool ret_ = ResolveAll(manager);
  503. ASSERT_TRUE(ret_);
  504. // draw graph
  505. auto anfGraph = *(manager->func_graphs().begin());
  506. DfGraphConvertor converter(anfGraph);
  507. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  508. ASSERT_EQ(converter.ErrCode(), 0);
  509. }
  510. TEST_F(TestConvert, TestConvertTensor) {
  511. float data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
  512. // Create a tensor with wanted data type and shape
  513. std::vector<int64_t> dims{2, 2, 3};
  514. std::vector<int64_t> ge_dims{2, 2, 3};
  515. auto type_id = kNumberTypeFloat32;
  516. MeTensor me_tensor(type_id, dims);
  517. // Get the writable data pointer of the tensor and cast it to its data type
  518. uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c());
  519. // Copy or use the writable data pointer of the ME tensor
  520. memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float));
  521. auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor);
  522. auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW);
  523. ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetFormat(), GeFormat::FORMAT_NCHW);
  524. ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetDataType(), GeDataType::DT_FLOAT);
  525. // ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().array().GetDims(), ge_dims);
  526. int i = 0;
  527. for (i = 0; i < ge_dims.size(); i++) {
  528. ASSERT_EQ(ge_dims[i], ge_tensor_ptr->GetTensorDesc().GetShape().GetDims()[i]);
  529. }
  530. for (i = 0; i < ge_tensor_ptr->GetTensorDesc().GetShape().GetShapeSize(); i++) {
  531. ASSERT_EQ(data[i], (reinterpret_cast<float*>(ge_tensor_ptr->GetData()))[i]);
  532. }
  533. }
  534. TEST_F(TestConvert, TestConvertTensor0Dims) {
  535. // shape with 0 dims is also valid
  536. std::vector<int64_t> dims{};
  537. auto type_id = kNumberTypeFloat32;
  538. auto me_tensor_ptr = std::make_shared<MeTensor>(type_id, dims);
  539. ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW), nullptr);
  540. }
  541. TEST_F(TestConvert, TestConvertTensorError) {
  542. std::vector<int64_t> dims2{2, 3, 4};
  543. auto type_id_2 = kNumberTypeFloat32;
  544. auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
  545. ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
  546. }
  547. TEST_F(TestConvert, TestUtilsConvertDataType) {
  548. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat16), GeDataType::DT_FLOAT16);
  549. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat32), GeDataType::DT_FLOAT);
  550. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat64), GeDataType::DT_DOUBLE);
  551. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt8), GeDataType::DT_INT8);
  552. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt16), GeDataType::DT_INT16);
  553. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt32), GeDataType::DT_INT32);
  554. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt64), GeDataType::DT_INT64);
  555. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeUInt32), GeDataType::DT_UINT32);
  556. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeBool), GeDataType::DT_BOOL);
  557. }
  558. TEST_F(TestConvert, TestUtilsConvertFormat) {
  559. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NCHW), GeFormat::FORMAT_NCHW);
  560. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NC1HWC0), GeFormat::FORMAT_NC1HWC0);
  561. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NHWC), GeFormat::FORMAT_NHWC);
  562. ASSERT_EQ(TransformUtil::ConvertFormat("xyz"), GeFormat::FORMAT_ND);
  563. }
  564. TEST_F(TestConvert, TestUtilsDataSize) {
  565. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat32), 4);
  566. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat16), 2);
  567. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat64), 8);
  568. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt8), 1);
  569. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt16), 2);
  570. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt32), 4);
  571. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt64), 8);
  572. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeUInt32), 4);
  573. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeBool), 1);
  574. }
  575. TEST_F(TestConvert, TestConvertGeTensor) {
  576. #define DTYPE float
  577. ge::DataType dt = ge::DataType::DT_FLOAT;
  578. std::vector<float> data1 = {1.1, 2.2, 3.3, 4.4, 6.6, 7.7, 8.8, 9.9};
  579. std::vector<DTYPE> data2 = {1, 2, 3, 4, 6, 7, 8, 9};
  580. auto data = data1;
  581. ge::Shape shape({2, 2, 2});
  582. ge::Format format = ge::Format::FORMAT_NCHW;
  583. ge::TensorDesc desc(shape, format, dt);
  584. GeTensorPtr ge_tensor_ptr =
  585. std::make_shared<GeTensor>(desc, reinterpret_cast<uint8_t*>(data.data()), data.size() * sizeof(DTYPE));
  586. GeTensor& ge_tensor = *ge_tensor_ptr;
  587. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensor.GetData());
  588. // make sure GetData()'s return is a reference
  589. assert(ge_data == reinterpret_cast<DTYPE*>(ge_tensor.GetData()));
  590. cout << "ge data size is: " << std::dec << ge_tensor.GetSize() << " bytes" << endl;
  591. for (int i = 0; i < ge_tensor.GetSize() / sizeof(DTYPE); i++) {
  592. cout << "ge data is: " << static_cast<DTYPE>(*(ge_data + i)) << endl;
  593. }
  594. MeTensorPtr me_tensor_ptr = TransformUtil::ConvertGeTensor(ge_tensor_ptr);
  595. MeTensor& me_tensor = *me_tensor_ptr;
  596. cout << "after convert ge tensor to me tensor" << endl;
  597. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_tensor.data_c());
  598. PrintMeTensor(&me_tensor);
  599. assert(ge_tensor.GetSize() == me_tensor.data().nbytes());
  600. assert(memcmp(ge_data, me_data, ge_tensor.GetSize()) == 0);
  601. }
  602. TEST_F(TestConvert, TestConvertMakeTuple) {
  603. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  604. std::vector<AnfNodePtr> inputs;
  605. inputs.push_back(NewValueNode(std::make_shared<Primitive>("MakeTuple")));
  606. for (int i = 0; i < 3; i++) {
  607. auto input = func_graph->add_parameter();
  608. input->set_name("x" + std::to_string(i));
  609. inputs.push_back(input);
  610. }
  611. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  612. inputs.clear();
  613. inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
  614. inputs.push_back(cnode_prim);
  615. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  616. func_graph->set_return(cnode_return);
  617. // save the func_graph to manager
  618. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  619. // call resolve
  620. bool ret_ = ResolveAll(manager);
  621. ASSERT_TRUE(ret_);
  622. // draw graph
  623. auto anfGraph = *(manager->func_graphs().begin());
  624. DfGraphConvertor converter(anfGraph);
  625. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  626. ASSERT_EQ(converter.ErrCode(), 0);
  627. }
  628. TEST_F(TestConvert, TestConvertInputTensors) {
  629. #define DTYPE float
  630. std::initializer_list<int64_t> list0 = {1, 1, 4, 4};
  631. std::initializer_list<int64_t> list1 = {2, 3, 4, 5};
  632. std::initializer_list<int64_t> list2 = {9, 9, 1, 1};
  633. MeTensorPtr input_ptr1 = MakeTensor(kF32, list0);
  634. MeTensorPtr input_ptr2 = MakeTensor(kF32, list1);
  635. MeTensorPtr input_ptr3 = MakeTensor(kF32, list2);
  636. std::vector<MeTensorPtr> me_inputs;
  637. me_inputs.emplace_back(input_ptr1);
  638. me_inputs.emplace_back(input_ptr2);
  639. me_inputs.emplace_back(input_ptr3);
  640. std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(me_inputs, kOpFormat_NCHW);
  641. for (int i = 0; i < ge_tensors.size(); i++) {
  642. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_inputs[i]->data_c());
  643. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
  644. ASSERT_TRUE(ge_tensors[i]->GetSize() == me_inputs[i]->data().nbytes());
  645. ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
  646. ASSERT_TRUE(ge_tensors[i]->GetTensorDesc().GetShape().GetDims() ==
  647. TransformUtil::ConvertMeShape(me_inputs[i]->shape_c()).GetDims());
  648. }
  649. }
  650. TEST_F(TestConvert, TestConvertGeTensors) {
  651. #define DTYPE float
  652. ge::DataType dt = ge::DataType::DT_FLOAT;
  653. std::vector<float> data1(16);
  654. std::vector<float> data2(120);
  655. std::vector<float> data3(81);
  656. ge::Shape shape1({1, 1, 4, 4});
  657. ge::Shape shape2({2, 3, 4, 5});
  658. ge::Shape shape3({9, 9, 1, 1});
  659. ge::Format format = ge::Format::FORMAT_NCHW;
  660. ge::TensorDesc desc1(shape1, format, dt);
  661. ge::TensorDesc desc2(shape2, format, dt);
  662. ge::TensorDesc desc3(shape3, format, dt);
  663. GeTensorPtr ge_tensor_ptr1 =
  664. std::make_shared<GeTensor>(desc1, reinterpret_cast<uint8_t*>(data1.data()), data1.size() * sizeof(DTYPE));
  665. GeTensorPtr ge_tensor_ptr2 =
  666. std::make_shared<GeTensor>(desc2, reinterpret_cast<uint8_t*>(data2.data()), data2.size() * sizeof(DTYPE));
  667. GeTensorPtr ge_tensor_ptr3 =
  668. std::make_shared<GeTensor>(desc3, reinterpret_cast<uint8_t*>(data3.data()), data3.size() * sizeof(DTYPE));
  669. std::vector<GeTensorPtr> ge_tensors;
  670. ge_tensors.emplace_back(ge_tensor_ptr1);
  671. ge_tensors.emplace_back(ge_tensor_ptr2);
  672. ge_tensors.emplace_back(ge_tensor_ptr3);
  673. std::vector<std::vector<int64_t>> request_dims;
  674. std::vector<int64_t> dims1 = {1, 1, 4, 4};
  675. std::vector<int64_t> dims2 = {2, 3, 4, 5};
  676. std::vector<int64_t> dims3 = {9, 9, 1, 1};
  677. request_dims.emplace_back(dims1);
  678. request_dims.emplace_back(dims2);
  679. request_dims.emplace_back(dims3);
  680. std::vector<MeTensorPtr> me_outputs = TransformUtil::ConvertGeTensors(ge_tensors, request_dims);
  681. for (int i = 0; i < ge_tensors.size(); i++) {
  682. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_outputs[i]->data_c());
  683. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
  684. ASSERT_TRUE(ge_tensors[i]->GetSize() == me_outputs[i]->data().nbytes());
  685. ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
  686. ASSERT_TRUE(request_dims[i] == me_outputs[i]->shape_c());
  687. }
  688. }
  689. TEST_F(TestConvert, TestConvertGeShape1) {
  690. GeShape ge_shape({10, 1, 1, 1});
  691. std::vector<int64_t> request_dims{10};
  692. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  693. }
  694. TEST_F(TestConvert, TestConvertGeShape2) {
  695. GeShape ge_shape({10, 15, 1, 1});
  696. std::vector<int64_t> request_dims{10, 15};
  697. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  698. }
  699. TEST_F(TestConvert, TestConvertGeShape3) {
  700. GeShape ge_shape({10, 13, 18, 1});
  701. std::vector<int64_t> request_dims{10, 13, 18};
  702. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  703. }
  704. TEST_F(TestConvert, TestConvertGeShape4) {
  705. GeShape ge_shape({1, 10, 1, 1});
  706. std::vector<int64_t> request_dims{10};
  707. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  708. }
  709. TEST_F(TestConvert, TestConvertGeShape5) {
  710. GeShape ge_shape({10, 1, 1, 2});
  711. std::vector<int64_t> request_dims{10};
  712. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  713. }
  714. TEST_F(TestConvert, TestConvertGeShape6) {
  715. GeShape ge_shape({5, 2, 1, 1});
  716. std::vector<int64_t> request_dims{10};
  717. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  718. }
  719. TEST_F(TestConvert, TestConvertGeShape7) {
  720. GeShape ge_shape({10});
  721. std::vector<int64_t> request_dims{10, 1};
  722. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  723. }
  724. } // namespace transform
  725. } // namespace mindspore