You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cc 5.9 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <random>
  18. #include <iostream>
  19. #include <fstream>
  20. #include <cstring>
  21. #include <memory>
  22. #include "include/api/model.h"
  23. #include "include/api/context.h"
  24. #include "include/api/status.h"
  25. #include "include/api/types.h"
  26. namespace {
  27. constexpr int kNumPrintOfOutData = 50;
  28. }
  29. std::string RealPath(const char *path) {
  30. const size_t max = 4096;
  31. if (path == nullptr) {
  32. std::cerr << "path is nullptr" << std::endl;
  33. return "";
  34. }
  35. if ((strlen(path)) >= max) {
  36. std::cerr << "path is too long" << std::endl;
  37. return "";
  38. }
  39. auto resolved_path = std::make_unique<char[]>(max);
  40. if (resolved_path == nullptr) {
  41. std::cerr << "new resolved_path failed" << std::endl;
  42. return "";
  43. }
  44. #ifdef _WIN32
  45. char *real_path = _fullpath(resolved_path.get(), path, 1024);
  46. #else
  47. char *real_path = realpath(path, resolved_path.get());
  48. #endif
  49. if (real_path == nullptr || strlen(real_path) == 0) {
  50. std::cerr << "file path is not valid : " << path << std::endl;
  51. return "";
  52. }
  53. std::string res = resolved_path.get();
  54. return res;
  55. }
  56. char *ReadFile(const char *file, size_t *size) {
  57. if (file == nullptr) {
  58. std::cerr << "file is nullptr." << std::endl;
  59. return nullptr;
  60. }
  61. std::ifstream ifs(file);
  62. if (!ifs.good()) {
  63. std::cerr << "file: " << file << " is not exist." << std::endl;
  64. return nullptr;
  65. }
  66. if (!ifs.is_open()) {
  67. std::cerr << "file: " << file << " open failed." << std::endl;
  68. return nullptr;
  69. }
  70. ifs.seekg(0, std::ios::end);
  71. *size = ifs.tellg();
  72. std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
  73. if (buf == nullptr) {
  74. std::cerr << "malloc buf failed, file: " << file << std::endl;
  75. ifs.close();
  76. return nullptr;
  77. }
  78. ifs.seekg(0, std::ios::beg);
  79. ifs.read(buf.get(), *size);
  80. ifs.close();
  81. return buf.release();
  82. }
  83. template <typename T, typename Distribution>
  84. void GenerateRandomData(int size, void *data, Distribution distribution) {
  85. std::mt19937 random_engine;
  86. int elements_num = size / sizeof(T);
  87. (void)std::generate_n(static_cast<T *>(data), elements_num,
  88. [&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
  89. }
  90. int GenerateInputDataWithRandom(std::vector<mindspore::MSTensor> inputs) {
  91. for (auto tensor : inputs) {
  92. auto input_data = tensor.MutableData();
  93. if (input_data == nullptr) {
  94. std::cerr << "MallocData for inTensor failed." << std::endl;
  95. return -1;
  96. }
  97. GenerateRandomData<float>(tensor.DataSize(), input_data, std::uniform_real_distribution<float>(0.1f, 1.0f));
  98. }
  99. return mindspore::kSuccess;
  100. }
  101. int QuickStart(int argc, const char **argv) {
  102. if (argc < 2) {
  103. std::cerr << "Model file must be provided.\n";
  104. return -1;
  105. }
  106. // Read model file.
  107. auto model_path = RealPath(argv[1]);
  108. if (model_path.empty()) {
  109. std::cerr << "Model path " << argv[1] << " is invalid.";
  110. return -1;
  111. }
  112. size_t size = 0;
  113. char *model_buf = ReadFile(model_path.c_str(), &size);
  114. if (model_buf == nullptr) {
  115. std::cerr << "Read model file failed." << std::endl;
  116. return -1;
  117. }
  118. // Create and init context, add CPU device info
  119. auto context = std::make_shared<mindspore::Context>();
  120. if (context == nullptr) {
  121. delete[](model_buf);
  122. std::cerr << "New context failed." << std::endl;
  123. return -1;
  124. }
  125. auto &device_list = context->MutableDeviceInfo();
  126. auto device_info = std::make_shared<mindspore::CPUDeviceInfo>();
  127. if (device_info == nullptr) {
  128. delete[](model_buf);
  129. std::cerr << "New CPUDeviceInfo failed." << std::endl;
  130. return -1;
  131. }
  132. device_list.push_back(device_info);
  133. // Create model
  134. auto model = new (std::nothrow) mindspore::Model();
  135. if (model == nullptr) {
  136. delete[](model_buf);
  137. std::cerr << "New Model failed." << std::endl;
  138. return -1;
  139. }
  140. // Build model
  141. auto build_ret = model->Build(model_buf, size, mindspore::kMindIR, context);
  142. delete[](model_buf);
  143. if (build_ret != mindspore::kSuccess) {
  144. delete model;
  145. std::cerr << "Build model failed." << std::endl;
  146. return -1;
  147. }
  148. // Get Input
  149. auto inputs = model->GetInputs();
  150. // Generate random data as input data.
  151. auto ret = GenerateInputDataWithRandom(inputs);
  152. if (ret != mindspore::kSuccess) {
  153. delete model;
  154. std::cerr << "Generate Random Input Data failed." << std::endl;
  155. return -1;
  156. }
  157. // Get Output
  158. auto outputs = model->GetOutputs();
  159. // Model Predict
  160. auto predict_ret = model->Predict(inputs, &outputs);
  161. if (predict_ret != mindspore::kSuccess) {
  162. delete model;
  163. std::cerr << "Predict error " << ret << std::endl;
  164. return ret;
  165. }
  166. // Print Output Tensor Data.
  167. for (auto tensor : outputs) {
  168. std::cout << "tensor name is:" << tensor.Name() << " tensor size is:" << tensor.DataSize()
  169. << " tensor elements num is:" << tensor.ElementNum() << std::endl;
  170. auto out_data = reinterpret_cast<const float *>(tensor.Data().get());
  171. std::cout << "output data is:";
  172. for (int i = 0; i < tensor.ElementNum() && i <= 50; i++) {
  173. std::cout << out_data[i] << " ";
  174. }
  175. std::cout << std::endl;
  176. }
  177. // Delete model.
  178. delete model;
  179. return mindspore::kSuccess;
  180. }
  181. int main(int argc, const char **argv) { return QuickStart(argc, argv); }