You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark.cc 27 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/benchmark/benchmark.h"
  17. #define __STDC_FORMAT_MACROS
  18. #include <cinttypes>
  19. #undef __STDC_FORMAT_MACROS
  20. #include <algorithm>
  21. #include <utility>
  22. #include "include/context.h"
  23. #include "include/ms_tensor.h"
  24. #include "include/version.h"
  25. #include "src/common/common.h"
  26. #include "src/runtime/runtime_api.h"
  27. namespace mindspore {
  28. namespace lite {
  29. static const char *DELIM_COLON = ":";
  30. static const char *DELIM_COMMA = ",";
  31. static const char *DELIM_SLASH = "/";
  32. int Benchmark::GenerateRandomData(size_t size, void *data) {
  33. MS_ASSERT(data != nullptr);
  34. char *casted_data = static_cast<char *>(data);
  35. for (size_t i = 0; i < size; i++) {
  36. casted_data[i] = static_cast<char>(i);
  37. }
  38. return RET_OK;
  39. }
  40. int Benchmark::GenerateInputData() {
  41. for (auto tensor : ms_inputs_) {
  42. MS_ASSERT(tensor != nullptr);
  43. auto input_data = tensor->MutableData();
  44. if (input_data == nullptr) {
  45. MS_LOG(ERROR) << "MallocData for inTensor failed";
  46. return RET_ERROR;
  47. }
  48. MS_ASSERT(tensor->GetData() != nullptr);
  49. auto tensor_byte_size = tensor->Size();
  50. auto status = GenerateRandomData(tensor_byte_size, input_data);
  51. if (status != 0) {
  52. std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
  53. MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status;
  54. return status;
  55. }
  56. }
  57. return RET_OK;
  58. }
  59. int Benchmark::LoadInput() {
  60. if (flags_->in_data_file_.empty()) {
  61. auto status = GenerateInputData();
  62. if (status != 0) {
  63. std::cerr << "Generate input data error " << status << std::endl;
  64. MS_LOG(ERROR) << "Generate input data error " << status;
  65. return status;
  66. }
  67. } else {
  68. auto status = ReadInputFile();
  69. if (status != 0) {
  70. std::cerr << "ReadInputFile error, " << status << std::endl;
  71. MS_LOG(ERROR) << "ReadInputFile error, " << status;
  72. return status;
  73. }
  74. }
  75. return RET_OK;
  76. }
  77. int Benchmark::ReadInputFile() {
  78. if (ms_inputs_.empty()) {
  79. return RET_OK;
  80. }
  81. if (this->flags_->in_data_type_ == kImage) {
  82. MS_LOG(ERROR) << "Not supported image input";
  83. return RET_ERROR;
  84. } else {
  85. for (size_t i = 0; i < flags_->input_data_list_.size(); i++) {
  86. auto cur_tensor = ms_inputs_.at(i);
  87. MS_ASSERT(cur_tensor != nullptr);
  88. size_t size;
  89. char *bin_buf = ReadFile(flags_->input_data_list_[i].c_str(), &size);
  90. if (bin_buf == nullptr) {
  91. MS_LOG(ERROR) << "ReadFile return nullptr";
  92. return RET_ERROR;
  93. }
  94. if (cur_tensor->data_type() == kObjectTypeString) {
  95. std::string str(bin_buf, size);
  96. auto ret = StringsToMSTensor({str}, cur_tensor);
  97. if (ret != RET_OK) {
  98. MS_LOG(ERROR) << "write strings to tensor failed";
  99. delete[] bin_buf;
  100. return RET_ERROR;
  101. }
  102. } else {
  103. auto tensor_data_size = cur_tensor->Size();
  104. if (size != tensor_data_size) {
  105. std::cerr << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
  106. << std::endl;
  107. MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size;
  108. delete[] bin_buf;
  109. return RET_ERROR;
  110. }
  111. auto input_data = cur_tensor->MutableData();
  112. memcpy(input_data, bin_buf, tensor_data_size);
  113. }
  114. delete[] bin_buf;
  115. }
  116. }
  117. return RET_OK;
  118. }
  119. // calibData is FP32
  120. int Benchmark::ReadCalibData() {
  121. const char *calib_data_path = flags_->benchmark_data_file_.c_str();
  122. // read calib data
  123. std::ifstream in_file(calib_data_path);
  124. if (!in_file.good()) {
  125. std::cerr << "file: " << calib_data_path << " is not exist" << std::endl;
  126. MS_LOG(ERROR) << "file: " << calib_data_path << " is not exist";
  127. return RET_ERROR;
  128. }
  129. if (!in_file.is_open()) {
  130. std::cerr << "file: " << calib_data_path << " open failed" << std::endl;
  131. MS_LOG(ERROR) << "file: " << calib_data_path << " open failed";
  132. in_file.close();
  133. return RET_ERROR;
  134. }
  135. std::string line;
  136. MS_LOG(INFO) << "Start reading calibData file";
  137. std::string tensor_name;
  138. while (!in_file.eof()) {
  139. getline(in_file, line);
  140. std::stringstream string_line1(line);
  141. size_t dim = 0;
  142. string_line1 >> tensor_name >> dim;
  143. std::vector<size_t> dims;
  144. size_t shape_size = 1;
  145. for (size_t i = 0; i < dim; i++) {
  146. size_t tmp_dim;
  147. string_line1 >> tmp_dim;
  148. dims.push_back(tmp_dim);
  149. shape_size *= tmp_dim;
  150. }
  151. getline(in_file, line);
  152. std::stringstream string_line2(line);
  153. std::vector<float> tensor_data;
  154. for (size_t i = 0; i < shape_size; i++) {
  155. float tmp_data;
  156. string_line2 >> tmp_data;
  157. tensor_data.push_back(tmp_data);
  158. }
  159. auto *check_tensor = new CheckTensor(dims, tensor_data);
  160. this->benchmark_data_.insert(std::make_pair(tensor_name, check_tensor));
  161. }
  162. in_file.close();
  163. MS_LOG(INFO) << "Finish reading calibData file";
  164. return RET_OK;
  165. }
  166. int Benchmark::CompareOutput() {
  167. std::cout << "================ Comparing Output data ================" << std::endl;
  168. float total_bias = 0;
  169. int total_size = 0;
  170. bool has_error = false;
  171. for (const auto &calib_tensor : benchmark_data_) {
  172. std::string node_or_tensor_name = calib_tensor.first;
  173. auto tensors = session_->GetOutputsByNodeName(node_or_tensor_name);
  174. mindspore::tensor::MSTensor *tensor = nullptr;
  175. if (tensors.empty() || tensors.size() != 1) {
  176. MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name
  177. << " or node has more than one output tensor, switch to GetOutputByTensorName";
  178. tensor = session_->GetOutputByTensorName(node_or_tensor_name);
  179. if (tensor == nullptr) {
  180. MS_LOG(ERROR) << "Cannot find output tensor " << node_or_tensor_name << ", get model output failed";
  181. return RET_ERROR;
  182. }
  183. } else {
  184. tensor = tensors.front();
  185. }
  186. MS_ASSERT(tensor->MutableData() != nullptr);
  187. float bias = 0;
  188. switch (msCalibDataType) {
  189. case TypeId::kNumberTypeFloat: {
  190. bias = CompareData<float>(node_or_tensor_name, tensor->shape(), static_cast<float *>(tensor->MutableData()));
  191. break;
  192. }
  193. case TypeId::kNumberTypeInt8: {
  194. bias = CompareData<int8_t>(node_or_tensor_name, tensor->shape(), static_cast<int8_t *>(tensor->MutableData()));
  195. break;
  196. }
  197. case TypeId::kNumberTypeUInt8: {
  198. bias =
  199. CompareData<uint8_t>(node_or_tensor_name, tensor->shape(), static_cast<uint8_t *>(tensor->MutableData()));
  200. break;
  201. }
  202. case TypeId::kNumberTypeInt32: {
  203. bias =
  204. CompareData<int32_t>(node_or_tensor_name, tensor->shape(), static_cast<int32_t *>(tensor->MutableData()));
  205. break;
  206. }
  207. default:
  208. MS_LOG(ERROR) << "Datatype " << msCalibDataType << " is not supported.";
  209. return RET_ERROR;
  210. }
  211. if (bias >= 0) {
  212. total_bias += bias;
  213. total_size++;
  214. } else {
  215. has_error = true;
  216. break;
  217. }
  218. }
  219. if (!has_error) {
  220. float mean_bias;
  221. if (total_size != 0) {
  222. mean_bias = total_bias / total_size * 100;
  223. } else {
  224. mean_bias = 0;
  225. }
  226. std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" << std::endl;
  227. std::cout << "=======================================================" << std::endl << std::endl;
  228. if (mean_bias > this->flags_->accuracy_threshold_) {
  229. MS_LOG(ERROR) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
  230. std::cerr << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
  231. return RET_ERROR;
  232. } else {
  233. return RET_OK;
  234. }
  235. } else {
  236. MS_LOG(ERROR) << "Error in CompareData";
  237. std::cerr << "Error in CompareData" << std::endl;
  238. std::cout << "=======================================================" << std::endl << std::endl;
  239. return RET_ERROR;
  240. }
  241. }
  242. int Benchmark::MarkPerformance() {
  243. MS_LOG(INFO) << "Running warm up loops...";
  244. std::cout << "Running warm up loops..." << std::endl;
  245. for (int i = 0; i < flags_->warm_up_loop_count_; i++) {
  246. auto status = session_->RunGraph();
  247. if (status != 0) {
  248. MS_LOG(ERROR) << "Inference error " << status;
  249. std::cerr << "Inference error " << status << std::endl;
  250. return status;
  251. }
  252. }
  253. MS_LOG(INFO) << "Running benchmark loops...";
  254. std::cout << "Running benchmark loops..." << std::endl;
  255. uint64_t time_min = 1000000;
  256. uint64_t time_max = 0;
  257. uint64_t time_avg = 0;
  258. for (int i = 0; i < flags_->loop_count_; i++) {
  259. session_->BindThread(true);
  260. auto start = GetTimeUs();
  261. auto status =
  262. flags_->time_profiling_ ? session_->RunGraph(before_call_back_, after_call_back_) : session_->RunGraph();
  263. if (status != 0) {
  264. MS_LOG(ERROR) << "Inference error " << status;
  265. std::cerr << "Inference error " << status;
  266. return status;
  267. }
  268. auto end = GetTimeUs();
  269. auto time = end - start;
  270. time_min = std::min(time_min, time);
  271. time_max = std::max(time_max, time);
  272. time_avg += time;
  273. session_->BindThread(false);
  274. }
  275. if (flags_->time_profiling_) {
  276. const std::vector<std::string> per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
  277. const std::vector<std::string> per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
  278. PrintResult(per_op_name, op_times_by_name_);
  279. PrintResult(per_op_type, op_times_by_type_);
  280. }
  281. if (flags_->loop_count_ > 0) {
  282. time_avg /= flags_->loop_count_;
  283. MS_LOG(INFO) << "Model = " << flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  284. << ", NumThreads = " << flags_->num_threads_ << ", MinRunTime = " << time_min / 1000.0f
  285. << ", MaxRuntime = " << time_max / 1000.0f << ", AvgRunTime = " << time_avg / 1000.0f;
  286. printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n",
  287. flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str(), flags_->num_threads_,
  288. time_min / 1000.0f, time_max / 1000.0f, time_avg / 1000.0f);
  289. }
  290. return RET_OK;
  291. }
  292. int Benchmark::MarkAccuracy() {
  293. MS_LOG(INFO) << "MarkAccuracy";
  294. std::cout << "MarkAccuracy" << std::endl;
  295. for (auto &msInput : ms_inputs_) {
  296. switch (msInput->data_type()) {
  297. case TypeId::kNumberTypeFloat:
  298. PrintInputData<float>(msInput);
  299. break;
  300. case TypeId::kNumberTypeFloat32:
  301. PrintInputData<float>(msInput);
  302. break;
  303. case TypeId::kNumberTypeInt8:
  304. PrintInputData<int8_t>(msInput);
  305. break;
  306. case TypeId::kNumberTypeUInt8:
  307. PrintInputData<uint8_t>(msInput);
  308. break;
  309. case TypeId::kNumberTypeInt32:
  310. PrintInputData<int>(msInput);
  311. break;
  312. default:
  313. MS_LOG(ERROR) << "Datatype " << msInput->data_type() << " is not supported.";
  314. return RET_ERROR;
  315. }
  316. }
  317. auto status = session_->RunGraph();
  318. if (status != RET_OK) {
  319. MS_LOG(ERROR) << "Inference error " << status;
  320. std::cerr << "Inference error " << status << std::endl;
  321. return status;
  322. }
  323. status = ReadCalibData();
  324. if (status != RET_OK) {
  325. MS_LOG(ERROR) << "Read calib data error " << status;
  326. std::cerr << "Read calib data error " << status << std::endl;
  327. return status;
  328. }
  329. status = CompareOutput();
  330. if (status != RET_OK) {
  331. MS_LOG(ERROR) << "Compare output error " << status;
  332. std::cerr << "Compare output error " << status << std::endl;
  333. return status;
  334. }
  335. return RET_OK;
  336. }
  337. int Benchmark::RunBenchmark() {
  338. auto start_prepare_time = GetTimeUs();
  339. // Load graph
  340. std::string model_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1);
  341. MS_LOG(INFO) << "start reading model file";
  342. std::cout << "start reading model file" << std::endl;
  343. size_t size = 0;
  344. char *graph_buf = ReadFile(flags_->model_file_.c_str(), &size);
  345. if (graph_buf == nullptr) {
  346. MS_LOG(ERROR) << "Read model file failed while running " << model_name.c_str();
  347. std::cerr << "Read model file failed while running " << model_name.c_str() << std::endl;
  348. return RET_ERROR;
  349. }
  350. auto model = std::shared_ptr<Model>(lite::Model::Import(graph_buf, size));
  351. delete[](graph_buf);
  352. if (model == nullptr) {
  353. MS_LOG(ERROR) << "Import model file failed while running " << model_name.c_str();
  354. std::cerr << "Import model file failed while running " << model_name.c_str() << std::endl;
  355. return RET_ERROR;
  356. }
  357. auto context = std::make_shared<Context>();
  358. if (context == nullptr) {
  359. MS_LOG(ERROR) << "New context failed while running " << model_name.c_str();
  360. std::cerr << "New context failed while running " << model_name.c_str() << std::endl;
  361. return RET_ERROR;
  362. }
  363. auto &cpu_device_ctx = context->device_list_[0];
  364. if (flags_->cpu_bind_mode_ == MID_CPU) {
  365. cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
  366. } else if (flags_->cpu_bind_mode_ == HIGHER_CPU) {
  367. cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
  368. } else {
  369. cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
  370. }
  371. cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
  372. if (flags_->device_ == "GPU") {
  373. DeviceContext gpu_device_ctx{DT_GPU, {false}};
  374. gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_;
  375. context->device_list_.push_back(gpu_device_ctx);
  376. }
  377. context->thread_num_ = flags_->num_threads_;
  378. session_ = session::LiteSession::CreateSession(context.get());
  379. if (session_ == nullptr) {
  380. MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str();
  381. std::cout << "CreateSession failed while running ", model_name.c_str();
  382. return RET_ERROR;
  383. }
  384. auto ret = session_->CompileGraph(model.get());
  385. if (ret != RET_OK) {
  386. MS_LOG(ERROR) << "CompileGraph failed while running ", model_name.c_str();
  387. std::cout << "CompileGraph failed while running ", model_name.c_str();
  388. return ret;
  389. }
  390. if (!flags_->input_shape_list_.empty()) {
  391. std::vector<std::vector<int>> input_shapes;
  392. std::string input_dims_list = flags_->input_shape_list_;
  393. while (!input_dims_list.empty()) {
  394. auto position =
  395. input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length();
  396. std::string input_dims = input_dims_list.substr(0, position);
  397. std::vector<int> input_shape;
  398. while (!input_dims.empty()) {
  399. auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length();
  400. std::string dim = input_dims.substr(0, pos);
  401. input_shape.emplace_back(std::stoi(dim));
  402. input_dims = input_dims.substr(pos);
  403. }
  404. input_shapes.emplace_back(input_shape);
  405. input_dims_list = input_dims_list.substr(position);
  406. }
  407. ret = session_->Resize(session_->GetInputs(), input_shapes);
  408. if (ret != RET_OK) {
  409. MS_LOG(ERROR) << "Input tensor resize failed.";
  410. std::cout << "Input tensor resize failed.";
  411. return ret;
  412. }
  413. }
  414. model->Free();
  415. ms_inputs_ = session_->GetInputs();
  416. auto end_prepare_time = GetTimeUs();
  417. MS_LOG(INFO) << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms";
  418. std::cout << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms" << std::endl;
  419. // Load input
  420. MS_LOG(INFO) << "start generate input data";
  421. auto status = LoadInput();
  422. if (status != 0) {
  423. MS_LOG(ERROR) << "Generate input data error";
  424. return status;
  425. }
  426. if (!flags_->benchmark_data_file_.empty()) {
  427. status = MarkAccuracy();
  428. for (auto &data : benchmark_data_) {
  429. data.second->shape.clear();
  430. data.second->data.clear();
  431. delete data.second;
  432. }
  433. benchmark_data_.clear();
  434. if (status != 0) {
  435. MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
  436. std::cout << "Run MarkAccuracy error: " << status << std::endl;
  437. return status;
  438. }
  439. } else {
  440. status = MarkPerformance();
  441. if (status != 0) {
  442. MS_LOG(ERROR) << "Run MarkPerformance error: " << status;
  443. std::cout << "Run MarkPerformance error: " << status << std::endl;
  444. return status;
  445. }
  446. }
  447. return RET_OK;
  448. }
  449. void BenchmarkFlags::InitInputDataList() {
  450. char *input_list = new char[this->in_data_file_.length() + 1];
  451. snprintf(input_list, this->in_data_file_.length() + 1, "%s", this->in_data_file_.c_str());
  452. char *cur_input;
  453. const char *split_c = ",";
  454. cur_input = strtok(input_list, split_c);
  455. while (cur_input != nullptr) {
  456. input_data_list_.emplace_back(cur_input);
  457. cur_input = strtok(nullptr, split_c);
  458. }
  459. delete[] input_list;
  460. }
  461. void BenchmarkFlags::InitResizeDimsList() {
  462. std::string content;
  463. content = this->resize_dims_in_;
  464. std::vector<int64_t> shape;
  465. auto shape_strs = StringSplit(content, std::string(DELIM_COLON));
  466. for (const auto &shape_str : shape_strs) {
  467. shape.clear();
  468. auto dim_strs = StringSplit(shape_str, std::string(DELIM_COMMA));
  469. std::cout << "Resize Dims: ";
  470. for (const auto &dim_str : dim_strs) {
  471. std::cout << dim_str << " ";
  472. shape.emplace_back(static_cast<int64_t>(std::stoi(dim_str)));
  473. }
  474. std::cout << std::endl;
  475. this->resize_dims_.emplace_back(shape);
  476. }
  477. }
  478. int Benchmark::InitCallbackParameter() {
  479. // before callback
  480. before_call_back_ = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
  481. const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
  482. const CallBackParam &callParam) {
  483. if (before_inputs.empty()) {
  484. MS_LOG(INFO) << "The num of beforeInputs is empty";
  485. }
  486. if (before_outputs.empty()) {
  487. MS_LOG(INFO) << "The num of beforeOutputs is empty";
  488. }
  489. if (op_times_by_type_.find(callParam.node_type) == op_times_by_type_.end()) {
  490. op_times_by_type_.insert(std::make_pair(callParam.node_type, std::make_pair(0, 0.0f)));
  491. }
  492. if (op_times_by_name_.find(callParam.node_name) == op_times_by_name_.end()) {
  493. op_times_by_name_.insert(std::make_pair(callParam.node_name, std::make_pair(0, 0.0f)));
  494. }
  495. op_call_times_total_++;
  496. op_begin_ = GetTimeUs();
  497. return true;
  498. };
  499. // after callback
  500. after_call_back_ = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
  501. const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
  502. const CallBackParam &call_param) {
  503. uint64_t opEnd = GetTimeUs();
  504. if (after_inputs.empty()) {
  505. MS_LOG(INFO) << "The num of after inputs is empty";
  506. }
  507. if (after_outputs.empty()) {
  508. MS_LOG(INFO) << "The num of after outputs is empty";
  509. }
  510. float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f;
  511. op_cost_total_ += cost;
  512. op_times_by_type_[call_param.node_type].first++;
  513. op_times_by_type_[call_param.node_type].second += cost;
  514. op_times_by_name_[call_param.node_name].first++;
  515. op_times_by_name_[call_param.node_name].second += cost;
  516. return true;
  517. };
  518. return RET_OK;
  519. }
  520. int Benchmark::Init() {
  521. if (this->flags_ == nullptr) {
  522. return 1;
  523. }
  524. MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
  525. MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
  526. MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
  527. MS_LOG(INFO) << "LoopCount = " << this->flags_->loop_count_;
  528. MS_LOG(INFO) << "DeviceType = " << this->flags_->device_;
  529. MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_;
  530. MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_;
  531. MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
  532. MS_LOG(INFO) << "Fp16Priority = " << this->flags_->enable_fp16_;
  533. MS_LOG(INFO) << "calibDataPath = " << this->flags_->benchmark_data_file_;
  534. if (this->flags_->loop_count_ < 1) {
  535. MS_LOG(ERROR) << "LoopCount:" << this->flags_->loop_count_ << " must be greater than 0";
  536. std::cerr << "LoopCount:" << this->flags_->loop_count_ << " must be greater than 0" << std::endl;
  537. return RET_ERROR;
  538. }
  539. if (this->flags_->num_threads_ < 1) {
  540. MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
  541. std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl;
  542. return RET_ERROR;
  543. }
  544. if (this->flags_->cpu_bind_mode_ == 2) {
  545. MS_LOG(INFO) << "cpuBindMode = MID_CPU";
  546. std::cout << "cpuBindMode = MID_CPU" << std::endl;
  547. } else if (this->flags_->cpu_bind_mode_ == 1) {
  548. MS_LOG(INFO) << "cpuBindMode = HIGHER_CPU";
  549. std::cout << "cpuBindMode = HIGHER_CPU" << std::endl;
  550. } else {
  551. MS_LOG(INFO) << "cpuBindMode = NO_BIND";
  552. std::cout << "cpuBindMode = NO_BIND" << std::endl;
  553. }
  554. this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
  555. if (!flags_->benchmark_data_type_.empty()) {
  556. if (data_type_map_.find(flags_->benchmark_data_type_) == data_type_map_.end()) {
  557. MS_LOG(ERROR) << "CalibDataType not supported: " << flags_->benchmark_data_type_.c_str();
  558. return RET_ERROR;
  559. }
  560. msCalibDataType = data_type_map_.at(flags_->benchmark_data_type_);
  561. MS_LOG(INFO) << "CalibDataType = " << flags_->benchmark_data_type_.c_str();
  562. std::cout << "CalibDataType = " << flags_->benchmark_data_type_.c_str() << std::endl;
  563. }
  564. if (flags_->model_file_.empty()) {
  565. MS_LOG(ERROR) << "modelPath is required";
  566. std::cerr << "modelPath is required" << std::endl;
  567. return 1;
  568. }
  569. flags_->InitInputDataList();
  570. flags_->InitResizeDimsList();
  571. if (!flags_->resize_dims_.empty() && flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
  572. MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
  573. std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
  574. return RET_ERROR;
  575. }
  576. if (flags_->device_ != "CPU" && flags_->device_ != "GPU") {
  577. MS_LOG(ERROR) << "Device type:" << flags_->device_ << " is not supported.";
  578. std::cerr << "Device type:" << flags_->device_ << " is not supported." << std::endl;
  579. return RET_ERROR;
  580. }
  581. if (flags_->time_profiling_) {
  582. auto status = InitCallbackParameter();
  583. if (status != RET_OK) {
  584. MS_LOG(ERROR) << "Init callback Parameter failed.";
  585. std::cerr << "Init callback Parameter failed." << std::endl;
  586. return RET_ERROR;
  587. }
  588. }
  589. return RET_OK;
  590. }
  591. int Benchmark::PrintResult(const std::vector<std::string> &title,
  592. const std::map<std::string, std::pair<int, float>> &result) {
  593. std::vector<size_t> columnLenMax(5);
  594. std::vector<std::vector<std::string>> rows;
  595. for (auto &iter : result) {
  596. char stringBuf[5][100] = {};
  597. std::vector<std::string> columns;
  598. size_t len;
  599. len = iter.first.size();
  600. if (len > columnLenMax.at(0)) {
  601. columnLenMax.at(0) = len + 4;
  602. }
  603. columns.push_back(iter.first);
  604. len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / flags_->loop_count_);
  605. if (len > columnLenMax.at(1)) {
  606. columnLenMax.at(1) = len + 4;
  607. }
  608. columns.emplace_back(stringBuf[1]);
  609. len = snprintf(stringBuf[2], sizeof(stringBuf[2]), "%f", iter.second.second / op_cost_total_);
  610. if (len > columnLenMax.at(2)) {
  611. columnLenMax.at(2) = len + 4;
  612. }
  613. columns.emplace_back(stringBuf[2]);
  614. len = snprintf(stringBuf[3], sizeof(stringBuf[3]), "%d", iter.second.first);
  615. if (len > columnLenMax.at(3)) {
  616. columnLenMax.at(3) = len + 4;
  617. }
  618. columns.emplace_back(stringBuf[3]);
  619. len = snprintf(stringBuf[4], sizeof(stringBuf[4]), "%f", iter.second.second);
  620. if (len > columnLenMax.at(4)) {
  621. columnLenMax.at(4) = len + 4;
  622. }
  623. columns.emplace_back(stringBuf[4]);
  624. rows.push_back(columns);
  625. }
  626. printf("-------------------------------------------------------------------------\n");
  627. for (int i = 0; i < 5; i++) {
  628. auto printBuf = title[i];
  629. if (printBuf.size() > columnLenMax.at(i)) {
  630. columnLenMax.at(i) = printBuf.size();
  631. }
  632. printBuf.resize(columnLenMax.at(i), ' ');
  633. printf("%s\t", printBuf.c_str());
  634. }
  635. printf("\n");
  636. for (size_t i = 0; i < rows.size(); i++) {
  637. for (int j = 0; j < 5; j++) {
  638. auto printBuf = rows[i][j];
  639. printBuf.resize(columnLenMax.at(j), ' ');
  640. printf("%s\t", printBuf.c_str());
  641. }
  642. printf("\n");
  643. }
  644. return RET_OK;
  645. }
  646. Benchmark::~Benchmark() {
  647. for (auto iter : this->benchmark_data_) {
  648. delete (iter.second);
  649. }
  650. this->benchmark_data_.clear();
  651. delete (session_);
  652. }
  653. int RunBenchmark(int argc, const char **argv) {
  654. BenchmarkFlags flags;
  655. Option<std::string> err = flags.ParseFlags(argc, argv);
  656. if (err.IsSome()) {
  657. std::cerr << err.Get() << std::endl;
  658. std::cerr << flags.Usage() << std::endl;
  659. return RET_ERROR;
  660. }
  661. if (flags.help) {
  662. std::cerr << flags.Usage() << std::endl;
  663. return RET_OK;
  664. }
  665. Benchmark benchmark(&flags);
  666. auto status = benchmark.Init();
  667. if (status != 0) {
  668. MS_LOG(ERROR) << "Benchmark init Error : " << status;
  669. std::cerr << "Benchmark init Error : " << status << std::endl;
  670. return RET_ERROR;
  671. }
  672. status = benchmark.RunBenchmark();
  673. if (status != 0) {
  674. MS_LOG(ERROR) << "Run Benchmark "
  675. << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  676. << " Failed : " << status;
  677. std::cerr << "Run Benchmark " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  678. << " Failed : " << status << std::endl;
  679. return RET_ERROR;
  680. }
  681. MS_LOG(INFO) << "Run Benchmark " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  682. << " Success.";
  683. std::cout << "Run Benchmark " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  684. << " Success." << std::endl;
  685. return RET_OK;
  686. }
  687. } // namespace lite
  688. } // namespace mindspore