You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_test.cc 31 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <unordered_map>
  18. #include "pybind11/pybind11.h"
  19. #include "transform/transform_base_test.h"
  20. #include "common/py_func_graph_fetcher.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "debug/draw.h"
  23. #include "debug/anf_ir_dump.h"
  24. #include "pipeline/jit/static_analysis/prim.h"
  25. #include "frontend/operator/ops.h"
  26. #include "common/common_test.h"
  27. #define private public
  28. #include "transform/graph_ir/types.h"
  29. #include "transform/graph_ir/convert.h"
  30. #include "securec/include/securec.h"
  31. #include "utils/utils.h"
  32. using std::cout;
  33. using std::endl;
  34. using std::string;
  35. using std::unordered_map;
  36. namespace mindspore {
  37. namespace transform {
  38. using AbstractScalar = abstract::AbstractScalar;
  39. using mindspore::parse::ResolveAll;
  40. class TestConvert : public UT::Common {
  41. public:
  42. TestConvert() {}
  43. virtual void SetUp();
  44. virtual void TearDown();
  45. static const std::shared_ptr<Float> kF32;
  46. };
  47. void TestConvert::SetUp() { UT::InitPythonPath(); }
  48. void TestConvert::TearDown() {}
  49. const std::shared_ptr<Float> TestConvert::kF32 = std::make_shared<Float>(32);
  50. AnfGraphPtr createAnfGraph() { return std::make_shared<AnfGraph>(); }
  51. TEST_F(TestConvert, TestConstruct) {
  52. AnfGraphPtr func_graph = std::make_shared<AnfGraph>();
  53. DfGraphConvertor converter(func_graph);
  54. converter.ConvertAllNode().GetComputeGraph();
  55. ASSERT_NE(converter.ErrCode(), SUCCESS);
  56. }
  57. #if (!defined ENABLE_GE)
  58. namespace {
  59. bool MakeDfGraph(PrimitivePtr prim, unsigned int nparam) {
  60. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, nparam);
  61. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  62. DfGraphConvertor converter(anf_graph);
  63. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  64. if (converter.ErrCode() != 0) {
  65. MS_LOG(ERROR) << "DfGraphConvertor convert " << prim->name() << " error, error code is: " << converter.ErrCode();
  66. return false;
  67. }
  68. if (df_graph == nullptr) {
  69. MS_LOG(ERROR) << "DfGraphConvertor get " << prim->name() << " compute func_graph failed";
  70. return false;
  71. }
  72. return true;
  73. }
  74. } // namespace
  75. TEST_F(TestConvert, TestConvertConv2d) {
  76. PrimitivePtr conv2d = prim::kPrimConv2D;
  77. conv2d->AddAttr("stride", MakeValue(static_cast<int64_t>(2)));
  78. conv2d->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  79. conv2d->AddAttr("dilation", MakeValue(static_cast<int64_t>(0)));
  80. FuncGraphPtr anf_graph = MakeFuncGraph(conv2d, 2);
  81. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  82. DfGraphConvertor converter(anf_graph);
  83. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  84. ASSERT_EQ(converter.ErrCode(), 0);
  85. ASSERT_NE(df_graph, nullptr);
  86. }
  87. TEST_F(TestConvert, TestConvertMaxpooling) {
  88. auto prim = std::make_shared<Primitive>("MaxPool");
  89. FuncGraphPtr anf_graph = MakeFuncGraph(prim, 5); // ary, ksize, stride, padding, data_format
  90. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  91. DfGraphConvertor converter(anf_graph);
  92. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  93. ASSERT_EQ(converter.ErrCode(), 0);
  94. ASSERT_NE(df_graph, nullptr);
  95. }
  96. TEST_F(TestConvert, TestReluOps) {
  97. auto prim = prim::kPrimRelu;
  98. prim->AddAttr("T", MakeValue(static_cast<int64_t>(0)));
  99. auto func_graph = MakeFuncGraph(prim, 1);
  100. ASSERT_TRUE(nullptr != func_graph);
  101. // save the func_graph to manager
  102. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  103. // call resolve
  104. bool ret_ = ResolveAll(manager);
  105. ASSERT_TRUE(ret_);
  106. // draw graph
  107. auto anfGraph = *(manager->func_graphs().begin());
  108. DfGraphConvertor converter(anfGraph);
  109. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  110. ASSERT_EQ(converter.ErrCode(), 0);
  111. }
  112. TEST_F(TestConvert, TestConvertBatchNorm) {
  113. PrimitivePtr batch_norm = prim::kPrimBatchNorm;
  114. batch_norm->AddAttr("epsilon", MakeValue(0.001f));
  115. batch_norm->AddAttr("momentum", MakeValue(0.1f));
  116. FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
  117. std::vector<AnfNodePtr> inputs;
  118. inputs.push_back(NewValueNode(batch_norm));
  119. for (unsigned int i = 0; i < 5; i++) {
  120. inputs.push_back(anf_graph->add_parameter());
  121. }
  122. CNodePtr cnode_prim = anf_graph->NewCNode(inputs);
  123. inputs.clear();
  124. inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
  125. inputs.push_back(cnode_prim);
  126. inputs.push_back(NewValueNode(static_cast<int64_t>(2)));
  127. CNodePtr cnode_getitem = anf_graph->NewCNode(inputs);
  128. inputs.clear();
  129. inputs.push_back(NewValueNode(prim::kPrimRelu));
  130. inputs.push_back(cnode_getitem);
  131. CNodePtr cnode_relu = anf_graph->NewCNode(inputs);
  132. inputs.clear();
  133. inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
  134. inputs.push_back(cnode_relu);
  135. CNodePtr cnode_return = anf_graph->NewCNode(inputs);
  136. anf_graph->set_return(cnode_return);
  137. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  138. DfGraphConvertor converter(anf_graph);
  139. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  140. ASSERT_EQ(converter.ErrCode(), 0);
  141. ASSERT_NE(df_graph, nullptr);
  142. }
  143. TEST_F(TestConvert, TestConvertConvBackpropInput) {
  144. auto prim = prim::kPrimConv2DBackpropInput;
  145. const std::vector<int64_t> list{1,1};
  146. prim->AddAttr("stride", MakeValue(list));
  147. prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  148. prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
  149. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  150. prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
  151. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  152. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  153. auto func_graph = MakeFuncGraph(prim, 3);
  154. ASSERT_NE(func_graph, nullptr);
  155. // save the func_graph to manager
  156. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  157. // call resolve
  158. bool ret_ = ResolveAll(manager);
  159. ASSERT_TRUE(ret_);
  160. // draw graph
  161. auto anf_graph = *(manager->func_graphs().begin());
  162. DfGraphConvertor converter(anf_graph);
  163. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  164. ASSERT_EQ(converter.ErrCode(), 0);
  165. ASSERT_NE(df_graph, nullptr);
  166. }
  167. TEST_F(TestConvert, TestConvertConvBackpropFilter) {
  168. auto prim = prim::kPrimConv2DBackpropFilter;
  169. const std::vector<int64_t> list{1,1};
  170. prim->AddAttr("stride", MakeValue(list));
  171. prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  172. prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
  173. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  174. prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
  175. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  176. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  177. auto func_graph = MakeFuncGraph(prim, 3);
  178. ASSERT_NE(func_graph, nullptr);
  179. // save the func_graph to manager
  180. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  181. // call resolve
  182. bool ret_ = ResolveAll(manager);
  183. ASSERT_TRUE(ret_);
  184. // draw graph
  185. auto anf_graph = *(manager->func_graphs().begin());
  186. DfGraphConvertor converter(anf_graph);
  187. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  188. ASSERT_EQ(converter.ErrCode(), 0);
  189. ASSERT_NE(df_graph, nullptr);
  190. }
  191. TEST_F(TestConvert, TestConvertReluGrad) {
  192. auto prim = prim::kPrimReluGrad;
  193. prim->AddAttr("alpha", MakeValue(0.1f));
  194. prim->AddAttr("beta", MakeValue(0.1f));
  195. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  196. auto func_graph = MakeFuncGraph(prim, 2);
  197. ASSERT_NE(func_graph, nullptr);
  198. // save the func_graph to manager
  199. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  200. // call resolve
  201. bool ret_ = ResolveAll(manager);
  202. ASSERT_TRUE(ret_);
  203. // draw graph
  204. auto anf_graph = *(manager->func_graphs().begin());
  205. DfGraphConvertor converter(anf_graph);
  206. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  207. ASSERT_EQ(converter.ErrCode(), 0);
  208. ASSERT_NE(df_graph, nullptr);
  209. }
  210. TEST_F(TestConvert, TestConvertBiasAdd) {
  211. auto prim = std::make_shared<Primitive>("BiasAdd");
  212. prim->AddAttr("alpha", MakeValue(0.0f));
  213. prim->AddAttr("beta", MakeValue(1.0f));
  214. auto func_graph = MakeFuncGraph(prim, 2);
  215. ASSERT_NE(func_graph, nullptr);
  216. // save the func_graph to manager
  217. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  218. // call resolve
  219. bool ret_ = ResolveAll(manager);
  220. ASSERT_TRUE(ret_);
  221. // draw graph
  222. auto anf_graph = *(manager->func_graphs().begin());
  223. DfGraphConvertor converter(anf_graph);
  224. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  225. ASSERT_EQ(converter.ErrCode(), 0);
  226. ASSERT_NE(df_graph, nullptr);
  227. }
  228. TEST_F(TestConvert, TestConvertBiasAddGrad) {
  229. auto prim = prim::kPrimBiasAddGrad;
  230. prim->AddAttr("alpha", MakeValue(0.0f));
  231. prim->AddAttr("beta", MakeValue(1.0f));
  232. auto func_graph = MakeFuncGraph(prim, 2);
  233. ASSERT_NE(func_graph, nullptr);
  234. // save the func_graph to manager
  235. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  236. // call resolve
  237. bool ret_ = ResolveAll(manager);
  238. ASSERT_TRUE(ret_);
  239. // draw graph
  240. auto anf_graph = *(manager->func_graphs().begin());
  241. DfGraphConvertor converter(anf_graph);
  242. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  243. ASSERT_EQ(converter.ErrCode(), 0);
  244. ASSERT_NE(df_graph, nullptr);
  245. }
  246. TEST_F(TestConvert, TestConvertMaxPoolGradWithArgmax) {
  247. auto prim = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
  248. prim->AddAttr("alpha", MakeValue(0.0f));
  249. prim->AddAttr("beta", MakeValue(1.0f));
  250. prim->AddAttr("window", MakeValue(static_cast<int64_t>(2)));
  251. prim->AddAttr("stride", MakeValue(static_cast<int64_t>(1)));
  252. prim->AddAttr("ceil_mode", MakeValue(static_cast<int64_t>(0)));
  253. prim->AddAttr("data_mode", MakeValue(static_cast<int64_t>(0)));
  254. prim->AddAttr("alpha", MakeValue(0.1f));
  255. prim->AddAttr("beta", MakeValue(1.0f));
  256. auto func_graph = MakeFuncGraph(prim, 2);
  257. ASSERT_NE(func_graph, nullptr);
  258. // save the func_graph to manager
  259. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  260. // call resolve
  261. bool ret_ = ResolveAll(manager);
  262. ASSERT_TRUE(ret_);
  263. // draw graph
  264. auto anf_graph = *(manager->func_graphs().begin());
  265. DfGraphConvertor converter(anf_graph);
  266. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  267. ASSERT_EQ(converter.ErrCode(), 0);
  268. ASSERT_NE(df_graph, nullptr);
  269. }
  270. TEST_F(TestConvert, TestConcat) {
  271. auto prim = prim::kPrimConcat;
  272. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  273. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  274. DfGraphConvertor converter(anf_graph);
  275. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  276. ASSERT_EQ(converter.ErrCode(), 0);
  277. ASSERT_NE(df_graph, nullptr);
  278. }
  279. TEST_F(TestConvert, TestGatherV2) {
  280. auto prim = prim::kPrimGather;
  281. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 3);
  282. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  283. DfGraphConvertor converter(anf_graph);
  284. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  285. ASSERT_EQ(converter.ErrCode(), 0);
  286. ASSERT_NE(df_graph, nullptr);
  287. }
  288. TEST_F(TestConvert, TestCast) {
  289. auto prim = prim::kPrimCast;
  290. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  291. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  292. DfGraphConvertor converter(anf_graph);
  293. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  294. ASSERT_EQ(converter.ErrCode(), 0);
  295. ASSERT_NE(df_graph, nullptr);
  296. }
  297. TEST_F(TestConvert, TestExp) {
  298. auto prim = std::make_shared<Primitive>("Exp");
  299. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  300. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  301. DfGraphConvertor converter(anf_graph);
  302. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  303. ASSERT_EQ(converter.ErrCode(), 0);
  304. ASSERT_NE(df_graph, nullptr);
  305. }
  306. TEST_F(TestConvert, TestFloor) {
  307. auto prim = std::make_shared<Primitive>("Floor");
  308. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  309. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  310. DfGraphConvertor converter(anf_graph);
  311. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  312. ASSERT_EQ(converter.ErrCode(), 0);
  313. ASSERT_NE(df_graph, nullptr);
  314. }
  315. TEST_F(TestConvert, TestGreaterEqual) {
  316. auto prim = std::make_shared<Primitive>("GreaterEqual");
  317. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  318. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  319. DfGraphConvertor converter(anf_graph);
  320. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  321. ASSERT_EQ(converter.ErrCode(), 0);
  322. ASSERT_NE(df_graph, nullptr);
  323. }
  324. TEST_F(TestConvert, TestLess) {
  325. auto prim = std::make_shared<Primitive>("Less");
  326. prim->AddAttr("T", MakeValue(kFloat32));
  327. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  328. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  329. DfGraphConvertor converter(anf_graph);
  330. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  331. ASSERT_EQ(converter.ErrCode(), 0);
  332. ASSERT_NE(df_graph, nullptr);
  333. }
  334. TEST_F(TestConvert, TestLessEqual) {
  335. auto prim = std::make_shared<Primitive>("LessEqual");
  336. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  337. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  338. DfGraphConvertor converter(anf_graph);
  339. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  340. ASSERT_EQ(converter.ErrCode(), 0);
  341. ASSERT_NE(df_graph, nullptr);
  342. }
  343. TEST_F(TestConvert, TestLogicalNot) {
  344. auto prim = std::make_shared<Primitive>("LogicalNot");
  345. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  346. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  347. DfGraphConvertor converter(anf_graph);
  348. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  349. ASSERT_EQ(converter.ErrCode(), 0);
  350. ASSERT_NE(df_graph, nullptr);
  351. }
  352. TEST_F(TestConvert, TestAssignAdd) {
  353. auto prim = prim::kPrimAssignAdd;
  354. prim->AddAttr("use_locking", MakeValue(true));
  355. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  356. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  357. DfGraphConvertor converter(anf_graph);
  358. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  359. ASSERT_EQ(converter.ErrCode(), 0);
  360. ASSERT_NE(df_graph, nullptr);
  361. }
  362. TEST_F(TestConvert, LogSoftmax) {
  363. auto prim = prim::kPrimLogSoftmax;
  364. prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
  365. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  366. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  367. DfGraphConvertor converter(anf_graph);
  368. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  369. ASSERT_EQ(converter.ErrCode(), 0);
  370. ASSERT_NE(df_graph, nullptr);
  371. }
  372. TEST_F(TestConvert, TestMaximumOps) {
  373. auto prim = prim::kPrimMaximum;
  374. bool ret = MakeDfGraph(prim, 2);
  375. ASSERT_TRUE(ret);
  376. }
  377. TEST_F(TestConvert, TestReduceMeanOps) {
  378. auto prim = prim::kPrimReduceMean;
  379. prim->AddAttr("keepdims", MakeValue(true));
  380. bool ret = MakeDfGraph(prim, 2);
  381. ASSERT_TRUE(ret);
  382. }
  383. TEST_F(TestConvert, TestMinimumOps) {
  384. auto prim = prim::kPrimMinimum;
  385. bool ret = MakeDfGraph(prim, 2);
  386. ASSERT_TRUE(ret);
  387. }
  388. TEST_F(TestConvert, TestFusedMinOrMaxGradOps) {
  389. // Add infer step to this test case
  390. ASSERT_TRUE(true);
  391. }
  392. TEST_F(TestConvert, TestSqueezeOps) {
  393. auto prim = prim::kPrimSqueeze;
  394. bool ret = MakeDfGraph(prim, 2);
  395. ASSERT_TRUE(ret);
  396. }
  397. TEST_F(TestConvert, TestMulOps) {
  398. auto prim = prim::kPrimMul;
  399. bool ret = MakeDfGraph(prim, 2);
  400. ASSERT_TRUE(ret);
  401. }
  402. TEST_F(TestConvert, TestNegOps) {
  403. auto prim = prim::kPrimNeg;
  404. bool ret = MakeDfGraph(prim, 1);
  405. ASSERT_TRUE(ret);
  406. }
  407. TEST_F(TestConvert, TestOneHotOps) {
  408. auto prim = prim::kPrimOneHot;
  409. prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
  410. bool ret = MakeDfGraph(prim, 4);
  411. ASSERT_TRUE(ret);
  412. }
  413. TEST_F(TestConvert, TestPowOps) {
  414. auto prim = std::make_shared<Primitive>("Pow");
  415. bool ret = MakeDfGraph(prim, 2);
  416. ASSERT_TRUE(ret);
  417. }
  418. TEST_F(TestConvert, TestReciprocalOps) {
  419. auto prim = std::make_shared<Primitive>("Reciprocal");
  420. bool ret = MakeDfGraph(prim, 1);
  421. ASSERT_TRUE(ret);
  422. }
  423. TEST_F(TestConvert, TestSelectOps) {
  424. auto prim = prim::kPrimSelect;
  425. bool ret = MakeDfGraph(prim, 3);
  426. ASSERT_TRUE(ret);
  427. }
  428. TEST_F(TestConvert, TestSqrtOps) {
  429. auto prim = std::make_shared<Primitive>("Sqrt");
  430. bool ret = MakeDfGraph(prim, 1);
  431. ASSERT_TRUE(ret);
  432. }
  433. TEST_F(TestConvert, TestSquareOps) {
  434. auto prim = std::make_shared<Primitive>("Square");
  435. bool ret = MakeDfGraph(prim, 1);
  436. ASSERT_TRUE(ret);
  437. }
  438. #ifndef ENABLE_SECURITY
  439. TEST_F(TestConvert, TestScalarSummaryOps) {
  440. auto prim = prim::kPrimScalarSummary;
  441. // should have only 1 input.
  442. bool ret = MakeDfGraph(prim, 2);
  443. ASSERT_TRUE(ret);
  444. }
  445. TEST_F(TestConvert, TestTensorSummaryOps) {
  446. auto prim = prim::kPrimTensorSummary;
  447. bool ret = MakeDfGraph(prim, 2);
  448. ASSERT_TRUE(ret);
  449. }
  450. TEST_F(TestConvert, TestHistogramSummaryOps) {
  451. auto prim = prim::kPrimHistogramSummary;
  452. bool ret = MakeDfGraph(prim, 2);
  453. ASSERT_TRUE(ret);
  454. }
  455. #endif
  456. TEST_F(TestConvert, TestGreaterOps) {
  457. auto prim = std::make_shared<Primitive>("Greater");
  458. bool ret = MakeDfGraph(prim, 2);
  459. ASSERT_TRUE(ret);
  460. }
  461. TEST_F(TestConvert, TestEqualOps) {
  462. auto prim = std::make_shared<Primitive>("Equal");
  463. bool ret = MakeDfGraph(prim, 2);
  464. ASSERT_TRUE(ret);
  465. }
  466. TEST_F(TestConvert, TestArgMaxiOps) {
  467. auto prim = std::make_shared<Primitive>("Argmax");
  468. bool ret = MakeDfGraph(prim, 2);
  469. ASSERT_TRUE(ret);
  470. }
  471. TEST_F(TestConvert, TestResizeNearestNeighborOps) {
  472. auto prim = std::make_shared<Primitive>("ResizeNearestNeighbor");
  473. bool ret = MakeDfGraph(prim, 1);
  474. ASSERT_TRUE(ret);
  475. }
  476. TEST_F(TestConvert, TestApplyMomentumOps) {
  477. auto prim = std::make_shared<Primitive>("ApplyMomentum");
  478. bool ret = MakeDfGraph(prim, 5);
  479. ASSERT_TRUE(ret);
  480. }
  481. TEST_F(TestConvert, TestNPUGetFloatStatusOps) {
  482. auto prim = std::make_shared<Primitive>("NPUGetFloatStatus");
  483. bool ret = MakeDfGraph(prim, 1);
  484. ASSERT_TRUE(ret);
  485. }
  486. TEST_F(TestConvert, TestNPUAllocFloatStatusOps) {
  487. auto prim = std::make_shared<Primitive>("NPUAllocFloatStatus");
  488. bool ret = MakeDfGraph(prim, 0);
  489. ASSERT_TRUE(ret);
  490. }
  491. TEST_F(TestConvert, TestNPUClearFloatStatusOps) {
  492. auto prim = std::make_shared<Primitive>("NPUClearFloatStatus");
  493. bool ret = MakeDfGraph(prim, 1);
  494. ASSERT_TRUE(ret);
  495. }
  496. #endif
  497. TEST_F(TestConvert, TestAddOps) {
  498. auto prim = std::make_shared<Primitive>("Add");
  499. auto func_graph = MakeFuncGraph(prim, 2);
  500. ASSERT_TRUE(nullptr != func_graph);
  501. // save the func_graph to manager
  502. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  503. // call resolve
  504. bool ret_ = ResolveAll(manager);
  505. ASSERT_TRUE(ret_);
  506. // draw graph
  507. auto anfGraph = *(manager->func_graphs().begin());
  508. DfGraphConvertor converter(anfGraph);
  509. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  510. ASSERT_EQ(converter.ErrCode(), 0);
  511. }
  512. TEST_F(TestConvert, TestConvertTensor) {
  513. float data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
  514. // Create a tensor with wanted data type and shape
  515. std::vector<int64_t> dims{2, 2, 3};
  516. std::vector<int64_t> ge_dims{2, 2, 3};
  517. auto type_id = kNumberTypeFloat32;
  518. MeTensor me_tensor(type_id, dims);
  519. // Get the writable data pointer of the tensor and cast it to its data type
  520. uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c());
  521. // Copy or use the writable data pointer of the ME tensor
  522. memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float));
  523. auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor);
  524. auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW);
  525. ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetFormat(), GeFormat::FORMAT_NCHW);
  526. ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetDataType(), GeDataType::DT_FLOAT);
  527. // ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().array().GetDims(), ge_dims);
  528. int i = 0;
  529. for (i = 0; i < ge_dims.size(); i++) {
  530. ASSERT_EQ(ge_dims[i], ge_tensor_ptr->GetTensorDesc().GetShape().GetDims()[i]);
  531. }
  532. for (i = 0; i < ge_tensor_ptr->GetTensorDesc().GetShape().GetShapeSize(); i++) {
  533. ASSERT_EQ(data[i], (reinterpret_cast<float*>(ge_tensor_ptr->GetData()))[i]);
  534. }
  535. }
  536. TEST_F(TestConvert, TestConvertTensor0Dims) {
  537. // shape with 0 dims is also valid
  538. std::vector<int64_t> dims{};
  539. auto type_id = kNumberTypeFloat32;
  540. auto me_tensor_ptr = std::make_shared<MeTensor>(type_id, dims);
  541. ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW), nullptr);
  542. }
  543. TEST_F(TestConvert, TestConvertTensorError) {
  544. std::vector<int64_t> dims2{2, 3, 4};
  545. auto type_id_2 = kNumberTypeFloat32;
  546. auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
  547. ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
  548. }
  549. TEST_F(TestConvert, TestUtilsConvertDataType) {
  550. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat16), GeDataType::DT_FLOAT16);
  551. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat32), GeDataType::DT_FLOAT);
  552. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat64), GeDataType::DT_DOUBLE);
  553. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt8), GeDataType::DT_INT8);
  554. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt16), GeDataType::DT_INT16);
  555. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt32), GeDataType::DT_INT32);
  556. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt64), GeDataType::DT_INT64);
  557. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeUInt32), GeDataType::DT_UINT32);
  558. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeBool), GeDataType::DT_BOOL);
  559. }
  560. TEST_F(TestConvert, TestUtilsConvertFormat) {
  561. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NCHW), GeFormat::FORMAT_NCHW);
  562. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NC1HWC0), GeFormat::FORMAT_NC1HWC0);
  563. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NHWC), GeFormat::FORMAT_NHWC);
  564. ASSERT_EQ(TransformUtil::ConvertFormat("xyz"), GeFormat::FORMAT_ND);
  565. }
  566. TEST_F(TestConvert, TestUtilsDataSize) {
  567. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat32), 4);
  568. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat16), 2);
  569. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat64), 8);
  570. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt8), 1);
  571. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt16), 2);
  572. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt32), 4);
  573. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt64), 8);
  574. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeUInt32), 4);
  575. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeBool), 1);
  576. }
  577. TEST_F(TestConvert, TestConvertGeTensor) {
  578. #define DTYPE float
  579. ge::DataType dt = ge::DataType::DT_FLOAT;
  580. std::vector<float> data1 = {1.1, 2.2, 3.3, 4.4, 6.6, 7.7, 8.8, 9.9};
  581. std::vector<DTYPE> data2 = {1, 2, 3, 4, 6, 7, 8, 9};
  582. auto data = data1;
  583. ge::Shape shape({2, 2, 2});
  584. ge::Format format = ge::Format::FORMAT_NCHW;
  585. ge::TensorDesc desc(shape, format, dt);
  586. GeTensorPtr ge_tensor_ptr =
  587. std::make_shared<GeTensor>(desc, reinterpret_cast<uint8_t*>(data.data()), data.size() * sizeof(DTYPE));
  588. GeTensor& ge_tensor = *ge_tensor_ptr;
  589. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensor.GetData());
  590. // make sure GetData()'s return is a reference
  591. assert(ge_data == reinterpret_cast<DTYPE*>(ge_tensor.GetData()));
  592. cout << "ge data size is: " << std::dec << ge_tensor.GetSize() << " bytes" << endl;
  593. for (int i = 0; i < ge_tensor.GetSize() / sizeof(DTYPE); i++) {
  594. cout << "ge data is: " << static_cast<DTYPE>(*(ge_data + i)) << endl;
  595. }
  596. MeTensorPtr me_tensor_ptr = TransformUtil::ConvertGeTensor(ge_tensor_ptr);
  597. MeTensor& me_tensor = *me_tensor_ptr;
  598. cout << "after convert ge tensor to me tensor" << endl;
  599. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_tensor.data_c());
  600. PrintMeTensor(&me_tensor);
  601. assert(ge_tensor.GetSize() == me_tensor.data().nbytes());
  602. assert(memcmp(ge_data, me_data, ge_tensor.GetSize()) == 0);
  603. }
  604. TEST_F(TestConvert, TestConvertMakeTuple) {
  605. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  606. std::vector<AnfNodePtr> inputs;
  607. inputs.push_back(NewValueNode(std::make_shared<Primitive>("MakeTuple")));
  608. for (int i = 0; i < 3; i++) {
  609. auto input = func_graph->add_parameter();
  610. input->set_name("x" + std::to_string(i));
  611. inputs.push_back(input);
  612. }
  613. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  614. inputs.clear();
  615. inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
  616. inputs.push_back(cnode_prim);
  617. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  618. func_graph->set_return(cnode_return);
  619. // save the func_graph to manager
  620. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  621. // call resolve
  622. bool ret_ = ResolveAll(manager);
  623. ASSERT_TRUE(ret_);
  624. // draw graph
  625. auto anfGraph = *(manager->func_graphs().begin());
  626. DfGraphConvertor converter(anfGraph);
  627. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  628. ASSERT_EQ(converter.ErrCode(), 0);
  629. }
  630. TEST_F(TestConvert, TestConvertInputTensors) {
  631. #define DTYPE float
  632. std::initializer_list<int64_t> list0 = {1, 1, 4, 4};
  633. std::initializer_list<int64_t> list1 = {2, 3, 4, 5};
  634. std::initializer_list<int64_t> list2 = {9, 9, 1, 1};
  635. MeTensorPtr input_ptr1 = MakeTensor(kF32, list0);
  636. MeTensorPtr input_ptr2 = MakeTensor(kF32, list1);
  637. MeTensorPtr input_ptr3 = MakeTensor(kF32, list2);
  638. std::vector<MeTensorPtr> me_inputs;
  639. me_inputs.emplace_back(input_ptr1);
  640. me_inputs.emplace_back(input_ptr2);
  641. me_inputs.emplace_back(input_ptr3);
  642. std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(me_inputs, kOpFormat_NCHW);
  643. for (int i = 0; i < ge_tensors.size(); i++) {
  644. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_inputs[i]->data_c());
  645. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
  646. ASSERT_TRUE(ge_tensors[i]->GetSize() == me_inputs[i]->data().nbytes());
  647. ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
  648. ASSERT_TRUE(ge_tensors[i]->GetTensorDesc().GetShape().GetDims() ==
  649. TransformUtil::ConvertMeShape(me_inputs[i]->shape_c()).GetDims());
  650. }
  651. }
  652. TEST_F(TestConvert, TestConvertGeTensors) {
  653. #define DTYPE float
  654. ge::DataType dt = ge::DataType::DT_FLOAT;
  655. std::vector<float> data1(16);
  656. std::vector<float> data2(120);
  657. std::vector<float> data3(81);
  658. ge::Shape shape1({1, 1, 4, 4});
  659. ge::Shape shape2({2, 3, 4, 5});
  660. ge::Shape shape3({9, 9, 1, 1});
  661. ge::Format format = ge::Format::FORMAT_NCHW;
  662. ge::TensorDesc desc1(shape1, format, dt);
  663. ge::TensorDesc desc2(shape2, format, dt);
  664. ge::TensorDesc desc3(shape3, format, dt);
  665. GeTensorPtr ge_tensor_ptr1 =
  666. std::make_shared<GeTensor>(desc1, reinterpret_cast<uint8_t*>(data1.data()), data1.size() * sizeof(DTYPE));
  667. GeTensorPtr ge_tensor_ptr2 =
  668. std::make_shared<GeTensor>(desc2, reinterpret_cast<uint8_t*>(data2.data()), data2.size() * sizeof(DTYPE));
  669. GeTensorPtr ge_tensor_ptr3 =
  670. std::make_shared<GeTensor>(desc3, reinterpret_cast<uint8_t*>(data3.data()), data3.size() * sizeof(DTYPE));
  671. std::vector<GeTensorPtr> ge_tensors;
  672. ge_tensors.emplace_back(ge_tensor_ptr1);
  673. ge_tensors.emplace_back(ge_tensor_ptr2);
  674. ge_tensors.emplace_back(ge_tensor_ptr3);
  675. std::vector<std::vector<int64_t>> request_dims;
  676. std::vector<int64_t> dims1 = {1, 1, 4, 4};
  677. std::vector<int64_t> dims2 = {2, 3, 4, 5};
  678. std::vector<int64_t> dims3 = {9, 9, 1, 1};
  679. request_dims.emplace_back(dims1);
  680. request_dims.emplace_back(dims2);
  681. request_dims.emplace_back(dims3);
  682. std::vector<MeTensorPtr> me_outputs = TransformUtil::ConvertGeTensors(ge_tensors, request_dims);
  683. for (int i = 0; i < ge_tensors.size(); i++) {
  684. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_outputs[i]->data_c());
  685. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
  686. ASSERT_TRUE(ge_tensors[i]->GetSize() == me_outputs[i]->data().nbytes());
  687. ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
  688. ASSERT_TRUE(request_dims[i] == me_outputs[i]->shape_c());
  689. }
  690. }
  691. TEST_F(TestConvert, TestConvertGeShape1) {
  692. GeShape ge_shape({10, 1, 1, 1});
  693. std::vector<int64_t> request_dims{10};
  694. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  695. }
  696. TEST_F(TestConvert, TestConvertGeShape2) {
  697. GeShape ge_shape({10, 15, 1, 1});
  698. std::vector<int64_t> request_dims{10, 15};
  699. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  700. }
  701. TEST_F(TestConvert, TestConvertGeShape3) {
  702. GeShape ge_shape({10, 13, 18, 1});
  703. std::vector<int64_t> request_dims{10, 13, 18};
  704. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  705. }
  706. TEST_F(TestConvert, TestConvertGeShape4) {
  707. GeShape ge_shape({1, 10, 1, 1});
  708. std::vector<int64_t> request_dims{10};
  709. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  710. }
  711. TEST_F(TestConvert, TestConvertGeShape5) {
  712. GeShape ge_shape({10, 1, 1, 2});
  713. std::vector<int64_t> request_dims{10};
  714. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  715. }
  716. TEST_F(TestConvert, TestConvertGeShape6) {
  717. GeShape ge_shape({5, 2, 1, 1});
  718. std::vector<int64_t> request_dims{10};
  719. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  720. }
  721. TEST_F(TestConvert, TestConvertGeShape7) {
  722. GeShape ge_shape({10});
  723. std::vector<int64_t> request_dims{10, 1};
  724. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  725. }
  726. } // namespace transform
  727. } // namespace mindspore