From 88486f149140f4b9644ea695b7a4e376c59f1fa6 Mon Sep 17 00:00:00 2001 From: liujiahan Date: Mon, 15 Mar 2021 10:26:03 +0800 Subject: [PATCH] fixed a bug in benchmark_train fixed clang-format error --- mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt | 3 ++- mindspore/lite/tools/benchmark_train/net_train.cc | 3 ++- mindspore/lite/tools/benchmark_train/net_train.h | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt index 324ce653ab..6e6c4f4043 100644 --- a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -21,7 +21,8 @@ if(PLATFORM_ARM64) if(ENABLE_FP16) file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) if(SUPPORT_TRAIN) - file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc) + file(GLOB FP16_KERNEL_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc) + set(FP16_KERNEL_SRC ${FP16_KERNEL_SRC} ${FP16_KERNEL_TRAIN_SRC}) endif() add_library(cpu_fp16_kernel_mid OBJECT ${FP16_KERNEL_SRC}) add_dependencies(cpu_fp16_kernel_mid fbs_src) diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index b91e3473b0..7ad4c8ec3e 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -385,7 +385,7 @@ int NetTrain::RunNetTrain() { } else { context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; } - + context->device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; layer_checksum_ = flags_->layer_checksum_; context->thread_num_ = flags_->num_threads_; session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); @@ -545,6 +545,7 @@ int NetTrain::Init() { MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_; MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_; MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_; + MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_; if (this->flags_->epochs_ < 0) { MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0"; diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h index 7abd01b96c..23a3c7e628 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.h +++ b/mindspore/lite/tools/benchmark_train/net_train.h @@ -66,6 +66,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); + AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false); } ~NetTrainFlags() override = default; @@ -82,6 +83,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { DataType in_data_type_; std::string in_data_type_in_ = "bin"; int cpu_bind_mode_ = 1; + bool enable_fp16_ = false; // MarkPerformance int num_threads_ = 1; int warm_up_loop_count_ = 0;