Browse Source

fix train_lenet demo

fix code-style
pull/14751/head
lz 4 years ago
parent
commit
88106f3f59
3 changed files with 9 additions and 11 deletions
  1. +1
    -1
      mindspore/lite/examples/train_lenet/Makefile
  2. +4
    -7
      mindspore/lite/examples/train_lenet/src/net_runner.cc
  3. +4
    -3
      mindspore/lite/include/train/ckpt_saver.h

+ 1
- 1
mindspore/lite/examples/train_lenet/Makefile View File

@@ -1,6 +1,6 @@
BASE_DIR=$(realpath ../../../../)
APP:=bin/net_runner
LMSLIB:=-lmindspore-lite
LMSLIB:=-lmindspore-lite-train
LMDLIB:=-lminddata-lite
MSDIR:=$(realpath package-$(TARGET)/lib)
ifneq ("$(wildcard $(MSDIR)/libhiai.so)","")


+ 4
- 7
mindspore/lite/examples/train_lenet/src/net_runner.cc View File

@@ -99,11 +99,8 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;

model_ = mindspore::lite::Model::Import(ms_file_);
if (model_ == nullptr) {
MS_LOG(ERROR) << "import model failed";
return nullptr;
}
model_ = mindspore::lite::Model::Import(ms_file_.c_str());
MS_ASSERT(nullptr != model_);
session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true);

MS_ASSERT(nullptr != session_);
@@ -169,7 +166,7 @@ int NetRunner::TrainLoop() {

mindspore::lite::LossMonitor lm(100);
mindspore::lite::ClassificationTrainAccuracyMonitor am(1);
mindspore::lite::CkptSaver cs(1000, std::string("lenet"));
mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_);
Rescaler rescale(255.0);

loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched});
@@ -187,7 +184,7 @@ int NetRunner::Main() {

if (epochs_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
Model::Export(model_, trained_fn);
mindspore::lite::Model::Export(model_, trained_fn.c_str());
}
return 0;
}


+ 4
- 3
mindspore/lite/include/train/ckpt_saver.h View File

@@ -29,14 +29,14 @@ namespace lite {

class CkptSaver : public session::TrainLoopCallBack {
public:
CkptSaver(int save_every_n, const std::string &filename_prefix)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix) {}
CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {}

int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
remove(cpkt_fn.c_str());
cb_data.session_->SaveToFile(cpkt_fn);
Model::Export(model_, cpkt_fn.c_str());
}
return session::RET_CONTINUE;
}
@@ -44,6 +44,7 @@ class CkptSaver : public session::TrainLoopCallBack {
private:
int save_every_n_;
std::string filename_prefix_;
mindspore::lite::Model *model_ = nullptr;
};

} // namespace lite


Loading…
Cancel
Save