|
|
|
@@ -66,6 +66,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { |
|
|
|
AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); |
|
|
|
AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); |
|
|
|
AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); |
|
|
|
AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false); |
|
|
|
} |
|
|
|
|
|
|
|
~NetTrainFlags() override = default; |
|
|
|
@@ -82,6 +83,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { |
|
|
|
DataType in_data_type_; |
|
|
|
std::string in_data_type_in_ = "bin"; |
|
|
|
int cpu_bind_mode_ = 1; |
|
|
|
bool enable_fp16_ = false; |
|
|
|
// MarkPerformance |
|
|
|
int num_threads_ = 1; |
|
|
|
int warm_up_loop_count_ = 0; |
|
|
|
|