You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_control.cc 20 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include <vector>
  18. #include "common/common_test.h"
  19. #include "include/api/model.h"
  20. #include "include/api/serialization.h"
  21. #include "include/api/context.h"
  22. using namespace mindspore;
  23. static constexpr char kIfbyIfFile[] = "/home/workspace/mindspore_dataset/mindir/control/ifbyif.mindir";
  24. static constexpr char kSimpleWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/simple_while.mindir";
  25. static constexpr char kMixIfWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/mix_while_if.mindir";
  26. static constexpr char kRecursiveFile[] = "/home/workspace/mindspore_dataset/mindir/control/fibonacci.mindir";
  27. static constexpr char kSingleForFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_for.mindir";
  28. static constexpr char kSingleOrFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_or.mindir";
  29. static constexpr char kSingleSwitchFile[] = "/home/workspace/mindspore_dataset/mindir/control/switch_layer_net.mindir";
  30. static constexpr float kConstValue = 0.1234;
  31. static const std::vector<float> input_data(2 * 3 * 4 * 5, kConstValue);
  32. class TestControl : public ST::Common {
  33. public:
  34. TestControl() {}
  35. };
  36. TEST_F(TestControl, InferIfbyIf) {
  37. auto context = ContextAutoSet();
  38. Graph graph;
  39. ASSERT_TRUE(Serialization::Load(kIfbyIfFile, ModelType::kMindIR, &graph));
  40. Model control_model;
  41. ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
  42. // assert inputs
  43. std::vector<MSTensor> inputs_before = control_model.GetInputs();
  44. ASSERT_EQ(5, inputs_before.size());
  45. EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
  46. EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
  47. EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeBool);
  48. EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeBool);
  49. EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeFloat32);
  50. ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float));
  51. ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float));
  52. ASSERT_EQ(inputs_before[2].DataSize(), sizeof(bool));
  53. ASSERT_EQ(inputs_before[3].DataSize(), sizeof(bool));
  54. ASSERT_EQ(inputs_before[4].DataSize(), sizeof(float) * input_data.size());
  55. ASSERT_EQ(inputs_before[0].Shape().size(), 1);
  56. EXPECT_EQ(inputs_before[0].Shape()[0], 1);
  57. ASSERT_EQ(inputs_before[1].Shape().size(), 1);
  58. EXPECT_EQ(inputs_before[1].Shape()[0], 1);
  59. ASSERT_EQ(inputs_before[2].Shape().size(), 1);
  60. EXPECT_EQ(inputs_before[2].Shape()[0], 1);
  61. ASSERT_EQ(inputs_before[3].Shape().size(), 1);
  62. EXPECT_EQ(inputs_before[3].Shape()[0], 1);
  63. ASSERT_EQ(inputs_before[4].Shape().size(), 4);
  64. EXPECT_EQ(inputs_before[4].Shape()[0], 2);
  65. EXPECT_EQ(inputs_before[4].Shape()[1], 3);
  66. EXPECT_EQ(inputs_before[4].Shape()[2], 4);
  67. EXPECT_EQ(inputs_before[4].Shape()[3], 5);
  68. // assert outputs
  69. std::vector<MSTensor> outputs_before = control_model.GetOutputs();
  70. ASSERT_EQ(1, outputs_before.size());
  71. EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
  72. ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size());
  73. ASSERT_EQ(outputs_before[0].Shape().size(), 4);
  74. EXPECT_EQ(outputs_before[0].Shape()[0], 2);
  75. EXPECT_EQ(outputs_before[0].Shape()[1], 3);
  76. EXPECT_EQ(outputs_before[0].Shape()[2], 4);
  77. EXPECT_EQ(outputs_before[0].Shape()[3], 5);
  78. // prepare input
  79. std::vector<MSTensor> outputs;
  80. std::vector<MSTensor> inputs;
  81. float x = 2.345678, y = 1.234567;
  82. bool cond1 = true, cond2 = false;
  83. inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
  84. sizeof(float));
  85. inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
  86. sizeof(float));
  87. inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &cond1,
  88. sizeof(bool));
  89. inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &cond2,
  90. sizeof(bool));
  91. inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), input_data.data(),
  92. sizeof(float) * input_data.size());
  93. // infer
  94. ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
  95. // assert output
  96. ASSERT_TRUE(outputs.size() == 1);
  97. auto out = outputs[0];
  98. ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
  99. auto out_data = out.Data();
  100. auto p = reinterpret_cast<const float *>(out_data.get());
  101. for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
  102. ASSERT_LE(std::abs(p[i] - kConstValue * 24), 1e-3);
  103. }
  104. }
  105. TEST_F(TestControl, InferSimpleWhile) {
  106. auto context = ContextAutoSet();
  107. Graph graph;
  108. ASSERT_TRUE(Serialization::Load(kSimpleWhileFile, ModelType::kMindIR, &graph));
  109. Model control_model;
  110. ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
  111. // assert inputs
  112. std::vector<MSTensor> inputs_before = control_model.GetInputs();
  113. ASSERT_EQ(3, inputs_before.size());
  114. EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeBool);
  115. EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeBool);
  116. EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeFloat32);
  117. ASSERT_EQ(inputs_before[0].DataSize(), sizeof(bool));
  118. ASSERT_EQ(inputs_before[1].DataSize(), sizeof(bool));
  119. ASSERT_EQ(inputs_before[2].DataSize(), sizeof(float) * input_data.size());
  120. ASSERT_EQ(inputs_before[0].Shape().size(), 1);
  121. EXPECT_EQ(inputs_before[0].Shape()[0], 1);
  122. ASSERT_EQ(inputs_before[1].Shape().size(), 1);
  123. EXPECT_EQ(inputs_before[1].Shape()[0], 1);
  124. ASSERT_EQ(inputs_before[2].Shape().size(), 4);
  125. EXPECT_EQ(inputs_before[2].Shape()[0], 2);
  126. EXPECT_EQ(inputs_before[2].Shape()[1], 3);
  127. EXPECT_EQ(inputs_before[2].Shape()[2], 4);
  128. EXPECT_EQ(inputs_before[2].Shape()[3], 5);
  129. // assert outputs
  130. std::vector<MSTensor> outputs_before = control_model.GetOutputs();
  131. ASSERT_EQ(1, outputs_before.size());
  132. EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
  133. ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size());
  134. ASSERT_EQ(outputs_before[0].Shape().size(), 4);
  135. EXPECT_EQ(outputs_before[0].Shape()[0], 2);
  136. EXPECT_EQ(outputs_before[0].Shape()[1], 3);
  137. EXPECT_EQ(outputs_before[0].Shape()[2], 4);
  138. EXPECT_EQ(outputs_before[0].Shape()[3], 5);
  139. // prepare input
  140. std::vector<MSTensor> outputs;
  141. std::vector<MSTensor> inputs;
  142. {
  143. bool x = true, y = false;
  144. inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
  145. sizeof(bool));
  146. inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
  147. sizeof(bool));
  148. inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(),
  149. input_data.data(), sizeof(float) * input_data.size());
  150. }
  151. // infer
  152. ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
  153. // assert output
  154. ASSERT_TRUE(outputs.size() == 1);
  155. auto out = outputs[0];
  156. ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
  157. auto out_data = out.Data();
  158. auto p = reinterpret_cast<const float *>(out_data.get());
  159. for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
  160. ASSERT_LE(std::abs(p[i] - kConstValue * 3), 1e-3);
  161. }
  162. }
  163. TEST_F(TestControl, InferRecursive) {
  164. auto context = ContextAutoSet();
  165. Graph graph;
  166. ASSERT_TRUE(Serialization::Load(kRecursiveFile, ModelType::kMindIR, &graph));
  167. Model control_model;
  168. ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
  169. // assert inputs
  170. std::vector<MSTensor> inputs_before = control_model.GetInputs();
  171. ASSERT_EQ(1, inputs_before.size());
  172. EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
  173. ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
  174. ASSERT_EQ(inputs_before[0].Shape().size(), 1);
  175. EXPECT_EQ(inputs_before[0].Shape()[0], 1);
  176. // assert outputs
  177. std::vector<MSTensor> outputs_before = control_model.GetOutputs();
  178. ASSERT_EQ(1, outputs_before.size());
  179. EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
  180. ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
  181. ASSERT_EQ(outputs_before[0].Shape().size(), 1);
  182. EXPECT_EQ(outputs_before[0].Shape()[0], 1);
  183. // prepare input
  184. std::vector<MSTensor> outputs;
  185. std::vector<MSTensor> inputs;
  186. {
  187. int32_t x = 7;
  188. inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
  189. sizeof(int32_t));
  190. }
  191. // infer
  192. ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
  193. // assert output
  194. ASSERT_TRUE(outputs.size() == 1);
  195. auto out = outputs[0];
  196. ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
  197. auto out_data = out.Data();
  198. auto p = reinterpret_cast<const int32_t *>(out_data.get());
  199. ASSERT_EQ(*p, 21);
  200. }
  201. TEST_F(TestControl, InferMixedWhileIf) {
  202. auto context = ContextAutoSet();
  203. Graph graph;
  204. ASSERT_TRUE(Serialization::Load(kMixIfWhileFile, ModelType::kMindIR, &graph));
  205. Model control_model;
  206. ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
  207. // assert inputs
  208. std::vector<MSTensor> inputs_before = control_model.GetInputs();
  209. ASSERT_EQ(inputs_before.size(), 5);
  210. EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
  211. EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
  212. EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
  213. EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeInt32);
  214. EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeInt32);
  215. ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
  216. ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
  217. ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
  218. ASSERT_EQ(inputs_before[3].DataSize(), sizeof(int32_t));
  219. ASSERT_EQ(inputs_before[4].DataSize(), sizeof(int32_t));
  220. ASSERT_EQ(inputs_before[0].Shape().size(), 1);
  221. EXPECT_EQ(inputs_before[0].Shape()[0], 1);
  222. ASSERT_EQ(inputs_before[1].Shape().size(), 1);
  223. EXPECT_EQ(inputs_before[1].Shape()[0], 1);
  224. ASSERT_EQ(inputs_before[2].Shape().size(), 1);
  225. EXPECT_EQ(inputs_before[2].Shape()[0], 1);
  226. ASSERT_EQ(inputs_before[3].Shape().size(), 1);
  227. EXPECT_EQ(inputs_before[3].Shape()[0], 1);
  228. ASSERT_EQ(inputs_before[4].Shape().size(), 1);
  229. EXPECT_EQ(inputs_before[4].Shape()[0], 1);
  230. // assert outputs
  231. std::vector<MSTensor> outputs_before = control_model.GetOutputs();
  232. ASSERT_EQ(1, outputs_before.size());
  233. EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
  234. ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
  235. ASSERT_EQ(outputs_before[0].Shape().size(), 1);
  236. EXPECT_EQ(outputs_before[0].Shape()[0], 1);
  237. // prepare input
  238. std::vector<MSTensor> outputs;
  239. std::vector<MSTensor> inputs;
  240. {
  241. int32_t x = 2, y = 14, z = 1, c2 = 14, c4 = 0;
  242. inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
  243. sizeof(int32_t));
  244. inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
  245. sizeof(int32_t));
  246. inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
  247. sizeof(int32_t));
  248. inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &c2,
  249. sizeof(int32_t));
  250. inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), &c4,
  251. sizeof(int32_t));
  252. }
  253. // infer
  254. ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
  255. // assert output
  256. ASSERT_TRUE(outputs.size() == 1);
  257. auto out = outputs[0];
  258. ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
  259. auto out_data = out.Data();
  260. auto p = reinterpret_cast<const int32_t *>(out_data.get());
  261. ASSERT_EQ(*p, 350);
  262. }
  263. TEST_F(TestControl, InferSingleFor) {
  264. auto context = ContextAutoSet();
  265. Graph graph;
  266. ASSERT_TRUE(Serialization::Load(kSingleForFile, ModelType::kMindIR, &graph));
  267. Model control_model;
  268. ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
  269. // assert inputs
  270. std::vector<MSTensor> inputs_before = control_model.GetInputs();
  271. ASSERT_EQ(inputs_before.size(), 3);
  272. EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
  273. EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
  274. EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
  275. ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
  276. ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
  277. ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
  278. ASSERT_EQ(inputs_before[0].Shape().size(), 1);
  279. EXPECT_EQ(inputs_before[0].Shape()[0], 1);
  280. ASSERT_EQ(inputs_before[1].Shape().size(), 1);
  281. EXPECT_EQ(inputs_before[1].Shape()[0], 1);
  282. ASSERT_EQ(inputs_before[2].Shape().size(), 1);
  283. EXPECT_EQ(inputs_before[2].Shape()[0], 1);
  284. // assert outputs
  285. std::vector<MSTensor> outputs_before = control_model.GetOutputs();
  286. ASSERT_EQ(1, outputs_before.size());
  287. EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
  288. ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
  289. ASSERT_EQ(outputs_before[0].Shape().size(), 1);
  290. EXPECT_EQ(outputs_before[0].Shape()[0], 1);
  291. // prepare input
  292. std::vector<MSTensor> outputs;
  293. std::vector<MSTensor> inputs;
  294. {
  295. int32_t x = 2, y = 5, z = 4;
  296. inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
  297. sizeof(int32_t));
  298. inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
  299. sizeof(int32_t));
  300. inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
  301. sizeof(int32_t));
  302. }
  303. // infer
  304. ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
  305. // assert output
  306. ASSERT_TRUE(outputs.size() == 1);
  307. auto out = outputs[0];
  308. ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
  309. auto out_data = out.Data();
  310. auto p = reinterpret_cast<const int32_t *>(out_data.get());
  311. ASSERT_EQ(*p, 125);
  312. }
  313. TEST_F(TestControl, InferSingleOr) {
  314. auto context = ContextAutoSet();
  315. Graph graph;
  316. ASSERT_TRUE(Serialization::Load(kSingleOrFile, ModelType::kMindIR, &graph));
  317. Model control_model;
  318. ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
  319. // assert inputs
  320. std::vector<MSTensor> inputs_before = control_model.GetInputs();
  321. ASSERT_EQ(inputs_before.size(), 2);
  322. EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
  323. EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
  324. ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 2);
  325. ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float) * 2);
  326. ASSERT_EQ(inputs_before[0].Shape().size(), 1);
  327. EXPECT_EQ(inputs_before[0].Shape()[0], 2);
  328. ASSERT_EQ(inputs_before[1].Shape().size(), 1);
  329. EXPECT_EQ(inputs_before[1].Shape()[0], 2);
  330. // assert outputs
  331. std::vector<MSTensor> outputs_before = control_model.GetOutputs();
  332. ASSERT_EQ(1, outputs_before.size());
  333. EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
  334. ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float));
  335. // prepare input
  336. std::vector<MSTensor> outputs;
  337. std::vector<MSTensor> inputs;
  338. {
  339. static const std::vector<float> input_data1 = {0, 1};
  340. static const std::vector<float> input_data2 = {0, 0};
  341. inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
  342. input_data1.data(), sizeof(float) * input_data1.size());
  343. inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(),
  344. input_data2.data(), sizeof(int32_t) * input_data2.size());
  345. }
  346. // infer
  347. ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
  348. // assert outputs
  349. std::vector<MSTensor> outputs_after = control_model.GetOutputs();
  350. ASSERT_EQ(1, outputs_after.size());
  351. EXPECT_EQ(outputs_after[0].DataType(), DataType::kNumberTypeFloat32);
  352. ASSERT_TRUE(outputs_after[0].DataSize() == sizeof(float));
  353. EXPECT_EQ(outputs_after[0].Shape().size(), outputs_before[0].Shape().size());
  354. // assert output
  355. ASSERT_TRUE(outputs.size() == 1);
  356. auto out = outputs[0];
  357. ASSERT_TRUE(out.DataSize() == sizeof(float));
  358. auto out_data = out.Data();
  359. auto p = reinterpret_cast<const float *>(out_data.get());
  360. ASSERT_EQ(*p, 1);
  361. }
  362. TEST_F(TestControl, InferSingleSwitch) {
  363. auto context = ContextAutoSet();
  364. Graph graph;
  365. ASSERT_TRUE(Serialization::Load(kSingleSwitchFile, ModelType::kMindIR, &graph));
  366. Model control_model;
  367. ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
  368. // assert inputs
  369. std::vector<MSTensor> inputs_before = control_model.GetInputs();
  370. ASSERT_EQ(inputs_before.size(), 3);
  371. EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
  372. EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
  373. EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
  374. ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 224 * 224);
  375. ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
  376. ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
  377. ASSERT_EQ(inputs_before[0].Shape().size(), 4);
  378. EXPECT_EQ(inputs_before[0].Shape()[0], 1);
  379. EXPECT_EQ(inputs_before[0].Shape()[1], 1);
  380. EXPECT_EQ(inputs_before[0].Shape()[2], 224);
  381. EXPECT_EQ(inputs_before[0].Shape()[3], 224);
  382. ASSERT_EQ(inputs_before[1].Shape().size(), 1);
  383. EXPECT_EQ(inputs_before[1].Shape()[0], 1);
  384. ASSERT_EQ(inputs_before[2].Shape().size(), 1);
  385. EXPECT_EQ(inputs_before[2].Shape()[0], 1);
  386. // assert outputs
  387. std::vector<MSTensor> outputs_before = control_model.GetOutputs();
  388. ASSERT_EQ(1, outputs_before.size());
  389. EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
  390. ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * 224 * 224);
  391. ASSERT_EQ(outputs_before[0].Shape().size(), 4);
  392. EXPECT_EQ(outputs_before[0].Shape()[0], 1);
  393. EXPECT_EQ(outputs_before[0].Shape()[1], 1);
  394. EXPECT_EQ(outputs_before[0].Shape()[2], 224);
  395. EXPECT_EQ(outputs_before[0].Shape()[3], 224);
  396. // prepare input
  397. std::vector<MSTensor> outputs;
  398. std::vector<MSTensor> inputs;
  399. {
  400. static const std::vector<float> input_data1(1 * 1 * 224 * 224, 1);
  401. int32_t index1 = 0;
  402. int32_t index2 = -1;
  403. inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
  404. input_data1.data(), sizeof(float) * input_data1.size());
  405. inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &index1,
  406. sizeof(int32_t));
  407. inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &index2,
  408. sizeof(int32_t));
  409. }
  410. // infer
  411. ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
  412. // assert output
  413. ASSERT_TRUE(outputs.size() == 1);
  414. auto out = outputs[0];
  415. ASSERT_TRUE(out.DataSize() == sizeof(float) * 224 * 224);
  416. auto out_data = out.Data();
  417. auto p = reinterpret_cast<const float *>(out_data.get());
  418. for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
  419. ASSERT_EQ(p[i], 1);
  420. }
  421. }