You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_test.cc 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <unordered_map>
  18. #include "pybind11/pybind11.h"
  19. #include "transform/transform_base_test.h"
  20. #include "common/py_func_graph_fetcher.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "debug/draw.h"
  23. #include "debug/anf_ir_dump.h"
  24. #include "pipeline/jit/static_analysis/prim.h"
  25. #include "frontend/operator/ops.h"
  26. #include "common/common_test.h"
  27. #define private public
  28. #include "transform/graph_ir/types.h"
  29. #include "transform/graph_ir/convert.h"
  30. #include "securec/include/securec.h"
  31. #include "utils/utils.h"
  32. using std::cout;
  33. using std::endl;
  34. using std::string;
  35. using std::unordered_map;
  36. namespace mindspore {
  37. namespace transform {
  38. using AbstractScalar = abstract::AbstractScalar;
  39. using mindspore::parse::ResolveAll;
  40. class TestConvert : public UT::Common {
  41. public:
  42. TestConvert() {}
  43. virtual void SetUp();
  44. virtual void TearDown();
  45. static const std::shared_ptr<Float> kF32;
  46. };
  47. void TestConvert::SetUp() { UT::InitPythonPath(); }
  48. void TestConvert::TearDown() {}
  49. const std::shared_ptr<Float> TestConvert::kF32 = std::make_shared<Float>(32);
  50. AnfGraphPtr createAnfGraph() { return std::make_shared<AnfGraph>(); }
  51. TEST_F(TestConvert, TestConstruct) {
  52. AnfGraphPtr func_graph = std::make_shared<AnfGraph>();
  53. DfGraphConvertor convertor(func_graph);
  54. convertor.ConvertAllNode().GetComputeGraph();
  55. ASSERT_NE(convertor.ErrCode(), SUCCESS);
  56. }
  57. #if (!defined ENABLE_GE)
  58. namespace {
  59. bool MakeDfGraph(PrimitivePtr prim, unsigned int nparam) {
  60. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, nparam);
  61. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  62. draw::Draw("ut_prim_" + prim->name() + ".dot", anf_graph);
  63. DumpIR("ut_prim_" + prim->name() + ".ir", anf_graph);
  64. DfGraphConvertor convertor(anf_graph);
  65. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  66. convertor.DrawComputeGraph(prim->name() + ".dot");
  67. if (convertor.ErrCode() != 0) {
  68. MS_LOG(ERROR) << "DfGraphConvertor convert " << prim->name() << " error, error code is: " << convertor.ErrCode();
  69. return false;
  70. }
  71. if (df_graph == nullptr) {
  72. MS_LOG(ERROR) << "DfGraphConvertor get " << prim->name() << " compute func_graph failed";
  73. return false;
  74. }
  75. return true;
  76. }
  77. } // namespace
  78. TEST_F(TestConvert, TestConvertConv2d) {
  79. PrimitivePtr conv2d = prim::kPrimConv2D;
  80. conv2d->AddAttr("stride", MakeValue(2));
  81. conv2d->AddAttr("pad", MakeValue(0));
  82. conv2d->AddAttr("dilation", MakeValue(0));
  83. FuncGraphPtr anf_graph = MakeFuncGraph(conv2d, 2);
  84. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  85. draw::Draw("ut_prim_conv2d1.dot", anf_graph);
  86. DumpIR("ut_prim_conv2d1.ir", anf_graph);
  87. DfGraphConvertor convertor(anf_graph);
  88. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  89. convertor.DrawComputeGraph("conv2d.dot");
  90. ASSERT_EQ(convertor.ErrCode(), 0);
  91. ASSERT_NE(df_graph, nullptr);
  92. }
  93. TEST_F(TestConvert, TestConvertMaxpooling) {
  94. auto prim = std::make_shared<Primitive>("MaxPool");
  95. FuncGraphPtr anf_graph = MakeFuncGraph(prim, 5); // ary, ksize, stride, padding, data_format
  96. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  97. draw::Draw("ut_prim_maxpooling.dot", anf_graph);
  98. DumpIR("ut_prim_maxpooling.ir", anf_graph);
  99. DfGraphConvertor convertor(anf_graph);
  100. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  101. convertor.DrawComputeGraph("maxpooling.dot");
  102. ASSERT_EQ(convertor.ErrCode(), 0);
  103. ASSERT_NE(df_graph, nullptr);
  104. }
  105. TEST_F(TestConvert, TestReluOps) {
  106. auto prim = prim::kPrimRelu;
  107. prim->AddAttr("T", MakeValue(0));
  108. auto func_graph = MakeFuncGraph(prim, 1);
  109. ASSERT_TRUE(nullptr != func_graph);
  110. // save the func_graph to manager
  111. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  112. // call resolve
  113. bool ret_ = ResolveAll(manager);
  114. ASSERT_TRUE(ret_);
  115. // draw graph
  116. auto anfGraph = *(manager->func_graphs().begin());
  117. DfGraphConvertor convertor(anfGraph);
  118. convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  119. ASSERT_EQ(convertor.ErrCode(), 0);
  120. }
  121. TEST_F(TestConvert, TestConvertBatchNorm) {
  122. PrimitivePtr batch_norm = prim::kPrimBatchNorm;
  123. batch_norm->AddAttr("epsilon", MakeValue(0.001f));
  124. batch_norm->AddAttr("momentum", MakeValue(0.1f));
  125. FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
  126. std::vector<AnfNodePtr> inputs;
  127. inputs.push_back(NewValueNode(batch_norm));
  128. for (unsigned int i = 0; i < 5; i++) {
  129. inputs.push_back(anf_graph->add_parameter());
  130. }
  131. CNodePtr cnode_prim = anf_graph->NewCNode(inputs);
  132. inputs.clear();
  133. inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
  134. inputs.push_back(cnode_prim);
  135. inputs.push_back(NewValueNode(2));
  136. CNodePtr cnode_getitem = anf_graph->NewCNode(inputs);
  137. inputs.clear();
  138. inputs.push_back(NewValueNode(prim::kPrimRelu));
  139. inputs.push_back(cnode_getitem);
  140. CNodePtr cnode_relu = anf_graph->NewCNode(inputs);
  141. inputs.clear();
  142. inputs.push_back(NewValueNode(std::make_shared<Primitive>("return")));
  143. inputs.push_back(cnode_relu);
  144. CNodePtr cnode_return = anf_graph->NewCNode(inputs);
  145. anf_graph->set_return(cnode_return);
  146. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  147. draw::Draw("ut_prim_batchnorm.dot", anf_graph);
  148. DumpIR("ut_prim_batchnorm.ir", anf_graph);
  149. DfGraphConvertor convertor(anf_graph);
  150. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  151. convertor.DrawComputeGraph("batchnrom.dot");
  152. ASSERT_EQ(convertor.ErrCode(), 0);
  153. ASSERT_NE(df_graph, nullptr);
  154. }
  155. TEST_F(TestConvert, TestConvertConvBackpropInput) {
  156. auto prim = prim::kPrimConv2DBackpropInput;
  157. const std::vector<int> list{1,1};
  158. prim->AddAttr("stride", MakeValue(list));
  159. prim->AddAttr("pad", MakeValue(0));
  160. prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
  161. prim->AddAttr("dilation", MakeValue(1));
  162. prim->AddAttr("group", MakeValue(1));
  163. prim->AddAttr("mode", MakeValue(1));
  164. prim->AddAttr("dilation", MakeValue(1));
  165. auto func_graph = MakeFuncGraph(prim, 3);
  166. ASSERT_NE(func_graph, nullptr);
  167. // save the func_graph to manager
  168. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  169. // call resolve
  170. bool ret_ = ResolveAll(manager);
  171. ASSERT_TRUE(ret_);
  172. // draw graph
  173. auto anf_graph = *(manager->func_graphs().begin());
  174. DfGraphConvertor convertor(anf_graph);
  175. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  176. convertor.DrawComputeGraph("Conv2DBackpropInput.dot");
  177. ASSERT_EQ(convertor.ErrCode(), 0);
  178. ASSERT_NE(df_graph, nullptr);
  179. }
  180. TEST_F(TestConvert, TestConvertConvBackpropFilter) {
  181. auto prim = prim::kPrimConv2DBackpropFilter;
  182. const std::vector<int> list{1,1};
  183. prim->AddAttr("stride", MakeValue(list));
  184. prim->AddAttr("pad", MakeValue(0));
  185. prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
  186. prim->AddAttr("dilation", MakeValue(1));
  187. prim->AddAttr("group", MakeValue(1));
  188. prim->AddAttr("mode", MakeValue(1));
  189. prim->AddAttr("dilation", MakeValue(1));
  190. auto func_graph = MakeFuncGraph(prim, 3);
  191. ASSERT_NE(func_graph, nullptr);
  192. // save the func_graph to manager
  193. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  194. // call resolve
  195. bool ret_ = ResolveAll(manager);
  196. ASSERT_TRUE(ret_);
  197. // draw graph
  198. auto anf_graph = *(manager->func_graphs().begin());
  199. DfGraphConvertor convertor(anf_graph);
  200. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  201. convertor.DrawComputeGraph("Conv2DBackpropFilter.dot");
  202. ASSERT_EQ(convertor.ErrCode(), 0);
  203. ASSERT_NE(df_graph, nullptr);
  204. }
  205. TEST_F(TestConvert, TestConvertReluGrad) {
  206. auto prim = prim::kPrimReluGrad;
  207. prim->AddAttr("alpha", MakeValue(0.1f));
  208. prim->AddAttr("beta", MakeValue(0.1f));
  209. prim->AddAttr("mode", MakeValue(1));
  210. auto func_graph = MakeFuncGraph(prim, 2);
  211. ASSERT_NE(func_graph, nullptr);
  212. // save the func_graph to manager
  213. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  214. // call resolve
  215. bool ret_ = ResolveAll(manager);
  216. ASSERT_TRUE(ret_);
  217. // draw graph
  218. auto anf_graph = *(manager->func_graphs().begin());
  219. DfGraphConvertor convertor(anf_graph);
  220. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  221. convertor.DrawComputeGraph("ReluGrad.dot");
  222. ASSERT_EQ(convertor.ErrCode(), 0);
  223. ASSERT_NE(df_graph, nullptr);
  224. }
  225. TEST_F(TestConvert, TestConvertBiasAdd) {
  226. auto prim = std::make_shared<Primitive>("BiasAdd");
  227. prim->AddAttr("alpha", MakeValue(0.0f));
  228. prim->AddAttr("beta", MakeValue(1.0f));
  229. prim->AddAttr("format", MakeValue(1));
  230. auto func_graph = MakeFuncGraph(prim, 2);
  231. ASSERT_NE(func_graph, nullptr);
  232. // save the func_graph to manager
  233. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  234. // call resolve
  235. bool ret_ = ResolveAll(manager);
  236. ASSERT_TRUE(ret_);
  237. // draw graph
  238. auto anf_graph = *(manager->func_graphs().begin());
  239. DfGraphConvertor convertor(anf_graph);
  240. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  241. convertor.DrawComputeGraph("BiasAdd.dot");
  242. ASSERT_EQ(convertor.ErrCode(), 0);
  243. ASSERT_NE(df_graph, nullptr);
  244. }
  245. TEST_F(TestConvert, TestConvertBiasAddGrad) {
  246. auto prim = prim::kPrimBiasAddGrad;
  247. prim->AddAttr("alpha", MakeValue(0.0f));
  248. prim->AddAttr("beta", MakeValue(1.0f));
  249. prim->AddAttr("format", MakeValue(1));
  250. auto func_graph = MakeFuncGraph(prim, 2);
  251. ASSERT_NE(func_graph, nullptr);
  252. // save the func_graph to manager
  253. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  254. // call resolve
  255. bool ret_ = ResolveAll(manager);
  256. ASSERT_TRUE(ret_);
  257. // draw graph
  258. auto anf_graph = *(manager->func_graphs().begin());
  259. DfGraphConvertor convertor(anf_graph);
  260. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  261. convertor.DrawComputeGraph("BiasAddGrad.dot");
  262. ASSERT_EQ(convertor.ErrCode(), 0);
  263. ASSERT_NE(df_graph, nullptr);
  264. }
  265. TEST_F(TestConvert, TestConvertMaxPoolGradWithArgmax) {
  266. auto prim = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
  267. prim->AddAttr("alpha", MakeValue(0.0f));
  268. prim->AddAttr("beta", MakeValue(1.0f));
  269. prim->AddAttr("window", MakeValue(2));
  270. prim->AddAttr("stride", MakeValue(1));
  271. prim->AddAttr("ceil_mode", MakeValue(0));
  272. prim->AddAttr("data_mode", MakeValue(0));
  273. prim->AddAttr("alpha", MakeValue(0.1f));
  274. prim->AddAttr("beta", MakeValue(1.0f));
  275. auto func_graph = MakeFuncGraph(prim, 2);
  276. ASSERT_NE(func_graph, nullptr);
  277. // save the func_graph to manager
  278. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  279. // call resolve
  280. bool ret_ = ResolveAll(manager);
  281. ASSERT_TRUE(ret_);
  282. // draw graph
  283. auto anf_graph = *(manager->func_graphs().begin());
  284. DfGraphConvertor convertor(anf_graph);
  285. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  286. convertor.DrawComputeGraph("MaxPoolGradWithArgmax.dot");
  287. ASSERT_EQ(convertor.ErrCode(), 0);
  288. ASSERT_NE(df_graph, nullptr);
  289. }
  290. TEST_F(TestConvert, TestConcat) {
  291. auto prim = prim::kPrimConcat;
  292. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  293. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  294. draw::Draw("ut_prim_concat.dot", anf_graph);
  295. DumpIR("ut_prim_concat.ir", anf_graph);
  296. DfGraphConvertor convertor(anf_graph);
  297. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  298. convertor.DrawComputeGraph("concat.dot");
  299. ASSERT_EQ(convertor.ErrCode(), 0);
  300. ASSERT_NE(df_graph, nullptr);
  301. }
  302. TEST_F(TestConvert, TestGatherV2) {
  303. auto prim = prim::kPrimGatherV2;
  304. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 3);
  305. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  306. draw::Draw("ut_prim_gatherv2.dot", anf_graph);
  307. DumpIR("ut_prim_gatherv2.ir", anf_graph);
  308. DfGraphConvertor convertor(anf_graph);
  309. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  310. convertor.DrawComputeGraph("gatherv2.dot");
  311. ASSERT_EQ(convertor.ErrCode(), 0);
  312. ASSERT_NE(df_graph, nullptr);
  313. }
  314. TEST_F(TestConvert, TestCast) {
  315. auto prim = prim::kPrimCast;
  316. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  317. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  318. draw::Draw("ut_prim_cast.dot", anf_graph);
  319. DumpIR("ut_prim_cast.ir", anf_graph);
  320. DfGraphConvertor convertor(anf_graph);
  321. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  322. convertor.DrawComputeGraph("cast.dot");
  323. ASSERT_EQ(convertor.ErrCode(), 0);
  324. ASSERT_NE(df_graph, nullptr);
  325. }
  326. TEST_F(TestConvert, TestExp) {
  327. auto prim = std::make_shared<Primitive>("Exp");
  328. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  329. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  330. draw::Draw("ut_prim_exp.dot", anf_graph);
  331. DumpIR("ut_prim_exp.ir", anf_graph);
  332. DfGraphConvertor convertor(anf_graph);
  333. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  334. convertor.DrawComputeGraph("exp.dot");
  335. ASSERT_EQ(convertor.ErrCode(), 0);
  336. ASSERT_NE(df_graph, nullptr);
  337. }
  338. TEST_F(TestConvert, TestFloor) {
  339. auto prim = std::make_shared<Primitive>("Floor");
  340. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  341. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  342. draw::Draw("ut_prim_floor.dot", anf_graph);
  343. DumpIR("ut_prim_floor.ir", anf_graph);
  344. DfGraphConvertor convertor(anf_graph);
  345. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  346. convertor.DrawComputeGraph("floor.dot");
  347. ASSERT_EQ(convertor.ErrCode(), 0);
  348. ASSERT_NE(df_graph, nullptr);
  349. }
  350. TEST_F(TestConvert, TestGreaterEqual) {
  351. auto prim = std::make_shared<Primitive>("GreaterEqual");
  352. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  353. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  354. draw::Draw("ut_prim_greater_equal.dot", anf_graph);
  355. DumpIR("ut_prim_greater_equal.ir", anf_graph);
  356. DfGraphConvertor convertor(anf_graph);
  357. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  358. convertor.DrawComputeGraph("greater_equal.dot");
  359. ASSERT_EQ(convertor.ErrCode(), 0);
  360. ASSERT_NE(df_graph, nullptr);
  361. }
  362. TEST_F(TestConvert, TestLess) {
  363. auto prim = std::make_shared<Primitive>("Less");
  364. prim->AddAttr("T", MakeValue(kFloat32));
  365. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  366. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  367. draw::Draw("ut_prim_less.dot", anf_graph);
  368. DumpIR("ut_prim_less.ir", anf_graph);
  369. DfGraphConvertor convertor(anf_graph);
  370. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  371. convertor.DrawComputeGraph("less.dot");
  372. ASSERT_EQ(convertor.ErrCode(), 0);
  373. ASSERT_NE(df_graph, nullptr);
  374. }
  375. TEST_F(TestConvert, TestLessEqual) {
  376. auto prim = std::make_shared<Primitive>("LessEqual");
  377. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  378. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  379. draw::Draw("ut_prim_less_equal.dot", anf_graph);
  380. DumpIR("ut_prim_less_equal.ir", anf_graph);
  381. DfGraphConvertor convertor(anf_graph);
  382. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  383. convertor.DrawComputeGraph("less_equal.dot");
  384. ASSERT_EQ(convertor.ErrCode(), 0);
  385. ASSERT_NE(df_graph, nullptr);
  386. }
  387. TEST_F(TestConvert, TestLogicalNot) {
  388. auto prim = std::make_shared<Primitive>("LogicalNot");
  389. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  390. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  391. draw::Draw("ut_prim_logical_not.dot", anf_graph);
  392. DumpIR("ut_prim_logical_not.ir", anf_graph);
  393. DfGraphConvertor convertor(anf_graph);
  394. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  395. convertor.DrawComputeGraph("logical_not.dot");
  396. ASSERT_EQ(convertor.ErrCode(), 0);
  397. ASSERT_NE(df_graph, nullptr);
  398. }
  399. TEST_F(TestConvert, TestAssignAdd) {
  400. auto prim = prim::kPrimAssignAdd;
  401. prim->AddAttr("use_locking", MakeValue(true));
  402. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  403. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  404. draw::Draw("ut_prim_assign_add.dot", anf_graph);
  405. DumpIR("ut_prim_assign_add.ir", anf_graph);
  406. DfGraphConvertor convertor(anf_graph);
  407. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  408. convertor.DrawComputeGraph("assign_add.dot");
  409. ASSERT_EQ(convertor.ErrCode(), 0);
  410. ASSERT_NE(df_graph, nullptr);
  411. }
  412. TEST_F(TestConvert, LogSoftmax) {
  413. auto prim = prim::kPrimLogSoftmax;
  414. prim->AddAttr("axis", MakeValue(0));
  415. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  416. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  417. draw::Draw("ut_prim_log_softmax.dot", anf_graph);
  418. DumpIR("ut_prim_log_softmax.ir", anf_graph);
  419. DfGraphConvertor convertor(anf_graph);
  420. auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  421. convertor.DrawComputeGraph("log_softmax.dot");
  422. ASSERT_EQ(convertor.ErrCode(), 0);
  423. ASSERT_NE(df_graph, nullptr);
  424. }
  425. TEST_F(TestConvert, TestMaximumOps) {
  426. auto prim = prim::kPrimMaximum;
  427. bool ret = MakeDfGraph(prim, 2);
  428. ASSERT_TRUE(ret);
  429. }
  430. TEST_F(TestConvert, TestReduceMeanOps) {
  431. auto prim = prim::kPrimReduceMean;
  432. prim->AddAttr("keepdims", MakeValue(true));
  433. bool ret = MakeDfGraph(prim, 2);
  434. ASSERT_TRUE(ret);
  435. }
  436. TEST_F(TestConvert, TestMinimumOps) {
  437. auto prim = prim::kPrimMinimum;
  438. bool ret = MakeDfGraph(prim, 2);
  439. ASSERT_TRUE(ret);
  440. }
  441. TEST_F(TestConvert, TestFusedMinOrMaxGradOps) {
  442. // Add infer step to this test case
  443. ASSERT_TRUE(true);
  444. }
  445. TEST_F(TestConvert, TestSqueezeOps) {
  446. auto prim = prim::kPrimSqueeze;
  447. bool ret = MakeDfGraph(prim, 2);
  448. ASSERT_TRUE(ret);
  449. }
  450. TEST_F(TestConvert, TestMulOps) {
  451. auto prim = prim::kPrimMul;
  452. bool ret = MakeDfGraph(prim, 2);
  453. ASSERT_TRUE(ret);
  454. }
  455. TEST_F(TestConvert, TestNegOps) {
  456. auto prim = prim::kPrimNeg;
  457. bool ret = MakeDfGraph(prim, 1);
  458. ASSERT_TRUE(ret);
  459. }
  460. TEST_F(TestConvert, TestOneHotOps) {
  461. auto prim = prim::kPrimOneHot;
  462. prim->AddAttr("axis", MakeValue(0));
  463. bool ret = MakeDfGraph(prim, 4);
  464. ASSERT_TRUE(ret);
  465. }
  466. TEST_F(TestConvert, TestPowOps) {
  467. auto prim = std::make_shared<Primitive>("Pow");
  468. bool ret = MakeDfGraph(prim, 2);
  469. ASSERT_TRUE(ret);
  470. }
  471. TEST_F(TestConvert, TestReciprocalOps) {
  472. auto prim = std::make_shared<Primitive>("Reciprocal");
  473. bool ret = MakeDfGraph(prim, 1);
  474. ASSERT_TRUE(ret);
  475. }
  476. TEST_F(TestConvert, TestSelectOps) {
  477. auto prim = prim::kPrimSelect;
  478. bool ret = MakeDfGraph(prim, 3);
  479. ASSERT_TRUE(ret);
  480. }
  481. TEST_F(TestConvert, TestSqrtOps) {
  482. auto prim = std::make_shared<Primitive>("Sqrt");
  483. bool ret = MakeDfGraph(prim, 1);
  484. ASSERT_TRUE(ret);
  485. }
  486. TEST_F(TestConvert, TestSquareOps) {
  487. auto prim = std::make_shared<Primitive>("Square");
  488. bool ret = MakeDfGraph(prim, 1);
  489. ASSERT_TRUE(ret);
  490. }
  491. TEST_F(TestConvert, TestScalarSummaryOps) {
  492. auto prim = prim::kPrimScalarSummary;
  493. // should have only 1 input.
  494. bool ret = MakeDfGraph(prim, 2);
  495. ASSERT_TRUE(ret);
  496. }
  497. TEST_F(TestConvert, TestTensorSummaryOps) {
  498. auto prim = prim::kPrimTensorSummary;
  499. bool ret = MakeDfGraph(prim, 2);
  500. ASSERT_TRUE(ret);
  501. }
  502. TEST_F(TestConvert, TestHistogramSummaryOps) {
  503. auto prim = prim::kPrimHistogramSummary;
  504. bool ret = MakeDfGraph(prim, 2);
  505. ASSERT_TRUE(ret);
  506. }
  507. TEST_F(TestConvert, TestGreaterOps) {
  508. auto prim = std::make_shared<Primitive>("Greater");
  509. bool ret = MakeDfGraph(prim, 2);
  510. ASSERT_TRUE(ret);
  511. }
  512. TEST_F(TestConvert, TestEqualOps) {
  513. auto prim = std::make_shared<Primitive>("Equal");
  514. bool ret = MakeDfGraph(prim, 2);
  515. ASSERT_TRUE(ret);
  516. }
  517. TEST_F(TestConvert, TestArgMaxiOps) {
  518. auto prim = std::make_shared<Primitive>("Argmax");
  519. bool ret = MakeDfGraph(prim, 2);
  520. ASSERT_TRUE(ret);
  521. }
  522. TEST_F(TestConvert, TestResizeNearestNeighborOps) {
  523. auto prim = std::make_shared<Primitive>("ResizeNearestNeighbor");
  524. bool ret = MakeDfGraph(prim, 1);
  525. ASSERT_TRUE(ret);
  526. }
  527. TEST_F(TestConvert, TestApplyMomentumOps) {
  528. auto prim = std::make_shared<Primitive>("ApplyMomentum");
  529. bool ret = MakeDfGraph(prim, 5);
  530. ASSERT_TRUE(ret);
  531. }
  532. TEST_F(TestConvert, TestNPUGetFloatStatusOps) {
  533. auto prim = std::make_shared<Primitive>("NPUGetFloatStatus");
  534. bool ret = MakeDfGraph(prim, 1);
  535. ASSERT_TRUE(ret);
  536. }
  537. TEST_F(TestConvert, TestNPUAllocFloatStatusOps) {
  538. auto prim = std::make_shared<Primitive>("NPUAllocFloatStatus");
  539. bool ret = MakeDfGraph(prim, 0);
  540. ASSERT_TRUE(ret);
  541. }
  542. TEST_F(TestConvert, TestNPUClearFloatStatusOps) {
  543. auto prim = std::make_shared<Primitive>("NPUClearFloatStatus");
  544. bool ret = MakeDfGraph(prim, 1);
  545. ASSERT_TRUE(ret);
  546. }
  547. #endif
  548. TEST_F(TestConvert, TestAddOps) {
  549. auto prim = std::make_shared<Primitive>("TensorAdd");
  550. auto func_graph = MakeFuncGraph(prim, 2);
  551. ASSERT_TRUE(nullptr != func_graph);
  552. // save the func_graph to manager
  553. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  554. // call resolve
  555. bool ret_ = ResolveAll(manager);
  556. ASSERT_TRUE(ret_);
  557. // draw graph
  558. auto anfGraph = *(manager->func_graphs().begin());
  559. DfGraphConvertor convertor(anfGraph);
  560. convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  561. ASSERT_EQ(convertor.ErrCode(), 0);
  562. }
  563. TEST_F(TestConvert, TestConvertTensor) {
  564. float data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
  565. // Create a tensor with wanted data type and shape
  566. std::vector<int> dims{2, 2, 3};
  567. std::vector<int64_t> ge_dims{2, 2, 3};
  568. auto type_id = kNumberTypeFloat32;
  569. MeTensor me_tensor(type_id, dims);
  570. // Get the writable data pointer of the tensor and cast it to its data type
  571. uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c());
  572. // Copy or use the writable data pointer of the ME tensor
  573. memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float));
  574. auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor);
  575. auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW);
  576. ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetFormat(), GeFormat::FORMAT_NCHW);
  577. ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetDataType(), GeDataType::DT_FLOAT);
  578. // ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().array().GetDims(), ge_dims);
  579. int i = 0;
  580. for (i = 0; i < ge_dims.size(); i++) {
  581. ASSERT_EQ(ge_dims[i], ge_tensor_ptr->GetTensorDesc().GetShape().GetDims()[i]);
  582. }
  583. for (i = 0; i < ge_tensor_ptr->GetTensorDesc().GetShape().GetShapeSize(); i++) {
  584. ASSERT_EQ(data[i], (reinterpret_cast<float*>(ge_tensor_ptr->GetData()))[i]);
  585. }
  586. }
  587. TEST_F(TestConvert, TestConvertTensor0Dims) {
  588. // shape with 0 dims is also valid
  589. std::vector<int> dims{};
  590. auto type_id = kNumberTypeFloat32;
  591. auto me_tensor_ptr = std::make_shared<MeTensor>(type_id, dims);
  592. ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW), nullptr);
  593. }
  594. TEST_F(TestConvert, TestConvertTensorError) {
  595. std::vector<int> dims2{2, 3, 4};
  596. auto type_id_2 = kNumberTypeFloat32;
  597. auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
  598. ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
  599. }
  600. TEST_F(TestConvert, TestUtilsConvertDataType) {
  601. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat16), GeDataType::DT_FLOAT16);
  602. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat32), GeDataType::DT_FLOAT);
  603. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat64), GeDataType::DT_DOUBLE);
  604. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt8), GeDataType::DT_INT8);
  605. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt16), GeDataType::DT_INT16);
  606. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt32), GeDataType::DT_INT32);
  607. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt64), GeDataType::DT_INT64);
  608. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeUInt32), GeDataType::DT_UINT32);
  609. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeBool), GeDataType::DT_BOOL);
  610. }
  611. TEST_F(TestConvert, TestUtilsConvertFormat) {
  612. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NCHW), GeFormat::FORMAT_NCHW);
  613. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NC1HWC0), GeFormat::FORMAT_NC1HWC0);
  614. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NHWC), GeFormat::FORMAT_NHWC);
  615. ASSERT_EQ(TransformUtil::ConvertFormat("xyz"), GeFormat::FORMAT_ND);
  616. }
  617. TEST_F(TestConvert, TestUtilsDataSize) {
  618. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat32), 4);
  619. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat16), 2);
  620. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat64), 8);
  621. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt8), 1);
  622. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt16), 2);
  623. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt32), 4);
  624. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt64), 8);
  625. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeUInt32), 4);
  626. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeBool), 1);
  627. }
  628. TEST_F(TestConvert, TestConvertGeTensor) {
  629. #define DTYPE float
  630. ge::DataType dt = ge::DataType::DT_FLOAT;
  631. std::vector<float> data1 = {1.1, 2.2, 3.3, 4.4, 6.6, 7.7, 8.8, 9.9};
  632. std::vector<DTYPE> data2 = {1, 2, 3, 4, 6, 7, 8, 9};
  633. auto data = data1;
  634. ge::Shape shape({2, 2, 2});
  635. ge::Format format = ge::Format::FORMAT_NCHW;
  636. ge::TensorDesc desc(shape, format, dt);
  637. GeTensorPtr ge_tensor_ptr =
  638. std::make_shared<GeTensor>(desc, reinterpret_cast<uint8_t*>(data.data()), data.size() * sizeof(DTYPE));
  639. GeTensor& ge_tensor = *ge_tensor_ptr;
  640. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensor.GetData());
  641. // make sure GetData()'s return is a reference
  642. assert(ge_data == reinterpret_cast<DTYPE*>(ge_tensor.GetData()));
  643. cout << "ge data size is: " << std::dec << ge_tensor.GetSize() << " bytes" << endl;
  644. for (int i = 0; i < ge_tensor.GetSize() / sizeof(DTYPE); i++) {
  645. cout << "ge data is: " << static_cast<DTYPE>(*(ge_data + i)) << endl;
  646. }
  647. MeTensorPtr me_tensor_ptr = TransformUtil::ConvertGeTensor(ge_tensor_ptr);
  648. MeTensor& me_tensor = *me_tensor_ptr;
  649. cout << "after convert ge tensor to me tensor" << endl;
  650. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_tensor.data_c());
  651. PrintMeTensor(&me_tensor);
  652. assert(ge_tensor.GetSize() == me_tensor.data().nbytes());
  653. assert(memcmp(ge_data, me_data, ge_tensor.GetSize()) == 0);
  654. }
  655. TEST_F(TestConvert, TestConvertMakeTuple) {
  656. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  657. std::vector<AnfNodePtr> inputs;
  658. inputs.push_back(NewValueNode(std::make_shared<Primitive>("make_tuple")));
  659. for (int i = 0; i < 3; i++) {
  660. auto input = func_graph->add_parameter();
  661. input->set_name("x" + std::to_string(i));
  662. inputs.push_back(input);
  663. }
  664. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  665. inputs.clear();
  666. inputs.push_back(NewValueNode(std::make_shared<Primitive>("return")));
  667. inputs.push_back(cnode_prim);
  668. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  669. func_graph->set_return(cnode_return);
  670. // save the func_graph to manager
  671. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  672. // call resolve
  673. bool ret_ = ResolveAll(manager);
  674. ASSERT_TRUE(ret_);
  675. // draw graph
  676. auto anfGraph = *(manager->func_graphs().begin());
  677. DfGraphConvertor convertor(anfGraph);
  678. convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
  679. ASSERT_EQ(convertor.ErrCode(), 0);
  680. }
  681. TEST_F(TestConvert, TestConvertInputTensors) {
  682. #define DTYPE float
  683. MeTensorPtr input_ptr1 = MakeTensor(kF32, {1, 1, 4, 4});
  684. MeTensorPtr input_ptr2 = MakeTensor(kF32, {2, 3, 4, 5});
  685. MeTensorPtr input_ptr3 = MakeTensor(kF32, {9, 9, 1, 1});
  686. std::vector<MeTensorPtr> me_inputs;
  687. me_inputs.emplace_back(input_ptr1);
  688. me_inputs.emplace_back(input_ptr2);
  689. me_inputs.emplace_back(input_ptr3);
  690. std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(me_inputs, kOpFormat_NCHW);
  691. for (int i = 0; i < ge_tensors.size(); i++) {
  692. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_inputs[i]->data_c());
  693. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
  694. ASSERT_TRUE(ge_tensors[i]->GetSize() == me_inputs[i]->data().nbytes());
  695. ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
  696. ASSERT_TRUE(ge_tensors[i]->GetTensorDesc().GetShape().GetDims() ==
  697. TransformUtil::ConvertMeShape(me_inputs[i]->shape_c()).GetDims());
  698. }
  699. }
  700. TEST_F(TestConvert, TestConvertGeTensors) {
  701. #define DTYPE float
  702. ge::DataType dt = ge::DataType::DT_FLOAT;
  703. std::vector<float> data1(16);
  704. std::vector<float> data2(120);
  705. std::vector<float> data3(81);
  706. ge::Shape shape1({1, 1, 4, 4});
  707. ge::Shape shape2({2, 3, 4, 5});
  708. ge::Shape shape3({9, 9, 1, 1});
  709. ge::Format format = ge::Format::FORMAT_NCHW;
  710. ge::TensorDesc desc1(shape1, format, dt);
  711. ge::TensorDesc desc2(shape2, format, dt);
  712. ge::TensorDesc desc3(shape3, format, dt);
  713. GeTensorPtr ge_tensor_ptr1 =
  714. std::make_shared<GeTensor>(desc1, reinterpret_cast<uint8_t*>(data1.data()), data1.size() * sizeof(DTYPE));
  715. GeTensorPtr ge_tensor_ptr2 =
  716. std::make_shared<GeTensor>(desc2, reinterpret_cast<uint8_t*>(data2.data()), data2.size() * sizeof(DTYPE));
  717. GeTensorPtr ge_tensor_ptr3 =
  718. std::make_shared<GeTensor>(desc3, reinterpret_cast<uint8_t*>(data3.data()), data3.size() * sizeof(DTYPE));
  719. std::vector<GeTensorPtr> ge_tensors;
  720. ge_tensors.emplace_back(ge_tensor_ptr1);
  721. ge_tensors.emplace_back(ge_tensor_ptr2);
  722. ge_tensors.emplace_back(ge_tensor_ptr3);
  723. std::vector<std::vector<int>> request_dims;
  724. std::vector<int> dims1 = {1, 1, 4, 4};
  725. std::vector<int> dims2 = {2, 3, 4, 5};
  726. std::vector<int> dims3 = {9, 9, 1, 1};
  727. request_dims.emplace_back(dims1);
  728. request_dims.emplace_back(dims2);
  729. request_dims.emplace_back(dims3);
  730. std::vector<MeTensorPtr> me_outputs = TransformUtil::ConvertGeTensors(ge_tensors, request_dims);
  731. for (int i = 0; i < ge_tensors.size(); i++) {
  732. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_outputs[i]->data_c());
  733. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
  734. ASSERT_TRUE(ge_tensors[i]->GetSize() == me_outputs[i]->data().nbytes());
  735. ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
  736. ASSERT_TRUE(request_dims[i] == me_outputs[i]->shape_c());
  737. }
  738. }
  739. TEST_F(TestConvert, TestConvertGeShape1) {
  740. GeShape ge_shape({10, 1, 1, 1});
  741. std::vector<int> request_dims{10};
  742. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  743. }
  744. TEST_F(TestConvert, TestConvertGeShape2) {
  745. GeShape ge_shape({10, 15, 1, 1});
  746. std::vector<int> request_dims{10, 15};
  747. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  748. }
  749. TEST_F(TestConvert, TestConvertGeShape3) {
  750. GeShape ge_shape({10, 13, 18, 1});
  751. std::vector<int> request_dims{10, 13, 18};
  752. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  753. }
  754. TEST_F(TestConvert, TestConvertGeShape4) {
  755. GeShape ge_shape({1, 10, 1, 1});
  756. std::vector<int> request_dims{10};
  757. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  758. }
  759. TEST_F(TestConvert, TestConvertGeShape5) {
  760. GeShape ge_shape({10, 1, 1, 2});
  761. std::vector<int> request_dims{10};
  762. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  763. }
  764. TEST_F(TestConvert, TestConvertGeShape6) {
  765. GeShape ge_shape({5, 2, 1, 1});
  766. std::vector<int> request_dims{10};
  767. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  768. }
  769. TEST_F(TestConvert, TestConvertGeShape7) {
  770. GeShape ge_shape({10});
  771. std::vector<int> request_dims{10, 1};
  772. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  773. }
  774. } // namespace transform
  775. } // namespace mindspore