You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_test.cc 33 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <unordered_map>
  18. #include "pybind11/pybind11.h"
  19. #include "transform/transform_base_test.h"
  20. #include "common/py_func_graph_fetcher.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "debug/draw.h"
  23. #include "debug/anf_ir_dump.h"
  24. #include "pipeline/jit/static_analysis/prim.h"
  25. #include "frontend/operator/ops.h"
  26. #include "common/common_test.h"
  27. #define private public
  28. #include "transform/graph_ir/types.h"
  29. #include "transform/graph_ir/convert.h"
  30. #include "securec/include/securec.h"
  31. #include "utils/utils.h"
  32. using std::cout;
  33. using std::endl;
  34. using std::string;
  35. using std::unordered_map;
  36. namespace mindspore {
  37. namespace transform {
  38. using AbstractScalar = abstract::AbstractScalar;
  39. using mindspore::parse::ResolveAll;
  40. class TestConvert : public UT::Common {
  41. public:
  42. TestConvert() {}
  43. virtual void SetUp();
  44. virtual void TearDown();
  45. static const std::shared_ptr<Float> kF32;
  46. };
  47. void TestConvert::SetUp() { UT::InitPythonPath(); }
  48. void TestConvert::TearDown() {}
  49. const std::shared_ptr<Float> TestConvert::kF32 = std::make_shared<Float>(32);
  50. AnfGraphPtr createAnfGraph() { return std::make_shared<AnfGraph>(); }
  51. TEST_F(TestConvert, TestConstruct) {
  52. AnfGraphPtr func_graph = std::make_shared<AnfGraph>();
  53. DfGraphConvertor converter(func_graph);
  54. converter.ConvertAllNode().GetComputeGraph();
  55. ASSERT_NE(converter.ErrCode(), SUCCESS);
  56. }
  57. #if (!defined ENABLE_GE)
  58. namespace {
  59. bool MakeDfGraph(PrimitivePtr prim, unsigned int nparam) {
  60. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, nparam);
  61. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  62. draw::Draw("ut_prim_" + prim->name() + ".dot", anf_graph);
  63. DumpIR("ut_prim_" + prim->name() + ".ir", anf_graph);
  64. DfGraphConvertor converter(anf_graph);
  65. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  66. converter.DrawComputeGraph(prim->name() + ".dot");
  67. if (converter.ErrCode() != 0) {
  68. MS_LOG(ERROR) << "DfGraphConvertor convert " << prim->name() << " error, error code is: " << converter.ErrCode();
  69. return false;
  70. }
  71. if (df_graph == nullptr) {
  72. MS_LOG(ERROR) << "DfGraphConvertor get " << prim->name() << " compute func_graph failed";
  73. return false;
  74. }
  75. return true;
  76. }
  77. } // namespace
  78. TEST_F(TestConvert, TestConvertConv2d) {
  79. PrimitivePtr conv2d = prim::kPrimConv2D;
  80. conv2d->AddAttr("stride", MakeValue(static_cast<int64_t>(2)));
  81. conv2d->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  82. conv2d->AddAttr("dilation", MakeValue(static_cast<int64_t>(0)));
  83. FuncGraphPtr anf_graph = MakeFuncGraph(conv2d, 2);
  84. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  85. draw::Draw("ut_prim_conv2d1.dot", anf_graph);
  86. DumpIR("ut_prim_conv2d1.ir", anf_graph);
  87. DfGraphConvertor converter(anf_graph);
  88. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  89. converter.DrawComputeGraph("conv2d.dot");
  90. ASSERT_EQ(converter.ErrCode(), 0);
  91. ASSERT_NE(df_graph, nullptr);
  92. }
  93. TEST_F(TestConvert, TestConvertMaxpooling) {
  94. auto prim = std::make_shared<Primitive>("MaxPool");
  95. FuncGraphPtr anf_graph = MakeFuncGraph(prim, 5); // ary, ksize, stride, padding, data_format
  96. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  97. draw::Draw("ut_prim_maxpooling.dot", anf_graph);
  98. DumpIR("ut_prim_maxpooling.ir", anf_graph);
  99. DfGraphConvertor converter(anf_graph);
  100. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  101. converter.DrawComputeGraph("maxpooling.dot");
  102. ASSERT_EQ(converter.ErrCode(), 0);
  103. ASSERT_NE(df_graph, nullptr);
  104. }
  105. TEST_F(TestConvert, TestReluOps) {
  106. auto prim = prim::kPrimRelu;
  107. prim->AddAttr("T", MakeValue(static_cast<int64_t>(0)));
  108. auto func_graph = MakeFuncGraph(prim, 1);
  109. ASSERT_TRUE(nullptr != func_graph);
  110. // save the func_graph to manager
  111. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  112. // call resolve
  113. bool ret_ = ResolveAll(manager);
  114. ASSERT_TRUE(ret_);
  115. // draw graph
  116. auto anfGraph = *(manager->func_graphs().begin());
  117. DfGraphConvertor converter(anfGraph);
  118. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  119. ASSERT_EQ(converter.ErrCode(), 0);
  120. }
  121. TEST_F(TestConvert, TestConvertBatchNorm) {
  122. PrimitivePtr batch_norm = prim::kPrimBatchNorm;
  123. batch_norm->AddAttr("epsilon", MakeValue(0.001f));
  124. batch_norm->AddAttr("momentum", MakeValue(0.1f));
  125. FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
  126. std::vector<AnfNodePtr> inputs;
  127. inputs.push_back(NewValueNode(batch_norm));
  128. for (unsigned int i = 0; i < 5; i++) {
  129. inputs.push_back(anf_graph->add_parameter());
  130. }
  131. CNodePtr cnode_prim = anf_graph->NewCNode(inputs);
  132. inputs.clear();
  133. inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
  134. inputs.push_back(cnode_prim);
  135. inputs.push_back(NewValueNode(static_cast<int64_t>(2)));
  136. CNodePtr cnode_getitem = anf_graph->NewCNode(inputs);
  137. inputs.clear();
  138. inputs.push_back(NewValueNode(prim::kPrimRelu));
  139. inputs.push_back(cnode_getitem);
  140. CNodePtr cnode_relu = anf_graph->NewCNode(inputs);
  141. inputs.clear();
  142. inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
  143. inputs.push_back(cnode_relu);
  144. CNodePtr cnode_return = anf_graph->NewCNode(inputs);
  145. anf_graph->set_return(cnode_return);
  146. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  147. draw::Draw("ut_prim_batchnorm.dot", anf_graph);
  148. DumpIR("ut_prim_batchnorm.ir", anf_graph);
  149. DfGraphConvertor converter(anf_graph);
  150. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  151. converter.DrawComputeGraph("batchnrom.dot");
  152. ASSERT_EQ(converter.ErrCode(), 0);
  153. ASSERT_NE(df_graph, nullptr);
  154. }
  155. TEST_F(TestConvert, TestConvertConvBackpropInput) {
  156. auto prim = prim::kPrimConv2DBackpropInput;
  157. const std::vector<int64_t> list{1,1};
  158. prim->AddAttr("stride", MakeValue(list));
  159. prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  160. prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
  161. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  162. prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
  163. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  164. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  165. auto func_graph = MakeFuncGraph(prim, 3);
  166. ASSERT_NE(func_graph, nullptr);
  167. // save the func_graph to manager
  168. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  169. // call resolve
  170. bool ret_ = ResolveAll(manager);
  171. ASSERT_TRUE(ret_);
  172. // draw graph
  173. auto anf_graph = *(manager->func_graphs().begin());
  174. DfGraphConvertor converter(anf_graph);
  175. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  176. converter.DrawComputeGraph("Conv2DBackpropInput.dot");
  177. ASSERT_EQ(converter.ErrCode(), 0);
  178. ASSERT_NE(df_graph, nullptr);
  179. }
  180. TEST_F(TestConvert, TestConvertConvBackpropFilter) {
  181. auto prim = prim::kPrimConv2DBackpropFilter;
  182. const std::vector<int64_t> list{1,1};
  183. prim->AddAttr("stride", MakeValue(list));
  184. prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
  185. prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
  186. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  187. prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
  188. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  189. prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
  190. auto func_graph = MakeFuncGraph(prim, 3);
  191. ASSERT_NE(func_graph, nullptr);
  192. // save the func_graph to manager
  193. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  194. // call resolve
  195. bool ret_ = ResolveAll(manager);
  196. ASSERT_TRUE(ret_);
  197. // draw graph
  198. auto anf_graph = *(manager->func_graphs().begin());
  199. DfGraphConvertor converter(anf_graph);
  200. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  201. converter.DrawComputeGraph("Conv2DBackpropFilter.dot");
  202. ASSERT_EQ(converter.ErrCode(), 0);
  203. ASSERT_NE(df_graph, nullptr);
  204. }
  205. TEST_F(TestConvert, TestConvertReluGrad) {
  206. auto prim = prim::kPrimReluGrad;
  207. prim->AddAttr("alpha", MakeValue(0.1f));
  208. prim->AddAttr("beta", MakeValue(0.1f));
  209. prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
  210. auto func_graph = MakeFuncGraph(prim, 2);
  211. ASSERT_NE(func_graph, nullptr);
  212. // save the func_graph to manager
  213. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  214. // call resolve
  215. bool ret_ = ResolveAll(manager);
  216. ASSERT_TRUE(ret_);
  217. // draw graph
  218. auto anf_graph = *(manager->func_graphs().begin());
  219. DfGraphConvertor converter(anf_graph);
  220. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  221. converter.DrawComputeGraph("ReluGrad.dot");
  222. ASSERT_EQ(converter.ErrCode(), 0);
  223. ASSERT_NE(df_graph, nullptr);
  224. }
  225. TEST_F(TestConvert, TestConvertBiasAdd) {
  226. auto prim = std::make_shared<Primitive>("BiasAdd");
  227. prim->AddAttr("alpha", MakeValue(0.0f));
  228. prim->AddAttr("beta", MakeValue(1.0f));
  229. auto func_graph = MakeFuncGraph(prim, 2);
  230. ASSERT_NE(func_graph, nullptr);
  231. // save the func_graph to manager
  232. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  233. // call resolve
  234. bool ret_ = ResolveAll(manager);
  235. ASSERT_TRUE(ret_);
  236. // draw graph
  237. auto anf_graph = *(manager->func_graphs().begin());
  238. DfGraphConvertor converter(anf_graph);
  239. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  240. converter.DrawComputeGraph("BiasAdd.dot");
  241. ASSERT_EQ(converter.ErrCode(), 0);
  242. ASSERT_NE(df_graph, nullptr);
  243. }
  244. TEST_F(TestConvert, TestConvertBiasAddGrad) {
  245. auto prim = prim::kPrimBiasAddGrad;
  246. prim->AddAttr("alpha", MakeValue(0.0f));
  247. prim->AddAttr("beta", MakeValue(1.0f));
  248. auto func_graph = MakeFuncGraph(prim, 2);
  249. ASSERT_NE(func_graph, nullptr);
  250. // save the func_graph to manager
  251. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  252. // call resolve
  253. bool ret_ = ResolveAll(manager);
  254. ASSERT_TRUE(ret_);
  255. // draw graph
  256. auto anf_graph = *(manager->func_graphs().begin());
  257. DfGraphConvertor converter(anf_graph);
  258. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  259. converter.DrawComputeGraph("BiasAddGrad.dot");
  260. ASSERT_EQ(converter.ErrCode(), 0);
  261. ASSERT_NE(df_graph, nullptr);
  262. }
  263. TEST_F(TestConvert, TestConvertMaxPoolGradWithArgmax) {
  264. auto prim = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
  265. prim->AddAttr("alpha", MakeValue(0.0f));
  266. prim->AddAttr("beta", MakeValue(1.0f));
  267. prim->AddAttr("window", MakeValue(static_cast<int64_t>(2)));
  268. prim->AddAttr("stride", MakeValue(static_cast<int64_t>(1)));
  269. prim->AddAttr("ceil_mode", MakeValue(static_cast<int64_t>(0)));
  270. prim->AddAttr("data_mode", MakeValue(static_cast<int64_t>(0)));
  271. prim->AddAttr("alpha", MakeValue(0.1f));
  272. prim->AddAttr("beta", MakeValue(1.0f));
  273. auto func_graph = MakeFuncGraph(prim, 2);
  274. ASSERT_NE(func_graph, nullptr);
  275. // save the func_graph to manager
  276. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  277. // call resolve
  278. bool ret_ = ResolveAll(manager);
  279. ASSERT_TRUE(ret_);
  280. // draw graph
  281. auto anf_graph = *(manager->func_graphs().begin());
  282. DfGraphConvertor converter(anf_graph);
  283. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  284. converter.DrawComputeGraph("MaxPoolGradWithArgmax.dot");
  285. ASSERT_EQ(converter.ErrCode(), 0);
  286. ASSERT_NE(df_graph, nullptr);
  287. }
  288. TEST_F(TestConvert, TestConcat) {
  289. auto prim = prim::kPrimConcat;
  290. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  291. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  292. draw::Draw("ut_prim_concat.dot", anf_graph);
  293. DumpIR("ut_prim_concat.ir", anf_graph);
  294. DfGraphConvertor converter(anf_graph);
  295. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  296. converter.DrawComputeGraph("concat.dot");
  297. ASSERT_EQ(converter.ErrCode(), 0);
  298. ASSERT_NE(df_graph, nullptr);
  299. }
  300. TEST_F(TestConvert, TestGatherV2) {
  301. auto prim = prim::kPrimGather;
  302. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 3);
  303. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  304. draw::Draw("ut_prim_gatherv2.dot", anf_graph);
  305. DumpIR("ut_prim_gatherv2.ir", anf_graph);
  306. DfGraphConvertor converter(anf_graph);
  307. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  308. converter.DrawComputeGraph("gatherv2.dot");
  309. ASSERT_EQ(converter.ErrCode(), 0);
  310. ASSERT_NE(df_graph, nullptr);
  311. }
  312. TEST_F(TestConvert, TestCast) {
  313. auto prim = prim::kPrimCast;
  314. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  315. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  316. draw::Draw("ut_prim_cast.dot", anf_graph);
  317. DumpIR("ut_prim_cast.ir", anf_graph);
  318. DfGraphConvertor converter(anf_graph);
  319. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  320. converter.DrawComputeGraph("cast.dot");
  321. ASSERT_EQ(converter.ErrCode(), 0);
  322. ASSERT_NE(df_graph, nullptr);
  323. }
  324. TEST_F(TestConvert, TestExp) {
  325. auto prim = std::make_shared<Primitive>("Exp");
  326. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  327. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  328. draw::Draw("ut_prim_exp.dot", anf_graph);
  329. DumpIR("ut_prim_exp.ir", anf_graph);
  330. DfGraphConvertor converter(anf_graph);
  331. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  332. converter.DrawComputeGraph("exp.dot");
  333. ASSERT_EQ(converter.ErrCode(), 0);
  334. ASSERT_NE(df_graph, nullptr);
  335. }
  336. TEST_F(TestConvert, TestFloor) {
  337. auto prim = std::make_shared<Primitive>("Floor");
  338. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  339. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  340. draw::Draw("ut_prim_floor.dot", anf_graph);
  341. DumpIR("ut_prim_floor.ir", anf_graph);
  342. DfGraphConvertor converter(anf_graph);
  343. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  344. converter.DrawComputeGraph("floor.dot");
  345. ASSERT_EQ(converter.ErrCode(), 0);
  346. ASSERT_NE(df_graph, nullptr);
  347. }
  348. TEST_F(TestConvert, TestGreaterEqual) {
  349. auto prim = std::make_shared<Primitive>("GreaterEqual");
  350. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  351. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  352. draw::Draw("ut_prim_greater_equal.dot", anf_graph);
  353. DumpIR("ut_prim_greater_equal.ir", anf_graph);
  354. DfGraphConvertor converter(anf_graph);
  355. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  356. converter.DrawComputeGraph("greater_equal.dot");
  357. ASSERT_EQ(converter.ErrCode(), 0);
  358. ASSERT_NE(df_graph, nullptr);
  359. }
  360. TEST_F(TestConvert, TestLess) {
  361. auto prim = std::make_shared<Primitive>("Less");
  362. prim->AddAttr("T", MakeValue(kFloat32));
  363. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  364. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  365. draw::Draw("ut_prim_less.dot", anf_graph);
  366. DumpIR("ut_prim_less.ir", anf_graph);
  367. DfGraphConvertor converter(anf_graph);
  368. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  369. converter.DrawComputeGraph("less.dot");
  370. ASSERT_EQ(converter.ErrCode(), 0);
  371. ASSERT_NE(df_graph, nullptr);
  372. }
  373. TEST_F(TestConvert, TestLessEqual) {
  374. auto prim = std::make_shared<Primitive>("LessEqual");
  375. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  376. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  377. draw::Draw("ut_prim_less_equal.dot", anf_graph);
  378. DumpIR("ut_prim_less_equal.ir", anf_graph);
  379. DfGraphConvertor converter(anf_graph);
  380. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  381. converter.DrawComputeGraph("less_equal.dot");
  382. ASSERT_EQ(converter.ErrCode(), 0);
  383. ASSERT_NE(df_graph, nullptr);
  384. }
  385. TEST_F(TestConvert, TestLogicalNot) {
  386. auto prim = std::make_shared<Primitive>("LogicalNot");
  387. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  388. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  389. draw::Draw("ut_prim_logical_not.dot", anf_graph);
  390. DumpIR("ut_prim_logical_not.ir", anf_graph);
  391. DfGraphConvertor converter(anf_graph);
  392. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  393. converter.DrawComputeGraph("logical_not.dot");
  394. ASSERT_EQ(converter.ErrCode(), 0);
  395. ASSERT_NE(df_graph, nullptr);
  396. }
  397. TEST_F(TestConvert, TestAssignAdd) {
  398. auto prim = prim::kPrimAssignAdd;
  399. prim->AddAttr("use_locking", MakeValue(true));
  400. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
  401. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  402. draw::Draw("ut_prim_assign_add.dot", anf_graph);
  403. DumpIR("ut_prim_assign_add.ir", anf_graph);
  404. DfGraphConvertor converter(anf_graph);
  405. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  406. converter.DrawComputeGraph("assign_add.dot");
  407. ASSERT_EQ(converter.ErrCode(), 0);
  408. ASSERT_NE(df_graph, nullptr);
  409. }
  410. TEST_F(TestConvert, LogSoftmax) {
  411. auto prim = prim::kPrimLogSoftmax;
  412. prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
  413. std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
  414. std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
  415. draw::Draw("ut_prim_log_softmax.dot", anf_graph);
  416. DumpIR("ut_prim_log_softmax.ir", anf_graph);
  417. DfGraphConvertor converter(anf_graph);
  418. auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  419. converter.DrawComputeGraph("log_softmax.dot");
  420. ASSERT_EQ(converter.ErrCode(), 0);
  421. ASSERT_NE(df_graph, nullptr);
  422. }
  423. TEST_F(TestConvert, TestMaximumOps) {
  424. auto prim = prim::kPrimMaximum;
  425. bool ret = MakeDfGraph(prim, 2);
  426. ASSERT_TRUE(ret);
  427. }
  428. TEST_F(TestConvert, TestReduceMeanOps) {
  429. auto prim = prim::kPrimReduceMean;
  430. prim->AddAttr("keepdims", MakeValue(true));
  431. bool ret = MakeDfGraph(prim, 2);
  432. ASSERT_TRUE(ret);
  433. }
  434. TEST_F(TestConvert, TestMinimumOps) {
  435. auto prim = prim::kPrimMinimum;
  436. bool ret = MakeDfGraph(prim, 2);
  437. ASSERT_TRUE(ret);
  438. }
  439. TEST_F(TestConvert, TestFusedMinOrMaxGradOps) {
  440. // Add infer step to this test case
  441. ASSERT_TRUE(true);
  442. }
  443. TEST_F(TestConvert, TestSqueezeOps) {
  444. auto prim = prim::kPrimSqueeze;
  445. bool ret = MakeDfGraph(prim, 2);
  446. ASSERT_TRUE(ret);
  447. }
  448. TEST_F(TestConvert, TestMulOps) {
  449. auto prim = prim::kPrimMul;
  450. bool ret = MakeDfGraph(prim, 2);
  451. ASSERT_TRUE(ret);
  452. }
  453. TEST_F(TestConvert, TestNegOps) {
  454. auto prim = prim::kPrimNeg;
  455. bool ret = MakeDfGraph(prim, 1);
  456. ASSERT_TRUE(ret);
  457. }
  458. TEST_F(TestConvert, TestOneHotOps) {
  459. auto prim = prim::kPrimOneHot;
  460. prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
  461. bool ret = MakeDfGraph(prim, 4);
  462. ASSERT_TRUE(ret);
  463. }
  464. TEST_F(TestConvert, TestPowOps) {
  465. auto prim = std::make_shared<Primitive>("Pow");
  466. bool ret = MakeDfGraph(prim, 2);
  467. ASSERT_TRUE(ret);
  468. }
  469. TEST_F(TestConvert, TestReciprocalOps) {
  470. auto prim = std::make_shared<Primitive>("Reciprocal");
  471. bool ret = MakeDfGraph(prim, 1);
  472. ASSERT_TRUE(ret);
  473. }
  474. TEST_F(TestConvert, TestSelectOps) {
  475. auto prim = prim::kPrimSelect;
  476. bool ret = MakeDfGraph(prim, 3);
  477. ASSERT_TRUE(ret);
  478. }
  479. TEST_F(TestConvert, TestSqrtOps) {
  480. auto prim = std::make_shared<Primitive>("Sqrt");
  481. bool ret = MakeDfGraph(prim, 1);
  482. ASSERT_TRUE(ret);
  483. }
  484. TEST_F(TestConvert, TestSquareOps) {
  485. auto prim = std::make_shared<Primitive>("Square");
  486. bool ret = MakeDfGraph(prim, 1);
  487. ASSERT_TRUE(ret);
  488. }
  489. TEST_F(TestConvert, TestScalarSummaryOps) {
  490. auto prim = prim::kPrimScalarSummary;
  491. // should have only 1 input.
  492. bool ret = MakeDfGraph(prim, 2);
  493. ASSERT_TRUE(ret);
  494. }
  495. TEST_F(TestConvert, TestTensorSummaryOps) {
  496. auto prim = prim::kPrimTensorSummary;
  497. bool ret = MakeDfGraph(prim, 2);
  498. ASSERT_TRUE(ret);
  499. }
  500. TEST_F(TestConvert, TestHistogramSummaryOps) {
  501. auto prim = prim::kPrimHistogramSummary;
  502. bool ret = MakeDfGraph(prim, 2);
  503. ASSERT_TRUE(ret);
  504. }
  505. TEST_F(TestConvert, TestGreaterOps) {
  506. auto prim = std::make_shared<Primitive>("Greater");
  507. bool ret = MakeDfGraph(prim, 2);
  508. ASSERT_TRUE(ret);
  509. }
  510. TEST_F(TestConvert, TestEqualOps) {
  511. auto prim = std::make_shared<Primitive>("Equal");
  512. bool ret = MakeDfGraph(prim, 2);
  513. ASSERT_TRUE(ret);
  514. }
  515. TEST_F(TestConvert, TestArgMaxiOps) {
  516. auto prim = std::make_shared<Primitive>("Argmax");
  517. bool ret = MakeDfGraph(prim, 2);
  518. ASSERT_TRUE(ret);
  519. }
  520. TEST_F(TestConvert, TestResizeNearestNeighborOps) {
  521. auto prim = std::make_shared<Primitive>("ResizeNearestNeighbor");
  522. bool ret = MakeDfGraph(prim, 1);
  523. ASSERT_TRUE(ret);
  524. }
  525. TEST_F(TestConvert, TestApplyMomentumOps) {
  526. auto prim = std::make_shared<Primitive>("ApplyMomentum");
  527. bool ret = MakeDfGraph(prim, 5);
  528. ASSERT_TRUE(ret);
  529. }
  530. TEST_F(TestConvert, TestNPUGetFloatStatusOps) {
  531. auto prim = std::make_shared<Primitive>("NPUGetFloatStatus");
  532. bool ret = MakeDfGraph(prim, 1);
  533. ASSERT_TRUE(ret);
  534. }
  535. TEST_F(TestConvert, TestNPUAllocFloatStatusOps) {
  536. auto prim = std::make_shared<Primitive>("NPUAllocFloatStatus");
  537. bool ret = MakeDfGraph(prim, 0);
  538. ASSERT_TRUE(ret);
  539. }
  540. TEST_F(TestConvert, TestNPUClearFloatStatusOps) {
  541. auto prim = std::make_shared<Primitive>("NPUClearFloatStatus");
  542. bool ret = MakeDfGraph(prim, 1);
  543. ASSERT_TRUE(ret);
  544. }
  545. #endif
  546. TEST_F(TestConvert, TestAddOps) {
  547. auto prim = std::make_shared<Primitive>("Add");
  548. auto func_graph = MakeFuncGraph(prim, 2);
  549. ASSERT_TRUE(nullptr != func_graph);
  550. // save the func_graph to manager
  551. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  552. // call resolve
  553. bool ret_ = ResolveAll(manager);
  554. ASSERT_TRUE(ret_);
  555. // draw graph
  556. auto anfGraph = *(manager->func_graphs().begin());
  557. DfGraphConvertor converter(anfGraph);
  558. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  559. ASSERT_EQ(converter.ErrCode(), 0);
  560. }
  561. TEST_F(TestConvert, TestConvertTensor) {
  562. float data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
  563. // Create a tensor with wanted data type and shape
  564. std::vector<int64_t> dims{2, 2, 3};
  565. std::vector<int64_t> ge_dims{2, 2, 3};
  566. auto type_id = kNumberTypeFloat32;
  567. MeTensor me_tensor(type_id, dims);
  568. // Get the writable data pointer of the tensor and cast it to its data type
  569. uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c());
  570. // Copy or use the writable data pointer of the ME tensor
  571. memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float));
  572. auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor);
  573. auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW);
  574. ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetFormat(), GeFormat::FORMAT_NCHW);
  575. ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetDataType(), GeDataType::DT_FLOAT);
  576. // ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().array().GetDims(), ge_dims);
  577. int i = 0;
  578. for (i = 0; i < ge_dims.size(); i++) {
  579. ASSERT_EQ(ge_dims[i], ge_tensor_ptr->GetTensorDesc().GetShape().GetDims()[i]);
  580. }
  581. for (i = 0; i < ge_tensor_ptr->GetTensorDesc().GetShape().GetShapeSize(); i++) {
  582. ASSERT_EQ(data[i], (reinterpret_cast<float*>(ge_tensor_ptr->GetData()))[i]);
  583. }
  584. }
  585. TEST_F(TestConvert, TestConvertTensor0Dims) {
  586. // shape with 0 dims is also valid
  587. std::vector<int64_t> dims{};
  588. auto type_id = kNumberTypeFloat32;
  589. auto me_tensor_ptr = std::make_shared<MeTensor>(type_id, dims);
  590. ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW), nullptr);
  591. }
  592. TEST_F(TestConvert, TestConvertTensorError) {
  593. std::vector<int64_t> dims2{2, 3, 4};
  594. auto type_id_2 = kNumberTypeFloat32;
  595. auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
  596. ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
  597. }
  598. TEST_F(TestConvert, TestUtilsConvertDataType) {
  599. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat16), GeDataType::DT_FLOAT16);
  600. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat32), GeDataType::DT_FLOAT);
  601. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat64), GeDataType::DT_DOUBLE);
  602. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt8), GeDataType::DT_INT8);
  603. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt16), GeDataType::DT_INT16);
  604. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt32), GeDataType::DT_INT32);
  605. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt64), GeDataType::DT_INT64);
  606. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeUInt32), GeDataType::DT_UINT32);
  607. ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeBool), GeDataType::DT_BOOL);
  608. }
  609. TEST_F(TestConvert, TestUtilsConvertFormat) {
  610. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NCHW), GeFormat::FORMAT_NCHW);
  611. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NC1HWC0), GeFormat::FORMAT_NC1HWC0);
  612. ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NHWC), GeFormat::FORMAT_NHWC);
  613. ASSERT_EQ(TransformUtil::ConvertFormat("xyz"), GeFormat::FORMAT_ND);
  614. }
  615. TEST_F(TestConvert, TestUtilsDataSize) {
  616. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat32), 4);
  617. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat16), 2);
  618. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat64), 8);
  619. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt8), 1);
  620. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt16), 2);
  621. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt32), 4);
  622. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt64), 8);
  623. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeUInt32), 4);
  624. ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeBool), 1);
  625. }
  626. TEST_F(TestConvert, TestConvertGeTensor) {
  627. #define DTYPE float
  628. ge::DataType dt = ge::DataType::DT_FLOAT;
  629. std::vector<float> data1 = {1.1, 2.2, 3.3, 4.4, 6.6, 7.7, 8.8, 9.9};
  630. std::vector<DTYPE> data2 = {1, 2, 3, 4, 6, 7, 8, 9};
  631. auto data = data1;
  632. ge::Shape shape({2, 2, 2});
  633. ge::Format format = ge::Format::FORMAT_NCHW;
  634. ge::TensorDesc desc(shape, format, dt);
  635. GeTensorPtr ge_tensor_ptr =
  636. std::make_shared<GeTensor>(desc, reinterpret_cast<uint8_t*>(data.data()), data.size() * sizeof(DTYPE));
  637. GeTensor& ge_tensor = *ge_tensor_ptr;
  638. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensor.GetData());
  639. // make sure GetData()'s return is a reference
  640. assert(ge_data == reinterpret_cast<DTYPE*>(ge_tensor.GetData()));
  641. cout << "ge data size is: " << std::dec << ge_tensor.GetSize() << " bytes" << endl;
  642. for (int i = 0; i < ge_tensor.GetSize() / sizeof(DTYPE); i++) {
  643. cout << "ge data is: " << static_cast<DTYPE>(*(ge_data + i)) << endl;
  644. }
  645. MeTensorPtr me_tensor_ptr = TransformUtil::ConvertGeTensor(ge_tensor_ptr);
  646. MeTensor& me_tensor = *me_tensor_ptr;
  647. cout << "after convert ge tensor to me tensor" << endl;
  648. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_tensor.data_c());
  649. PrintMeTensor(&me_tensor);
  650. assert(ge_tensor.GetSize() == me_tensor.data().nbytes());
  651. assert(memcmp(ge_data, me_data, ge_tensor.GetSize()) == 0);
  652. }
  653. TEST_F(TestConvert, TestConvertMakeTuple) {
  654. FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
  655. std::vector<AnfNodePtr> inputs;
  656. inputs.push_back(NewValueNode(std::make_shared<Primitive>("MakeTuple")));
  657. for (int i = 0; i < 3; i++) {
  658. auto input = func_graph->add_parameter();
  659. input->set_name("x" + std::to_string(i));
  660. inputs.push_back(input);
  661. }
  662. CNodePtr cnode_prim = func_graph->NewCNode(inputs);
  663. inputs.clear();
  664. inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
  665. inputs.push_back(cnode_prim);
  666. CNodePtr cnode_return = func_graph->NewCNode(inputs);
  667. func_graph->set_return(cnode_return);
  668. // save the func_graph to manager
  669. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  670. // call resolve
  671. bool ret_ = ResolveAll(manager);
  672. ASSERT_TRUE(ret_);
  673. // draw graph
  674. auto anfGraph = *(manager->func_graphs().begin());
  675. DfGraphConvertor converter(anfGraph);
  676. converter.ConvertAllNode().BuildGraph().GetComputeGraph();
  677. ASSERT_EQ(converter.ErrCode(), 0);
  678. }
  679. TEST_F(TestConvert, TestConvertInputTensors) {
  680. #define DTYPE float
  681. std::initializer_list<int64_t> list0 = {1, 1, 4, 4};
  682. std::initializer_list<int64_t> list1 = {2, 3, 4, 5};
  683. std::initializer_list<int64_t> list2 = {9, 9, 1, 1};
  684. MeTensorPtr input_ptr1 = MakeTensor(kF32, list0);
  685. MeTensorPtr input_ptr2 = MakeTensor(kF32, list1);
  686. MeTensorPtr input_ptr3 = MakeTensor(kF32, list2);
  687. std::vector<MeTensorPtr> me_inputs;
  688. me_inputs.emplace_back(input_ptr1);
  689. me_inputs.emplace_back(input_ptr2);
  690. me_inputs.emplace_back(input_ptr3);
  691. std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(me_inputs, kOpFormat_NCHW);
  692. for (int i = 0; i < ge_tensors.size(); i++) {
  693. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_inputs[i]->data_c());
  694. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
  695. ASSERT_TRUE(ge_tensors[i]->GetSize() == me_inputs[i]->data().nbytes());
  696. ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
  697. ASSERT_TRUE(ge_tensors[i]->GetTensorDesc().GetShape().GetDims() ==
  698. TransformUtil::ConvertMeShape(me_inputs[i]->shape_c()).GetDims());
  699. }
  700. }
  701. TEST_F(TestConvert, TestConvertGeTensors) {
  702. #define DTYPE float
  703. ge::DataType dt = ge::DataType::DT_FLOAT;
  704. std::vector<float> data1(16);
  705. std::vector<float> data2(120);
  706. std::vector<float> data3(81);
  707. ge::Shape shape1({1, 1, 4, 4});
  708. ge::Shape shape2({2, 3, 4, 5});
  709. ge::Shape shape3({9, 9, 1, 1});
  710. ge::Format format = ge::Format::FORMAT_NCHW;
  711. ge::TensorDesc desc1(shape1, format, dt);
  712. ge::TensorDesc desc2(shape2, format, dt);
  713. ge::TensorDesc desc3(shape3, format, dt);
  714. GeTensorPtr ge_tensor_ptr1 =
  715. std::make_shared<GeTensor>(desc1, reinterpret_cast<uint8_t*>(data1.data()), data1.size() * sizeof(DTYPE));
  716. GeTensorPtr ge_tensor_ptr2 =
  717. std::make_shared<GeTensor>(desc2, reinterpret_cast<uint8_t*>(data2.data()), data2.size() * sizeof(DTYPE));
  718. GeTensorPtr ge_tensor_ptr3 =
  719. std::make_shared<GeTensor>(desc3, reinterpret_cast<uint8_t*>(data3.data()), data3.size() * sizeof(DTYPE));
  720. std::vector<GeTensorPtr> ge_tensors;
  721. ge_tensors.emplace_back(ge_tensor_ptr1);
  722. ge_tensors.emplace_back(ge_tensor_ptr2);
  723. ge_tensors.emplace_back(ge_tensor_ptr3);
  724. std::vector<std::vector<int64_t>> request_dims;
  725. std::vector<int64_t> dims1 = {1, 1, 4, 4};
  726. std::vector<int64_t> dims2 = {2, 3, 4, 5};
  727. std::vector<int64_t> dims3 = {9, 9, 1, 1};
  728. request_dims.emplace_back(dims1);
  729. request_dims.emplace_back(dims2);
  730. request_dims.emplace_back(dims3);
  731. std::vector<MeTensorPtr> me_outputs = TransformUtil::ConvertGeTensors(ge_tensors, request_dims);
  732. for (int i = 0; i < ge_tensors.size(); i++) {
  733. DTYPE* me_data = reinterpret_cast<DTYPE*>(me_outputs[i]->data_c());
  734. const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
  735. ASSERT_TRUE(ge_tensors[i]->GetSize() == me_outputs[i]->data().nbytes());
  736. ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
  737. ASSERT_TRUE(request_dims[i] == me_outputs[i]->shape_c());
  738. }
  739. }
  740. TEST_F(TestConvert, TestConvertGeShape1) {
  741. GeShape ge_shape({10, 1, 1, 1});
  742. std::vector<int64_t> request_dims{10};
  743. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  744. }
  745. TEST_F(TestConvert, TestConvertGeShape2) {
  746. GeShape ge_shape({10, 15, 1, 1});
  747. std::vector<int64_t> request_dims{10, 15};
  748. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  749. }
  750. TEST_F(TestConvert, TestConvertGeShape3) {
  751. GeShape ge_shape({10, 13, 18, 1});
  752. std::vector<int64_t> request_dims{10, 13, 18};
  753. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  754. }
  755. TEST_F(TestConvert, TestConvertGeShape4) {
  756. GeShape ge_shape({1, 10, 1, 1});
  757. std::vector<int64_t> request_dims{10};
  758. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
  759. }
  760. TEST_F(TestConvert, TestConvertGeShape5) {
  761. GeShape ge_shape({10, 1, 1, 2});
  762. std::vector<int64_t> request_dims{10};
  763. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  764. }
  765. TEST_F(TestConvert, TestConvertGeShape6) {
  766. GeShape ge_shape({5, 2, 1, 1});
  767. std::vector<int64_t> request_dims{10};
  768. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  769. }
  770. TEST_F(TestConvert, TestConvertGeShape7) {
  771. GeShape ge_shape({10});
  772. std::vector<int64_t> request_dims{10, 1};
  773. ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
  774. }
  775. } // namespace transform
  776. } // namespace mindspore