From: @HilbertDavid Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tongpull/14751/MERGE
| @@ -1,6 +1,6 @@ | |||||
| BASE_DIR=$(realpath ../../../../) | BASE_DIR=$(realpath ../../../../) | ||||
| APP:=bin/net_runner | APP:=bin/net_runner | ||||
| LMSLIB:=-lmindspore-lite | |||||
| LMSLIB:=-lmindspore-lite-train | |||||
| LMDLIB:=-lminddata-lite | LMDLIB:=-lminddata-lite | ||||
| MSDIR:=$(realpath package-$(TARGET)/lib) | MSDIR:=$(realpath package-$(TARGET)/lib) | ||||
| ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") | ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") | ||||
| @@ -99,11 +99,8 @@ void NetRunner::InitAndFigureInputs() { | |||||
| context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; | context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; | ||||
| context.thread_num_ = 2; | context.thread_num_ = 2; | ||||
| model_ = mindspore::lite::Model::Import(ms_file_); | |||||
| if (model_ == nullptr) { | |||||
| MS_LOG(ERROR) << "import model failed"; | |||||
| return nullptr; | |||||
| } | |||||
| model_ = mindspore::lite::Model::Import(ms_file_.c_str()); | |||||
| MS_ASSERT(nullptr != model_); | |||||
| session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true); | session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true); | ||||
| MS_ASSERT(nullptr != session_); | MS_ASSERT(nullptr != session_); | ||||
| @@ -169,7 +166,7 @@ int NetRunner::TrainLoop() { | |||||
| mindspore::lite::LossMonitor lm(100); | mindspore::lite::LossMonitor lm(100); | ||||
| mindspore::lite::ClassificationTrainAccuracyMonitor am(1); | mindspore::lite::ClassificationTrainAccuracyMonitor am(1); | ||||
| mindspore::lite::CkptSaver cs(1000, std::string("lenet")); | |||||
| mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_); | |||||
| Rescaler rescale(255.0); | Rescaler rescale(255.0); | ||||
| loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched}); | loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched}); | ||||
| @@ -187,7 +184,7 @@ int NetRunner::Main() { | |||||
| if (epochs_ > 0) { | if (epochs_ > 0) { | ||||
| auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; | auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; | ||||
| Model::Export(model_, trained_fn); | |||||
| mindspore::lite::Model::Export(model_, trained_fn.c_str()); | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -29,14 +29,14 @@ namespace lite { | |||||
| class CkptSaver : public session::TrainLoopCallBack { | class CkptSaver : public session::TrainLoopCallBack { | ||||
| public: | public: | ||||
| CkptSaver(int save_every_n, const std::string &filename_prefix) | |||||
| : save_every_n_(save_every_n), filename_prefix_(filename_prefix) {} | |||||
| CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model) | |||||
| : save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {} | |||||
| int EpochEnd(const session::TrainLoopCallBackData &cb_data) override { | int EpochEnd(const session::TrainLoopCallBackData &cb_data) override { | ||||
| if ((cb_data.epoch_ + 1) % save_every_n_ == 0) { | if ((cb_data.epoch_ + 1) % save_every_n_ == 0) { | ||||
| auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms"; | auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms"; | ||||
| remove(cpkt_fn.c_str()); | remove(cpkt_fn.c_str()); | ||||
| cb_data.session_->SaveToFile(cpkt_fn); | |||||
| Model::Export(model_, cpkt_fn.c_str()); | |||||
| } | } | ||||
| return session::RET_CONTINUE; | return session::RET_CONTINUE; | ||||
| } | } | ||||
| @@ -44,6 +44,7 @@ class CkptSaver : public session::TrainLoopCallBack { | |||||
| private: | private: | ||||
| int save_every_n_; | int save_every_n_; | ||||
| std::string filename_prefix_; | std::string filename_prefix_; | ||||
| mindspore::lite::Model *model_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||