|
|
|
@@ -99,11 +99,8 @@ void NetRunner::InitAndFigureInputs() { |
|
|
|
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; |
|
|
|
context.thread_num_ = 2; |
|
|
|
|
|
|
|
model_ = mindspore::lite::Model::Import(ms_file_); |
|
|
|
if (model_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "import model failed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
model_ = mindspore::lite::Model::Import(ms_file_.c_str()); |
|
|
|
MS_ASSERT(nullptr != model_); |
|
|
|
session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true); |
|
|
|
|
|
|
|
MS_ASSERT(nullptr != session_); |
|
|
|
@@ -169,7 +166,7 @@ int NetRunner::TrainLoop() { |
|
|
|
|
|
|
|
mindspore::lite::LossMonitor lm(100); |
|
|
|
mindspore::lite::ClassificationTrainAccuracyMonitor am(1); |
|
|
|
mindspore::lite::CkptSaver cs(1000, std::string("lenet")); |
|
|
|
mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_); |
|
|
|
Rescaler rescale(255.0); |
|
|
|
|
|
|
|
loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched}); |
|
|
|
@@ -187,7 +184,7 @@ int NetRunner::Main() { |
|
|
|
|
|
|
|
if (epochs_ > 0) { |
|
|
|
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; |
|
|
|
Model::Export(model_, trained_fn); |
|
|
|
mindspore::lite::Model::Export(model_, trained_fn.c_str()); |
|
|
|
} |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|