Browse Source

!13302 Fixed a bug in benchmark_train

From: @louisncu
Reviewed-by: @zhang_xue_tong,@HilbertDavid
Signed-off-by: @zhang_xue_tong
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
d831aba239
3 changed files with 6 additions and 2 deletions
  1. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt
  2. +2
    -1
      mindspore/lite/tools/benchmark_train/net_train.cc
  3. +2
    -0
      mindspore/lite/tools/benchmark_train/net_train.h

+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt View File

@@ -21,7 +21,8 @@ if(PLATFORM_ARM64)
if(ENABLE_FP16)
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
if(SUPPORT_TRAIN)
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
file(GLOB FP16_KERNEL_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
set(FP16_KERNEL_SRC ${FP16_KERNEL_SRC} ${FP16_KERNEL_TRAIN_SRC})
endif()
add_library(cpu_fp16_kernel_mid OBJECT ${FP16_KERNEL_SRC})
add_dependencies(cpu_fp16_kernel_mid fbs_src)


+ 2
- 1
mindspore/lite/tools/benchmark_train/net_train.cc View File

@@ -385,7 +385,7 @@ int NetTrain::RunNetTrain() {
} else {
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
}
context->device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
layer_checksum_ = flags_->layer_checksum_;
context->thread_num_ = flags_->num_threads_;
session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get());
@@ -545,6 +545,7 @@ int NetTrain::Init() {
MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;

if (this->flags_->epochs_ < 0) {
MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";


+ 2
- 0
mindspore/lite/tools/benchmark_train/net_train.h View File

@@ -66,6 +66,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", "");
AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false);
AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false);
}

~NetTrainFlags() override = default;
@@ -82,6 +83,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
DataType in_data_type_;
std::string in_data_type_in_ = "bin";
int cpu_bind_mode_ = 1;
bool enable_fp16_ = false;
// MarkPerformance
int num_threads_ = 1;
int warm_up_loop_count_ = 0;


Loading…
Cancel
Save