From: @louisncu Reviewed-by: @zhang_xue_tong,@HilbertDavid Signed-off-by: @zhang_xue_tongtags/v1.2.0-rc1
| @@ -21,7 +21,8 @@ if(PLATFORM_ARM64) | |||||
| if(ENABLE_FP16) | if(ENABLE_FP16) | ||||
| file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) | file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) | ||||
| if(SUPPORT_TRAIN) | if(SUPPORT_TRAIN) | ||||
| file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc) | |||||
| file(GLOB FP16_KERNEL_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc) | |||||
| set(FP16_KERNEL_SRC ${FP16_KERNEL_SRC} ${FP16_KERNEL_TRAIN_SRC}) | |||||
| endif() | endif() | ||||
| add_library(cpu_fp16_kernel_mid OBJECT ${FP16_KERNEL_SRC}) | add_library(cpu_fp16_kernel_mid OBJECT ${FP16_KERNEL_SRC}) | ||||
| add_dependencies(cpu_fp16_kernel_mid fbs_src) | add_dependencies(cpu_fp16_kernel_mid fbs_src) | ||||
| @@ -385,7 +385,7 @@ int NetTrain::RunNetTrain() { | |||||
| } else { | } else { | ||||
| context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; | context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; | ||||
| } | } | ||||
| context->device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; | |||||
| layer_checksum_ = flags_->layer_checksum_; | layer_checksum_ = flags_->layer_checksum_; | ||||
| context->thread_num_ = flags_->num_threads_; | context->thread_num_ = flags_->num_threads_; | ||||
| session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); | session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); | ||||
| @@ -545,6 +545,7 @@ int NetTrain::Init() { | |||||
| MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_; | MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_; | ||||
| MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_; | MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_; | ||||
| MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_; | MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_; | ||||
| MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_; | |||||
| if (this->flags_->epochs_ < 0) { | if (this->flags_->epochs_ < 0) { | ||||
| MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0"; | MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0"; | ||||
| @@ -66,6 +66,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { | |||||
| AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); | AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); | ||||
| AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); | AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); | ||||
| AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); | AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); | ||||
| AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false); | |||||
| } | } | ||||
| ~NetTrainFlags() override = default; | ~NetTrainFlags() override = default; | ||||
| @@ -82,6 +83,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { | |||||
| DataType in_data_type_; | DataType in_data_type_; | ||||
| std::string in_data_type_in_ = "bin"; | std::string in_data_type_in_ = "bin"; | ||||
| int cpu_bind_mode_ = 1; | int cpu_bind_mode_ = 1; | ||||
| bool enable_fp16_ = false; | |||||
| // MarkPerformance | // MarkPerformance | ||||
| int num_threads_ = 1; | int num_threads_ = 1; | ||||
| int warm_up_loop_count_ = 0; | int warm_up_loop_count_ = 0; | ||||