You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cc 6.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <random>
  18. #include <iostream>
  19. #include <fstream>
  20. #include <cstring>
  21. #include <cmath>
  22. #include <vector>
  23. #include <memory>
  24. #include "include/errorcode.h"
  25. #include "include/context.h"
  26. #include "include/api/types.h"
  27. #include "include/api/model.h"
  28. namespace mindspore {
  29. namespace lite {
  30. namespace {
  31. constexpr int kNumPrintOfOutData = 20;
  32. std::string RealPath(const char *path) {
  33. const size_t max = 4096;
  34. if (path == nullptr) {
  35. std::cerr << "path is nullptr" << std::endl;
  36. return "";
  37. }
  38. if ((strlen(path)) >= max) {
  39. std::cerr << "path is too long" << std::endl;
  40. return "";
  41. }
  42. auto resolved_path = std::make_unique<char[]>(max);
  43. if (resolved_path == nullptr) {
  44. std::cerr << "new resolved_path failed" << std::endl;
  45. return "";
  46. }
  47. char *real_path = realpath(path, resolved_path.get());
  48. if (real_path == nullptr || strlen(real_path) == 0) {
  49. std::cerr << "file path is not valid : " << path << std::endl;
  50. return "";
  51. }
  52. std::string res = resolved_path.get();
  53. return res;
  54. }
  55. char *ReadFile(const char *file, size_t *size) {
  56. if (file == nullptr) {
  57. std::cerr << "file is nullptr." << std::endl;
  58. return nullptr;
  59. }
  60. std::ifstream ifs(file);
  61. if (!ifs.good()) {
  62. std::cerr << "file: " << file << " is not exist." << std::endl;
  63. return nullptr;
  64. }
  65. if (!ifs.is_open()) {
  66. std::cerr << "file: " << file << " open failed." << std::endl;
  67. return nullptr;
  68. }
  69. ifs.seekg(0, std::ios::end);
  70. *size = ifs.tellg();
  71. std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
  72. if (buf == nullptr) {
  73. std::cerr << "malloc buf failed, file: " << file << std::endl;
  74. ifs.close();
  75. return nullptr;
  76. }
  77. ifs.seekg(0, std::ios::beg);
  78. ifs.read(buf.get(), *size);
  79. ifs.close();
  80. return buf.release();
  81. }
  82. } // namespace
  83. template <typename T, typename Distribution>
  84. void GenerateRandomData(int size, void *data, Distribution distribution) {
  85. std::mt19937 random_engine;
  86. int elements_num = size / sizeof(T);
  87. (void)std::generate_n(static_cast<T *>(data), elements_num,
  88. [&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
  89. }
  90. void InitMSContext(const std::shared_ptr<mindspore::Context> &context) {
  91. context->SetThreadNum(1);
  92. context->SetEnableParallel(false);
  93. context->SetThreadAffinity(HIGHER_CPU);
  94. auto &device_list = context->MutableDeviceInfo();
  95. std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
  96. device_info->SetEnableFP16(false);
  97. device_list.push_back(device_info);
  98. std::shared_ptr<GPUDeviceInfo> provider_gpu_device_info = std::make_shared<GPUDeviceInfo>();
  99. provider_gpu_device_info->SetEnableFP16(false);
  100. provider_gpu_device_info->SetProviderDevice("GPU");
  101. provider_gpu_device_info->SetProvider("Tutorial");
  102. device_list.push_back(provider_gpu_device_info);
  103. }
  104. int CompileAndRun(int argc, const char **argv) {
  105. if (argc < 2) {
  106. std::cerr << "Model file must be provided.\n";
  107. return RET_ERROR;
  108. }
  109. // Read model file.
  110. auto model_path = RealPath(argv[1]);
  111. if (model_path.empty()) {
  112. std::cerr << "model path " << argv[1] << " is invalid.";
  113. return RET_ERROR;
  114. }
  115. auto context = std::make_shared<mindspore::Context>();
  116. if (context == nullptr) {
  117. std::cerr << "New context failed." << std::endl;
  118. return RET_ERROR;
  119. }
  120. (void)InitMSContext(context);
  121. mindspore::Model ms_model;
  122. size_t size = 0;
  123. char *model_buf = ReadFile(model_path.c_str(), &size);
  124. if (model_buf == nullptr) {
  125. std::cerr << "Read model file failed." << std::endl;
  126. return RET_ERROR;
  127. }
  128. auto ret = ms_model.Build(model_buf, size, kMindIR, context);
  129. delete[](model_buf);
  130. if (ret != kSuccess) {
  131. std::cerr << "ms_model.Build failed." << std::endl;
  132. return RET_ERROR;
  133. }
  134. std::vector<mindspore::MSTensor> ms_inputs_for_api = ms_model.GetInputs();
  135. for (auto tensor : ms_inputs_for_api) {
  136. auto input_data = tensor.MutableData();
  137. if (input_data == nullptr) {
  138. std::cerr << "MallocData for inTensor failed." << std::endl;
  139. return RET_ERROR;
  140. }
  141. GenerateRandomData<float>(tensor.DataSize(), input_data, std::uniform_real_distribution<float>(1.0f, 1.0f));
  142. }
  143. std::cout << "\n------- print inputs ----------" << std::endl;
  144. for (auto tensor : ms_inputs_for_api) {
  145. std::cout << "in tensor name is:" << tensor.Name() << "\nin tensor size is:" << tensor.DataSize()
  146. << "\nin tensor elements num is:" << tensor.ElementNum() << std::endl;
  147. auto out_data = reinterpret_cast<float *>(tensor.MutableData());
  148. std::cout << "input data is:";
  149. for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) {
  150. std::cout << out_data[i] << " ";
  151. }
  152. std::cout << std::endl;
  153. }
  154. std::cout << "------- print end ----------\n" << std::endl;
  155. std::vector<MSTensor> outputs;
  156. auto status = ms_model.Predict(ms_inputs_for_api, &outputs);
  157. if (status != kSuccess) {
  158. std::cerr << "Inference error." << std::endl;
  159. return RET_ERROR;
  160. }
  161. // Get Output Tensor Data.
  162. auto out_tensors = ms_model.GetOutputs();
  163. std::cout << "\n------- print outputs ----------" << std::endl;
  164. for (auto tensor : out_tensors) {
  165. std::cout << "out tensor name is:" << tensor.Name() << "\nout tensor size is:" << tensor.DataSize()
  166. << "\nout tensor elements num is:" << tensor.ElementNum() << std::endl;
  167. auto out_data = reinterpret_cast<float *>(tensor.MutableData());
  168. std::cout << "output data is:";
  169. for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) {
  170. std::cout << out_data[i] << " ";
  171. }
  172. std::cout << std::endl;
  173. }
  174. std::cout << "------- print end ----------\n" << std::endl;
  175. return RET_OK;
  176. }
  177. } // namespace lite
  178. } // namespace mindspore
  179. int main(int argc, const char **argv) { return mindspore::lite::CompileAndRun(argc, argv); }