You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cc 24 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <cstring>
  18. #include <random>
  19. #include <fstream>
  20. #include <thread>
  21. #include <algorithm>
  22. #include "include/errorcode.h"
  23. #include "include/model.h"
  24. #include "include/context.h"
  25. #include "include/lite_session.h"
  26. #include "include/version.h"
  27. std::string RealPath(const char *path) {
  28. const size_t max = 4096;
  29. if (path == nullptr) {
  30. std::cerr << "path is nullptr" << std::endl;
  31. return "";
  32. }
  33. if ((strlen(path)) >= max) {
  34. std::cerr << "path is too long" << std::endl;
  35. return "";
  36. }
  37. auto resolved_path = std::make_unique<char[]>(max);
  38. if (resolved_path == nullptr) {
  39. std::cerr << "new resolved_path failed" << std::endl;
  40. return "";
  41. }
  42. #ifdef _WIN32
  43. char *real_path = _fullpath(resolved_path.get(), path, 1024);
  44. #else
  45. char *real_path = realpath(path, resolved_path.get());
  46. #endif
  47. if (real_path == nullptr || strlen(real_path) == 0) {
  48. std::cerr << "file path is not valid : " << path << std::endl;
  49. return "";
  50. }
  51. std::string res = resolved_path.get();
  52. return res;
  53. }
  54. char *ReadFile(const char *file, size_t *size) {
  55. if (file == nullptr) {
  56. std::cerr << "file is nullptr." << std::endl;
  57. return nullptr;
  58. }
  59. std::ifstream ifs(file);
  60. if (!ifs.good()) {
  61. std::cerr << "file: " << file << " is not exist." << std::endl;
  62. return nullptr;
  63. }
  64. if (!ifs.is_open()) {
  65. std::cerr << "file: " << file << " open failed." << std::endl;
  66. return nullptr;
  67. }
  68. ifs.seekg(0, std::ios::end);
  69. *size = ifs.tellg();
  70. std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
  71. if (buf == nullptr) {
  72. std::cerr << "malloc buf failed, file: " << file << std::endl;
  73. ifs.close();
  74. return nullptr;
  75. }
  76. ifs.seekg(0, std::ios::beg);
  77. ifs.read(buf.get(), *size);
  78. ifs.close();
  79. return buf.release();
  80. }
  81. template <typename T, typename Distribution>
  82. void GenerateRandomData(int size, void *data, Distribution distribution) {
  83. if (data == nullptr) {
  84. std::cerr << "data is nullptr." << std::endl;
  85. return;
  86. }
  87. std::mt19937 random_engine;
  88. int elements_num = size / sizeof(T);
  89. (void)std::generate_n(static_cast<T *>(data), elements_num,
  90. [&]() { return static_cast<T>(distribution(random_engine)); });
  91. }
  92. std::shared_ptr<mindspore::lite::Context> CreateCPUContext() {
  93. auto context = std::make_shared<mindspore::lite::Context>();
  94. if (context == nullptr) {
  95. std::cerr << "New context failed while running." << std::endl;
  96. return nullptr;
  97. }
  98. // Configure the number of worker threads in the thread pool to 2, including the main thread.
  99. context->thread_num_ = 2;
  100. // CPU device context has default values.
  101. auto &cpu_device_info = context->device_list_[0].device_info_.cpu_device_info_;
  102. // The large core takes priority in thread and core binding methods. This parameter will work in the BindThread
  103. // interface. For specific binding effect, see the "Run Graph" section.
  104. cpu_device_info.cpu_bind_mode_ = mindspore::lite::HIGHER_CPU;
  105. // Use float16 operator as priority.
  106. cpu_device_info.enable_float16_ = true;
  107. return context;
  108. }
  109. std::shared_ptr<mindspore::lite::Context> CreateGPUContext() {
  110. auto context = std::make_shared<mindspore::lite::Context>();
  111. if (context == nullptr) {
  112. std::cerr << "New context failed while running. " << std::endl;
  113. return nullptr;
  114. }
  115. // If GPU device context is set. The preferred backend is GPU, which means, if there is a GPU operator, it will run on
  116. // the GPU first, otherwise it will run on the CPU.
  117. mindspore::lite::DeviceContext gpu_device_ctx{mindspore::lite::DT_GPU, {false}};
  118. // GPU use float16 operator as priority.
  119. gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = true;
  120. // The GPU device context needs to be push_back into device_list to work.
  121. context->device_list_.push_back(gpu_device_ctx);
  122. return context;
  123. }
  124. std::shared_ptr<mindspore::lite::Context> CreateNPUContext() {
  125. auto context = std::make_shared<mindspore::lite::Context>();
  126. if (context == nullptr) {
  127. std::cerr << "New context failed while running. " << std::endl;
  128. return nullptr;
  129. }
  130. mindspore::lite::DeviceContext npu_device_ctx{mindspore::lite::DT_NPU};
  131. npu_device_ctx.device_info_.npu_device_info_.frequency_ = 3;
  132. // The NPU device context needs to be push_back into device_list to work.
  133. context->device_list_.push_back(npu_device_ctx);
  134. return context;
  135. }
  136. int GetInputsAndSetData(mindspore::session::LiteSession *session) {
  137. auto inputs = session->GetInputs();
  138. // The model has only one input tensor.
  139. auto in_tensor = inputs.front();
  140. if (in_tensor == nullptr) {
  141. std::cerr << "Input tensor is nullptr" << std::endl;
  142. return -1;
  143. }
  144. auto input_data = in_tensor->MutableData();
  145. if (input_data == nullptr) {
  146. std::cerr << "MallocData for inTensor failed." << std::endl;
  147. return -1;
  148. }
  149. GenerateRandomData<float>(in_tensor->Size(), input_data, std::uniform_real_distribution<float>(0.1f, 1.0f));
  150. return 0;
  151. }
  152. int GetInputsByTensorNameAndSetData(mindspore::session::LiteSession *session) {
  153. auto in_tensor = session->GetInputsByTensorName("2031_2030_1_construct_wrapper:x");
  154. if (in_tensor == nullptr) {
  155. std::cerr << "Input tensor is nullptr" << std::endl;
  156. return -1;
  157. }
  158. auto input_data = in_tensor->MutableData();
  159. if (input_data == nullptr) {
  160. std::cerr << "MallocData for inTensor failed." << std::endl;
  161. return -1;
  162. }
  163. GenerateRandomData<float>(in_tensor->Size(), input_data, std::uniform_real_distribution<float>(0.1f, 1.0f));
  164. return 0;
  165. }
  166. void GetOutputsByNodeName(mindspore::session::LiteSession *session) {
  167. // model has a output node named output_node_name_0.
  168. auto output_vec = session->GetOutputsByNodeName("Default/head-MobileNetV2Head/Softmax-op204");
  169. // output node named output_node_name_0 has only one output tensor.
  170. auto out_tensor = output_vec.front();
  171. if (out_tensor == nullptr) {
  172. std::cerr << "Output tensor is nullptr" << std::endl;
  173. return;
  174. }
  175. std::cout << "tensor size is:" << out_tensor->Size() << " tensor elements num is:" << out_tensor->ElementsNum()
  176. << std::endl;
  177. // The model output data is float 32.
  178. if (out_tensor->data_type() != mindspore::TypeId::kNumberTypeFloat32) {
  179. std::cerr << "Output should in float32" << std::endl;
  180. return;
  181. }
  182. auto out_data = reinterpret_cast<float *>(out_tensor->MutableData());
  183. if (out_data == nullptr) {
  184. std::cerr << "Data of out_tensor is nullptr" << std::endl;
  185. return;
  186. }
  187. std::cout << "output data is:";
  188. for (int i = 0; i < out_tensor->ElementsNum() && i < 10; i++) {
  189. std::cout << out_data[i] << " ";
  190. }
  191. std::cout << std::endl;
  192. }
  193. void GetOutputByTensorName(mindspore::session::LiteSession *session) {
  194. // We can use GetOutputTensorNames method to get all name of output tensor of model which is in order.
  195. auto tensor_names = session->GetOutputTensorNames();
  196. // Use output tensor name returned by GetOutputTensorNames as key
  197. for (const auto &tensor_name : tensor_names) {
  198. auto out_tensor = session->GetOutputByTensorName(tensor_name);
  199. if (out_tensor == nullptr) {
  200. std::cerr << "Output tensor is nullptr" << std::endl;
  201. return;
  202. }
  203. std::cout << "tensor size is:" << out_tensor->Size() << " tensor elements num is:" << out_tensor->ElementsNum()
  204. << std::endl;
  205. // The model output data is float 32.
  206. if (out_tensor->data_type() != mindspore::TypeId::kNumberTypeFloat32) {
  207. std::cerr << "Output should in float32" << std::endl;
  208. return;
  209. }
  210. auto out_data = reinterpret_cast<float *>(out_tensor->MutableData());
  211. if (out_data == nullptr) {
  212. std::cerr << "Data of out_tensor is nullptr" << std::endl;
  213. return;
  214. }
  215. std::cout << "output data is:";
  216. for (int i = 0; i < out_tensor->ElementsNum() && i < 10; i++) {
  217. std::cout << out_data[i] << " ";
  218. }
  219. std::cout << std::endl;
  220. }
  221. }
  222. void GetOutputs(mindspore::session::LiteSession *session) {
  223. auto out_tensors = session->GetOutputs();
  224. for (auto out_tensor : out_tensors) {
  225. std::cout << "tensor name is:" << out_tensor.first << " tensor size is:" << out_tensor.second->Size()
  226. << " tensor elements num is:" << out_tensor.second->ElementsNum() << std::endl;
  227. // The model output data is float 32.
  228. if (out_tensor.second->data_type() != mindspore::TypeId::kNumberTypeFloat32) {
  229. std::cerr << "Output should in float32" << std::endl;
  230. return;
  231. }
  232. auto out_data = reinterpret_cast<float *>(out_tensor.second->MutableData());
  233. if (out_data == nullptr) {
  234. std::cerr << "Data of out_tensor is nullptr" << std::endl;
  235. return;
  236. }
  237. std::cout << "output data is:";
  238. for (int i = 0; i < out_tensor.second->ElementsNum() && i < 10; i++) {
  239. std::cout << out_data[i] << " ";
  240. }
  241. std::cout << std::endl;
  242. }
  243. }
  244. mindspore::session::LiteSession *CreateSessionAndCompileByModel(mindspore::lite::Model *model) {
  245. // Create and init CPU context.
  246. // If you need to use GPU or NPU, you can refer to CreateGPUContext() or CreateNPUContext().
  247. auto context = CreateCPUContext();
  248. if (context == nullptr) {
  249. std::cerr << "New context failed while." << std::endl;
  250. return nullptr;
  251. }
  252. // Create the session.
  253. mindspore::session::LiteSession *session = mindspore::session::LiteSession::CreateSession(context.get());
  254. if (session == nullptr) {
  255. std::cerr << "CreateSession failed while running." << std::endl;
  256. return nullptr;
  257. }
  258. // Compile graph.
  259. auto ret = session->CompileGraph(model);
  260. if (ret != mindspore::lite::RET_OK) {
  261. delete session;
  262. std::cerr << "Compile failed while running." << std::endl;
  263. return nullptr;
  264. }
  265. return session;
  266. }
  267. mindspore::session::LiteSession *CreateSessionAndCompileByModelBuffer(char *model_buf, size_t size) {
  268. auto context = std::make_shared<mindspore::lite::Context>();
  269. if (context == nullptr) {
  270. std::cerr << "New context failed while running" << std::endl;
  271. return nullptr;
  272. }
  273. // Use model buffer and context to create Session.
  274. auto session = mindspore::session::LiteSession::CreateSession(model_buf, size, context.get());
  275. if (session == nullptr) {
  276. std::cerr << "CreateSession failed while running" << std::endl;
  277. return nullptr;
  278. }
  279. return session;
  280. }
  281. int ResizeInputsTensorShape(mindspore::session::LiteSession *session) {
  282. auto inputs = session->GetInputs();
  283. std::vector<int> resize_shape = {1, 128, 128, 3};
  284. // Assume the model has only one input,resize input shape to [1, 128, 128, 3]
  285. std::vector<std::vector<int>> new_shapes;
  286. new_shapes.push_back(resize_shape);
  287. return session->Resize(inputs, new_shapes);
  288. }
  289. int Run(const char *model_path) {
  290. // Read model file.
  291. size_t size = 0;
  292. char *model_buf = ReadFile(model_path, &size);
  293. if (model_buf == nullptr) {
  294. std::cerr << "Read model file failed." << std::endl;
  295. return -1;
  296. }
  297. // Load the .ms model.
  298. auto model = mindspore::lite::Model::Import(model_buf, size);
  299. delete[](model_buf);
  300. if (model == nullptr) {
  301. std::cerr << "Import model file failed." << std::endl;
  302. return -1;
  303. }
  304. // Compile MindSpore Lite model.
  305. auto session = CreateSessionAndCompileByModel(model);
  306. if (session == nullptr) {
  307. delete model;
  308. std::cerr << "Create session failed." << std::endl;
  309. return -1;
  310. }
  311. // Note: when use model->Free(), the model can not be compiled again.
  312. model->Free();
  313. // Set inputs data.
  314. // You can also get input through other methods, and you can refer to GetInputsAndSetData()
  315. GetInputsByTensorNameAndSetData(session);
  316. session->BindThread(true);
  317. auto ret = session->RunGraph();
  318. if (ret != mindspore::lite::RET_OK) {
  319. delete model;
  320. delete session;
  321. std::cerr << "Inference error " << ret << std::endl;
  322. return ret;
  323. }
  324. session->BindThread(false);
  325. // Get outputs data.
  326. // You can also get output through other methods,
  327. // and you can refer to GetOutputByTensorName() or GetOutputs().
  328. GetOutputsByNodeName(session);
  329. // Delete model buffer.
  330. delete model;
  331. // Delete session buffer.
  332. delete session;
  333. return 0;
  334. }
  335. int RunResize(const char *model_path) {
  336. size_t size = 0;
  337. char *model_buf = ReadFile(model_path, &size);
  338. if (model_buf == nullptr) {
  339. std::cerr << "Read model file failed." << std::endl;
  340. return -1;
  341. }
  342. // Load the .ms model.
  343. auto model = mindspore::lite::Model::Import(model_buf, size);
  344. delete[](model_buf);
  345. if (model == nullptr) {
  346. std::cerr << "Import model file failed." << std::endl;
  347. return -1;
  348. }
  349. // Compile MindSpore Lite model.
  350. auto session = CreateSessionAndCompileByModel(model);
  351. if (session == nullptr) {
  352. delete model;
  353. std::cerr << "Create session failed." << std::endl;
  354. return -1;
  355. }
  356. // Resize inputs tensor shape.
  357. auto ret = ResizeInputsTensorShape(session);
  358. if (ret != mindspore::lite::RET_OK) {
  359. delete model;
  360. delete session;
  361. std::cerr << "Resize input tensor shape error." << ret << std::endl;
  362. return ret;
  363. }
  364. // Set inputs data.
  365. // You can also get input through other methods, and you can refer to GetInputsAndSetData()
  366. GetInputsByTensorNameAndSetData(session);
  367. session->BindThread(true);
  368. ret = session->RunGraph();
  369. if (ret != mindspore::lite::RET_OK) {
  370. delete model;
  371. delete session;
  372. std::cerr << "Inference error " << ret << std::endl;
  373. return ret;
  374. }
  375. session->BindThread(false);
  376. // Get outputs data.
  377. // You can also get output through other methods,
  378. // and you can refer to GetOutputByTensorName() or GetOutputs().
  379. GetOutputsByNodeName(session);
  380. // Delete model buffer.
  381. delete model;
  382. // Delete session buffer.
  383. delete session;
  384. return 0;
  385. }
  386. int RunCreateSessionSimplified(const char *model_path) {
  387. size_t size = 0;
  388. char *model_buf = ReadFile(model_path, &size);
  389. if (model_buf == nullptr) {
  390. std::cerr << "Read model file failed." << std::endl;
  391. return -1;
  392. }
  393. // Compile MindSpore Lite model.
  394. auto session = CreateSessionAndCompileByModelBuffer(model_buf, size);
  395. if (session == nullptr) {
  396. std::cerr << "Create session failed." << std::endl;
  397. return -1;
  398. }
  399. // Set inputs data.
  400. // You can also get input through other methods, and you can refer to GetInputsAndSetData()
  401. GetInputsByTensorNameAndSetData(session);
  402. session->BindThread(true);
  403. auto ret = session->RunGraph();
  404. if (ret != mindspore::lite::RET_OK) {
  405. delete session;
  406. std::cerr << "Inference error " << ret << std::endl;
  407. return ret;
  408. }
  409. session->BindThread(false);
  410. // Get outputs data.
  411. // You can also get output through other methods,
  412. // and you can refer to GetOutputByTensorName() or GetOutputs().
  413. GetOutputsByNodeName(session);
  414. // Delete session buffer.
  415. delete session;
  416. return 0;
  417. }
  418. int RunSessionParallel(const char *model_path) {
  419. size_t size = 0;
  420. char *model_buf = ReadFile(model_path, &size);
  421. if (model_buf == nullptr) {
  422. std::cerr << "Read model file failed." << std::endl;
  423. return -1;
  424. }
  425. // Load the .ms model.
  426. auto model = mindspore::lite::Model::Import(model_buf, size);
  427. delete[](model_buf);
  428. if (model == nullptr) {
  429. std::cerr << "Import model file failed." << std::endl;
  430. return -1;
  431. }
  432. // Compile MindSpore Lite model.
  433. auto session1 = CreateSessionAndCompileByModel(model);
  434. if (session1 == nullptr) {
  435. delete model;
  436. std::cerr << "Create session failed." << std::endl;
  437. return -1;
  438. }
  439. // Compile MindSpore Lite model.
  440. auto session2 = CreateSessionAndCompileByModel(model);
  441. if (session2 == nullptr) {
  442. delete model;
  443. std::cerr << "Create session failed." << std::endl;
  444. return -1;
  445. }
  446. // Note: when use model->Free(), the model can not be compiled again.
  447. model->Free();
  448. std::thread thread1([&]() {
  449. GetInputsByTensorNameAndSetData(session1);
  450. auto status = session1->RunGraph();
  451. if (status != 0) {
  452. std::cerr << "Inference error " << status << std::endl;
  453. return;
  454. }
  455. std::cout << "Session1 inference success" << std::endl;
  456. });
  457. std::thread thread2([&]() {
  458. GetInputsByTensorNameAndSetData(session2);
  459. auto status = session2->RunGraph();
  460. if (status != 0) {
  461. std::cerr << "Inference error " << status << std::endl;
  462. return;
  463. }
  464. std::cout << "Session2 inference success" << std::endl;
  465. });
  466. thread1.join();
  467. thread2.join();
  468. // Get outputs data.
  469. // You can also get output through other methods,
  470. // and you can refer to GetOutputByTensorName() or GetOutputs().
  471. GetOutputsByNodeName(session1);
  472. GetOutputsByNodeName(session2);
  473. // Delete model buffer.
  474. if (model != nullptr) {
  475. delete model;
  476. model = nullptr;
  477. }
  478. // Delete session buffer.
  479. delete session1;
  480. delete session2;
  481. return 0;
  482. }
  483. int RunWithSharedMemoryPool(const char *model_path) {
  484. size_t size = 0;
  485. char *model_buf = ReadFile(model_path, &size);
  486. if (model_buf == nullptr) {
  487. std::cerr << "Read model file failed." << std::endl;
  488. return -1;
  489. }
  490. auto model = mindspore::lite::Model::Import(model_buf, size);
  491. delete[](model_buf);
  492. if (model == nullptr) {
  493. std::cerr << "Import model file failed." << std::endl;
  494. return -1;
  495. }
  496. auto context1 = std::make_shared<mindspore::lite::Context>();
  497. if (context1 == nullptr) {
  498. delete model;
  499. std::cerr << "New context failed while running." << std::endl;
  500. return -1;
  501. }
  502. auto session1 = mindspore::session::LiteSession::CreateSession(context1.get());
  503. if (session1 == nullptr) {
  504. delete model;
  505. std::cerr << "CreateSession failed while running." << std::endl;
  506. return -1;
  507. }
  508. auto ret = session1->CompileGraph(model);
  509. if (ret != mindspore::lite::RET_OK) {
  510. delete model;
  511. delete session1;
  512. std::cerr << "Compile failed while running." << std::endl;
  513. return -1;
  514. }
  515. auto context2 = std::make_shared<mindspore::lite::Context>();
  516. if (context2 == nullptr) {
  517. delete model;
  518. std::cerr << "New context failed while running." << std::endl;
  519. return -1;
  520. }
  521. // Use the same allocator to share the memory pool.
  522. context2->allocator = context1->allocator;
  523. auto session2 = mindspore::session::LiteSession::CreateSession(context2.get());
  524. if (session2 == nullptr) {
  525. delete model;
  526. delete session1;
  527. std::cerr << "CreateSession failed while running " << std::endl;
  528. return -1;
  529. }
  530. ret = session2->CompileGraph(model);
  531. if (ret != mindspore::lite::RET_OK) {
  532. delete model;
  533. delete session1;
  534. delete session2;
  535. std::cerr << "Compile failed while running " << std::endl;
  536. return -1;
  537. }
  538. // Note: when use model->Free(), the model can not be compiled again.
  539. model->Free();
  540. // Set inputs data.
  541. // You can also get input through other methods, and you can refer to GetInputsAndSetData()
  542. GetInputsByTensorNameAndSetData(session1);
  543. GetInputsByTensorNameAndSetData(session2);
  544. ret = session1->RunGraph();
  545. if (ret != mindspore::lite::RET_OK) {
  546. std::cerr << "Inference error " << ret << std::endl;
  547. return ret;
  548. }
  549. ret = session2->RunGraph();
  550. if (ret != mindspore::lite::RET_OK) {
  551. delete model;
  552. delete session1;
  553. delete session2;
  554. std::cerr << "Inference error " << ret << std::endl;
  555. return ret;
  556. }
  557. // Get outputs data.
  558. // You can also get output through other methods,
  559. // and you can refer to GetOutputByTensorName() or GetOutputs().
  560. GetOutputsByNodeName(session1);
  561. GetOutputsByNodeName(session2);
  562. // Delete model buffer.
  563. delete model;
  564. // Delete session buffer.
  565. delete session1;
  566. delete session2;
  567. return 0;
  568. }
  569. int RunCallback(const char *model_path) {
  570. size_t size = 0;
  571. char *model_buf = ReadFile(model_path, &size);
  572. if (model_buf == nullptr) {
  573. std::cerr << "Read model file failed." << std::endl;
  574. return -1;
  575. }
  576. // Load the .ms model.
  577. auto model = mindspore::lite::Model::Import(model_buf, size);
  578. delete[](model_buf);
  579. if (model == nullptr) {
  580. std::cerr << "Import model file failed." << std::endl;
  581. return -1;
  582. }
  583. // Compile MindSpore Lite model.
  584. auto session = CreateSessionAndCompileByModel(model);
  585. if (session == nullptr) {
  586. delete model;
  587. std::cerr << "Create session failed." << std::endl;
  588. return -1;
  589. }
  590. // Note: when use model->Free(), the model can not be compiled again.
  591. model->Free();
  592. // Set inputs data.
  593. // You can also get input through other methods, and you can refer to GetInputsAndSetData()
  594. GetInputsByTensorNameAndSetData(session);
  595. // Definition of callback function before forwarding operator.
  596. auto before_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
  597. const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
  598. const mindspore::CallBackParam &call_param) {
  599. std::cout << "Before forwarding " << call_param.node_name << " " << call_param.node_type << std::endl;
  600. return true;
  601. };
  602. // Definition of callback function after forwarding operator.
  603. auto after_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
  604. const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
  605. const mindspore::CallBackParam &call_param) {
  606. std::cout << "After forwarding " << call_param.node_name << " " << call_param.node_type << std::endl;
  607. return true;
  608. };
  609. session->BindThread(true);
  610. auto ret = session->RunGraph(before_call_back, after_call_back);
  611. if (ret != mindspore::lite::RET_OK) {
  612. delete model;
  613. delete session;
  614. std::cerr << "Inference error " << ret << std::endl;
  615. return ret;
  616. }
  617. session->BindThread(false);
  618. // Get outputs data.
  619. // You can also get output through other methods,
  620. // and you can refer to GetOutputByTensorName() or GetOutputs().
  621. GetOutputsByNodeName(session);
  622. // Delete model buffer.
  623. delete model;
  624. // Delete session buffer.
  625. delete session;
  626. return 0;
  627. }
  628. int main(int argc, const char **argv) {
  629. if (argc < 3) {
  630. std::cerr << "Usage: ./runtime_cpp model_path Option" << std::endl;
  631. std::cerr << "Example: ./runtime_cpp ../model/mobilenetv2.ms 0" << std::endl;
  632. std::cerr << "When your Option is 0, you will run MindSpore Lite inference." << std::endl;
  633. std::cerr << "When your Option is 1, you will run MindSpore Lite inference with resize." << std::endl;
  634. std::cerr << "When your Option is 2, you will run MindSpore Lite inference with CreateSession simplified API."
  635. << std::endl;
  636. std::cerr << "When your Option is 3, you will run MindSpore Lite inference with session parallel." << std::endl;
  637. std::cerr << "When your Option is 4, you will run MindSpore Lite inference with shared memory pool." << std::endl;
  638. std::cerr << "When your Option is 5, you will run MindSpore Lite inference with callback." << std::endl;
  639. return -1;
  640. }
  641. std::string version = mindspore::lite::Version();
  642. std::cout << "MindSpore Lite Version is " << version << std::endl;
  643. auto model_path = RealPath(argv[1]);
  644. if (model_path.empty()) {
  645. std::cerr << "model path " << argv[1] << " is invalid.";
  646. return -1;
  647. }
  648. auto flag = argv[2];
  649. if (strcmp(flag, "0") == 0) {
  650. return Run(model_path.c_str());
  651. } else if (strcmp(flag, "1") == 0) {
  652. return RunResize(model_path.c_str());
  653. } else if (strcmp(flag, "2") == 0) {
  654. return RunCreateSessionSimplified(model_path.c_str());
  655. } else if (strcmp(flag, "3") == 0) {
  656. return RunSessionParallel(model_path.c_str());
  657. } else if (strcmp(flag, "4") == 0) {
  658. return RunWithSharedMemoryPool(model_path.c_str());
  659. } else if (strcmp(flag, "5") == 0) {
  660. return RunCallback(model_path.c_str());
  661. } else {
  662. std::cerr << "Unsupported Flag " << flag << std::endl;
  663. return -1;
  664. }
  665. }