You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

net_train.cc 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/benchmark_train/net_train.h"
  17. #define __STDC_FORMAT_MACROS
  18. #include <cinttypes>
  19. #undef __STDC_FORMAT_MACROS
  20. #include <algorithm>
  21. #include <utility>
  22. #ifdef ENABLE_NEON
  23. #include <arm_neon.h>
  24. #endif
  25. #include "src/common/common.h"
  26. #include "include/ms_tensor.h"
  27. #include "include/context.h"
  28. #include "src/runtime/runtime_api.h"
  29. #include "include/version.h"
  30. #include "include/model.h"
  31. namespace mindspore {
  32. namespace lite {
  33. static const char *DELIM_SLASH = "/";
  34. namespace {
  35. float *ReadFileBuf(const char *file, size_t *size) {
  36. if (file == nullptr) {
  37. MS_LOG(ERROR) << "file is nullptr";
  38. return nullptr;
  39. }
  40. MS_ASSERT(size != nullptr);
  41. std::string real_path = RealPath(file);
  42. std::ifstream ifs(real_path);
  43. if (!ifs.good()) {
  44. MS_LOG(ERROR) << "file: " << real_path << " is not exist";
  45. return nullptr;
  46. }
  47. if (!ifs.is_open()) {
  48. MS_LOG(ERROR) << "file: " << real_path << " open failed";
  49. return nullptr;
  50. }
  51. ifs.seekg(0, std::ios::end);
  52. *size = ifs.tellg();
  53. std::unique_ptr<float[]> buf((new (std::nothrow) float[*size / sizeof(float) + 1]));
  54. if (buf == nullptr) {
  55. MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
  56. ifs.close();
  57. return nullptr;
  58. }
  59. ifs.seekg(0, std::ios::beg);
  60. ifs.read(reinterpret_cast<char *>(buf.get()), *size);
  61. ifs.close();
  62. return buf.release();
  63. }
  64. } // namespace
  65. int NetTrain::GenerateRandomData(size_t size, void *data) {
  66. MS_ASSERT(data != nullptr);
  67. char *casted_data = static_cast<char *>(data);
  68. for (size_t i = 0; i < size; i++) {
  69. casted_data[i] = static_cast<char>(i);
  70. }
  71. return RET_OK;
  72. }
  73. int NetTrain::GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
  74. for (auto tensor : *ms_inputs) {
  75. MS_ASSERT(tensor != nullptr);
  76. auto input_data = tensor->MutableData();
  77. if (input_data == nullptr) {
  78. MS_LOG(ERROR) << "MallocData for inTensor failed";
  79. return RET_ERROR;
  80. }
  81. auto tensor_byte_size = tensor->Size();
  82. auto status = GenerateRandomData(tensor_byte_size, input_data);
  83. if (status != RET_OK) {
  84. std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
  85. MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status;
  86. return status;
  87. }
  88. }
  89. return RET_OK;
  90. }
  91. int NetTrain::LoadInput(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
  92. if (flags_->in_data_file_.empty()) {
  93. auto status = GenerateInputData(ms_inputs);
  94. if (status != RET_OK) {
  95. std::cerr << "Generate input data error " << status << std::endl;
  96. MS_LOG(ERROR) << "Generate input data error " << status;
  97. return status;
  98. }
  99. } else {
  100. auto status = ReadInputFile(ms_inputs);
  101. if (status != RET_OK) {
  102. std::cerr << "ReadInputFile error, " << status << std::endl;
  103. MS_LOG(ERROR) << "ReadInputFile error, " << status;
  104. return status;
  105. }
  106. }
  107. return RET_OK;
  108. }
  109. int NetTrain::ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
  110. if (ms_inputs->empty()) {
  111. return RET_OK;
  112. }
  113. if (this->flags_->in_data_type_ == kImage) {
  114. MS_LOG(ERROR) << "Not supported image input";
  115. return RET_ERROR;
  116. } else {
  117. for (size_t i = 0; i < ms_inputs->size(); i++) {
  118. auto cur_tensor = ms_inputs->at(i);
  119. MS_ASSERT(cur_tensor != nullptr);
  120. size_t size;
  121. std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
  122. char *bin_buf = ReadFile(file_name.c_str(), &size);
  123. if (bin_buf == nullptr) {
  124. MS_LOG(ERROR) << "ReadFile return nullptr";
  125. return RET_ERROR;
  126. }
  127. auto tensor_data_size = cur_tensor->Size();
  128. if (size != tensor_data_size) {
  129. std::cerr << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
  130. << std::endl;
  131. MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size;
  132. delete bin_buf;
  133. return RET_ERROR;
  134. }
  135. auto input_data = cur_tensor->MutableData();
  136. memcpy(input_data, bin_buf, tensor_data_size);
  137. delete[](bin_buf);
  138. }
  139. }
  140. return RET_OK;
  141. }
  142. int NetTrain::CompareOutput(const session::LiteSession &lite_session) {
  143. std::cout << "================ Comparing Forward Output data ================" << std::endl;
  144. float total_bias = 0;
  145. int total_size = 0;
  146. bool has_error = false;
  147. auto tensors_list = lite_session.GetOutputs();
  148. if (tensors_list.empty()) {
  149. MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
  150. return RET_ERROR;
  151. }
  152. mindspore::tensor::MSTensor *tensor = nullptr;
  153. int i = 1;
  154. for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) {
  155. tensor = lite_session.GetOutputByTensorName(it->first);
  156. std::cout << "output is tensor " << it->first << "\n";
  157. auto outputs = tensor->data();
  158. size_t size;
  159. std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
  160. auto *bin_buf = ReadFileBuf(output_file.c_str(), &size);
  161. if (bin_buf == nullptr) {
  162. MS_LOG(ERROR) << "ReadFile return nullptr";
  163. return RET_ERROR;
  164. }
  165. if (size != tensor->Size()) {
  166. MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor->Size()
  167. << ", read size: " << size;
  168. return RET_ERROR;
  169. }
  170. float bias = CompareData<float>(bin_buf, tensor->ElementsNum(), reinterpret_cast<float *>(outputs));
  171. if (bias >= 0) {
  172. total_bias += bias;
  173. total_size++;
  174. } else {
  175. has_error = true;
  176. break;
  177. }
  178. i++;
  179. delete[] bin_buf;
  180. }
  181. if (!has_error) {
  182. float mean_bias;
  183. if (total_size != 0) {
  184. mean_bias = total_bias / total_size * 100;
  185. } else {
  186. mean_bias = 0;
  187. }
  188. std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
  189. << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
  190. std::cout << "=======================================================" << std::endl << std::endl;
  191. if (mean_bias > this->flags_->accuracy_threshold_) {
  192. MS_LOG(ERROR) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
  193. std::cerr << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
  194. return RET_ERROR;
  195. } else {
  196. return RET_OK;
  197. }
  198. } else {
  199. MS_LOG(ERROR) << "Error in CompareData";
  200. std::cerr << "Error in CompareData" << std::endl;
  201. std::cout << "=======================================================" << std::endl << std::endl;
  202. return RET_ERROR;
  203. }
  204. }
  205. int NetTrain::MarkPerformance(session::TrainSession *session) {
  206. MS_LOG(INFO) << "Running train loops...";
  207. std::cout << "Running train loops..." << std::endl;
  208. uint64_t time_min = 0xFFFFFFFFFFFFFFFF;
  209. uint64_t time_max = 0;
  210. uint64_t time_avg = 0;
  211. for (int i = 0; i < flags_->epochs_; i++) {
  212. session->BindThread(true);
  213. auto start = GetTimeUs();
  214. auto status =
  215. flags_->time_profiling_ ? session->RunGraph(before_call_back_, after_call_back_) : session->RunGraph();
  216. if (status != 0) {
  217. MS_LOG(ERROR) << "Inference error " << status;
  218. std::cerr << "Inference error " << status;
  219. return status;
  220. }
  221. auto end = GetTimeUs();
  222. auto time = end - start;
  223. time_min = std::min(time_min, time);
  224. time_max = std::max(time_max, time);
  225. time_avg += time;
  226. session->BindThread(false);
  227. }
  228. if (flags_->time_profiling_) {
  229. const std::vector<std::string> per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
  230. const std::vector<std::string> per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
  231. PrintResult(per_op_name, op_times_by_name_);
  232. PrintResult(per_op_type, op_times_by_type_);
  233. }
  234. if (flags_->epochs_ > 0) {
  235. time_avg /= flags_->epochs_;
  236. MS_LOG(INFO) << "Model = " << flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  237. << ", NumThreads = " << flags_->num_threads_ << ", MinRunTime = " << time_min / 1000.0f
  238. << ", MaxRuntime = " << time_max / 1000.0f << ", AvgRunTime = " << time_avg / 1000.0f;
  239. printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n",
  240. flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str(), flags_->num_threads_,
  241. time_min / 1000.0f, time_max / 1000.0f, time_avg / 1000.0f);
  242. }
  243. return RET_OK;
  244. }
  245. int NetTrain::MarkAccuracy(session::LiteSession *session) {
  246. MS_LOG(INFO) << "MarkAccuracy";
  247. for (auto &msInput : session->GetInputs()) {
  248. switch (msInput->data_type()) {
  249. case TypeId::kNumberTypeFloat:
  250. PrintInputData<float>(msInput);
  251. break;
  252. case TypeId::kNumberTypeFloat32:
  253. PrintInputData<float>(msInput);
  254. break;
  255. case TypeId::kNumberTypeInt32:
  256. PrintInputData<int>(msInput);
  257. break;
  258. default:
  259. MS_LOG(ERROR) << "Datatype " << msInput->data_type() << " is not supported.";
  260. return RET_ERROR;
  261. }
  262. }
  263. auto status = session->RunGraph();
  264. if (status != RET_OK) {
  265. MS_LOG(ERROR) << "Inference error " << status;
  266. std::cerr << "Inference error " << status << std::endl;
  267. return status;
  268. }
  269. status = CompareOutput(*session);
  270. if (status != RET_OK) {
  271. MS_LOG(ERROR) << "Compare output error " << status;
  272. std::cerr << "Compare output error " << status << std::endl;
  273. return status;
  274. }
  275. return RET_OK;
  276. }
  277. static CpuBindMode FlagToBindMode(int flag) {
  278. if (flag == 2) {
  279. return MID_CPU;
  280. }
  281. if (flag == 1) {
  282. return HIGHER_CPU;
  283. }
  284. return NO_BIND;
  285. }
  286. int NetTrain::CreateAndRunNetwork(const std::string &filename, int train_session, int epochs) {
  287. auto start_prepare_time = GetTimeUs();
  288. std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
  289. Context context;
  290. context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = FlagToBindMode(flags_->cpu_bind_mode_);
  291. context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
  292. context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
  293. context.thread_num_ = flags_->num_threads_;
  294. MS_LOG(INFO) << "start reading model file" << filename.c_str();
  295. std::cout << "start reading model file " << filename.c_str() << std::endl;
  296. auto *model = mindspore::lite::Model::Import(filename.c_str());
  297. if (model == nullptr) {
  298. MS_LOG(ERROR) << "create model for train session failed";
  299. return RET_ERROR;
  300. }
  301. session::LiteSession *session = nullptr;
  302. session::TrainSession *t_session = nullptr;
  303. if (train_session) {
  304. t_session = session::TrainSession::CreateSession(model, &context);
  305. if (t_session == nullptr) {
  306. MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str();
  307. std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl;
  308. delete model;
  309. return RET_ERROR;
  310. }
  311. if (flags_->loss_name_ != "") {
  312. t_session->SetLossName(flags_->loss_name_);
  313. }
  314. if (epochs > 0) {
  315. t_session->Train();
  316. }
  317. session = t_session;
  318. } else {
  319. session = session::LiteSession::CreateSession(&context);
  320. if (session == nullptr) {
  321. MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
  322. std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
  323. delete model;
  324. return RET_ERROR;
  325. }
  326. if (session->CompileGraph(model) != RET_OK) {
  327. MS_LOG(ERROR) << "Cannot compile model";
  328. delete model;
  329. return RET_ERROR;
  330. }
  331. delete model;
  332. }
  333. auto end_prepare_time = GetTimeUs();
  334. MS_LOG(INFO) << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms";
  335. std::cout << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms" << std::endl;
  336. // Load input
  337. MS_LOG(INFO) << "Load input data";
  338. auto ms_inputs = session->GetInputs();
  339. auto status = LoadInput(&ms_inputs);
  340. if (status != RET_OK) {
  341. MS_LOG(ERROR) << "Load input data error";
  342. return status;
  343. }
  344. if ((epochs > 0) && (t_session != nullptr)) {
  345. status = MarkPerformance(t_session);
  346. if (status != RET_OK) {
  347. MS_LOG(ERROR) << "Run MarkPerformance error: " << status;
  348. std::cout << "Run MarkPerformance error: " << status << std::endl;
  349. return status;
  350. }
  351. SaveModels(t_session, model); // save file if flags are on
  352. }
  353. if (!flags_->data_file_.empty()) {
  354. if (t_session != nullptr) {
  355. t_session->Eval();
  356. }
  357. status = MarkAccuracy(session);
  358. if (status != RET_OK) {
  359. MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
  360. std::cout << "Run MarkAccuracy error: " << status << std::endl;
  361. return status;
  362. }
  363. }
  364. return RET_OK;
  365. }
  366. int NetTrain::RunNetTrain() {
  367. CreateAndRunNetwork(flags_->model_file_, true, flags_->epochs_);
  368. auto status = CheckExecutionOfSavedModels(); // re-initialize sessions according to flags
  369. if (status != RET_OK) {
  370. MS_LOG(ERROR) << "Run CheckExecute error: " << status;
  371. std::cout << "Run CheckExecute error: " << status << std::endl;
  372. return status;
  373. }
  374. return RET_OK;
  375. }
  376. int NetTrain::SaveModels(session::TrainSession *session, mindspore::lite::Model *model) {
  377. if (!flags_->export_file_.empty()) {
  378. auto ret = Model::Export(model, flags_->export_file_.c_str());
  379. if (ret != RET_OK) {
  380. MS_LOG(ERROR) << "SaveToFile error";
  381. std::cout << "Run SaveToFile error";
  382. return RET_ERROR;
  383. }
  384. }
  385. if (!flags_->inference_file_.empty()) {
  386. auto tick = GetTimeUs();
  387. auto status = session->ExportInference(flags_->inference_file_);
  388. if (status != RET_OK) {
  389. MS_LOG(ERROR) << "Save model error: " << status;
  390. std::cout << "Save model error: " << status << std::endl;
  391. return status;
  392. }
  393. std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n";
  394. }
  395. return RET_OK;
  396. }
  397. int NetTrain::CheckExecutionOfSavedModels() {
  398. int status = RET_OK;
  399. if (!flags_->export_file_.empty()) {
  400. status = NetTrain::CreateAndRunNetwork(flags_->export_file_, true, 0);
  401. if (status != RET_OK) {
  402. MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status;
  403. std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl;
  404. return status;
  405. }
  406. }
  407. if (!flags_->inference_file_.empty()) {
  408. status = NetTrain::CreateAndRunNetwork(flags_->inference_file_ + ".ms", false, 0);
  409. if (status != RET_OK) {
  410. MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status;
  411. std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl;
  412. return status;
  413. }
  414. }
  415. return status;
  416. }
  417. int NetTrain::InitCallbackParameter() {
  418. // before callback
  419. before_call_back_ = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
  420. const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
  421. const mindspore::CallBackParam &callParam) {
  422. if (before_inputs.empty()) {
  423. MS_LOG(INFO) << "The num of beforeInputs is empty";
  424. }
  425. if (before_outputs.empty()) {
  426. MS_LOG(INFO) << "The num of beforeOutputs is empty";
  427. }
  428. if (op_times_by_type_.find(callParam.node_type) == op_times_by_type_.end()) {
  429. op_times_by_type_.insert(std::make_pair(callParam.node_type, std::make_pair(0, 0.0f)));
  430. }
  431. if (op_times_by_name_.find(callParam.node_name) == op_times_by_name_.end()) {
  432. op_times_by_name_.insert(std::make_pair(callParam.node_name, std::make_pair(0, 0.0f)));
  433. }
  434. op_call_times_total_++;
  435. op_begin_ = GetTimeUs();
  436. return true;
  437. };
  438. // after callback
  439. after_call_back_ = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
  440. const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
  441. const mindspore::CallBackParam &call_param) {
  442. uint64_t opEnd = GetTimeUs();
  443. if (after_inputs.empty()) {
  444. MS_LOG(INFO) << "The num of after inputs is empty";
  445. }
  446. if (after_outputs.empty()) {
  447. MS_LOG(INFO) << "The num of after outputs is empty";
  448. }
  449. float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f;
  450. op_cost_total_ += cost;
  451. op_times_by_type_[call_param.node_type].first++;
  452. op_times_by_type_[call_param.node_type].second += cost;
  453. op_times_by_name_[call_param.node_name].first++;
  454. op_times_by_name_[call_param.node_name].second += cost;
  455. if (flags_->layer_checksum_) {
  456. auto out_tensor = after_outputs.at(0);
  457. void *output = out_tensor->MutableData();
  458. int tensor_size = out_tensor->ElementsNum();
  459. TypeId type = out_tensor->data_type();
  460. std::cout << call_param.node_type << " shape=" << after_outputs.at(0)->shape() << " sum=";
  461. switch (type) {
  462. case kNumberTypeFloat32:
  463. std::cout << TensorSum<float>(output, tensor_size);
  464. break;
  465. case kNumberTypeInt32:
  466. std::cout << TensorSum<int>(output, tensor_size);
  467. break;
  468. #ifdef ENABLE_FP16
  469. case kNumberTypeFloat16:
  470. std::cout << TensorSum<float16_t>(output, tensor_size);
  471. break;
  472. #endif
  473. default:
  474. std::cout << "unsupported type:" << type;
  475. break;
  476. }
  477. std::cout << std::endl;
  478. }
  479. return true;
  480. };
  481. return RET_OK;
  482. }
  483. int NetTrain::Init() {
  484. if (this->flags_ == nullptr) {
  485. return 1;
  486. }
  487. MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
  488. MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
  489. MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
  490. MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_;
  491. MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_;
  492. MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_;
  493. MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
  494. MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
  495. MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
  496. MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;
  497. if (this->flags_->epochs_ < 0) {
  498. MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";
  499. std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl;
  500. return RET_ERROR;
  501. }
  502. if (this->flags_->num_threads_ < 1) {
  503. MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
  504. std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl;
  505. return RET_ERROR;
  506. }
  507. this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
  508. if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) {
  509. MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided";
  510. std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl;
  511. return RET_ERROR;
  512. }
  513. if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) {
  514. MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided";
  515. std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl;
  516. return RET_ERROR;
  517. }
  518. if (flags_->model_file_.empty()) {
  519. MS_LOG(ERROR) << "modelPath is required";
  520. std::cerr << "modelPath is required" << std::endl;
  521. return 1;
  522. }
  523. if (flags_->time_profiling_) {
  524. auto status = InitCallbackParameter();
  525. if (status != RET_OK) {
  526. MS_LOG(ERROR) << "Init callback Parameter failed.";
  527. std::cerr << "Init callback Parameter failed." << std::endl;
  528. return RET_ERROR;
  529. }
  530. }
  531. return RET_OK;
  532. }
  533. int NetTrain::PrintResult(const std::vector<std::string> &title,
  534. const std::map<std::string, std::pair<int, float>> &result) {
  535. std::vector<size_t> columnLenMax(5);
  536. std::vector<std::vector<std::string>> rows;
  537. for (auto &iter : result) {
  538. char stringBuf[5][100] = {};
  539. std::vector<std::string> columns;
  540. size_t len;
  541. len = iter.first.size();
  542. if (len > columnLenMax.at(0)) {
  543. columnLenMax.at(0) = len + 4;
  544. }
  545. columns.push_back(iter.first);
  546. len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / flags_->epochs_);
  547. if (len > columnLenMax.at(1)) {
  548. columnLenMax.at(1) = len + 4;
  549. }
  550. columns.emplace_back(stringBuf[1]);
  551. len = snprintf(stringBuf[2], sizeof(stringBuf[2]), "%f", iter.second.second / op_cost_total_);
  552. if (len > columnLenMax.at(2)) {
  553. columnLenMax.at(2) = len + 4;
  554. }
  555. columns.emplace_back(stringBuf[2]);
  556. len = snprintf(stringBuf[3], sizeof(stringBuf[3]), "%d", iter.second.first);
  557. if (len > columnLenMax.at(3)) {
  558. columnLenMax.at(3) = len + 4;
  559. }
  560. columns.emplace_back(stringBuf[3]);
  561. len = snprintf(stringBuf[4], sizeof(stringBuf[4]), "%f", iter.second.second);
  562. if (len > columnLenMax.at(4)) {
  563. columnLenMax.at(4) = len + 4;
  564. }
  565. columns.emplace_back(stringBuf[4]);
  566. rows.push_back(columns);
  567. }
  568. printf("-------------------------------------------------------------------------\n");
  569. for (int i = 0; i < 5; i++) {
  570. auto printBuf = title[i];
  571. if (printBuf.size() > columnLenMax.at(i)) {
  572. columnLenMax.at(i) = printBuf.size();
  573. }
  574. printBuf.resize(columnLenMax.at(i), ' ');
  575. printf("%s\t", printBuf.c_str());
  576. }
  577. printf("\n");
  578. for (size_t i = 0; i < rows.size(); i++) {
  579. for (int j = 0; j < 5; j++) {
  580. auto printBuf = rows[i][j];
  581. printBuf.resize(columnLenMax.at(j), ' ');
  582. printf("%s\t", printBuf.c_str());
  583. }
  584. printf("\n");
  585. }
  586. return RET_OK;
  587. }
  588. int RunNetTrain(int argc, const char **argv) {
  589. NetTrainFlags flags;
  590. Option<std::string> err = flags.ParseFlags(argc, argv);
  591. if (err.IsSome()) {
  592. std::cerr << err.Get() << std::endl;
  593. std::cerr << flags.Usage() << std::endl;
  594. return RET_ERROR;
  595. }
  596. if (flags.help) {
  597. std::cerr << flags.Usage() << std::endl;
  598. return RET_OK;
  599. }
  600. NetTrain net_trainer(&flags);
  601. auto status = net_trainer.Init();
  602. if (status != RET_OK) {
  603. MS_LOG(ERROR) << "NetTrain init Error : " << status;
  604. std::cerr << "NetTrain init Error : " << status << std::endl;
  605. return RET_ERROR;
  606. }
  607. status = net_trainer.RunNetTrain();
  608. if (status != RET_OK) {
  609. MS_LOG(ERROR) << "Run NetTrain "
  610. << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  611. << " Failed : " << status;
  612. std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  613. << " Failed : " << status << std::endl;
  614. return RET_ERROR;
  615. }
  616. MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  617. << " Success.";
  618. std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
  619. << " Success." << std::endl;
  620. return RET_OK;
  621. }
  622. } // namespace lite
  623. } // namespace mindspore