From: @HilbertDavid Reviewed-by: @zhanghaibo5,@ddwsky Signed-off-by: @zhanghaibo5pull/14563/MERGE
| @@ -99,7 +99,13 @@ void NetRunner::InitAndFigureInputs() { | |||
| context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; | |||
| context.thread_num_ = 2; | |||
| session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context); | |||
| model_ = mindspore::lite::Model::Import(ms_file_); | |||
| if (model_ == nullptr) { | |||
| MS_LOG(ERROR) << "import model failed"; | |||
| return nullptr; | |||
| } | |||
| session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true); | |||
| MS_ASSERT(nullptr != session_); | |||
| loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_); | |||
| @@ -154,7 +160,6 @@ int NetRunner::InitDB() { | |||
| std::cout << "No relevant data was found in " << data_dir_ << std::endl; | |||
| MS_ASSERT(train_ds_->GetDatasetSize() != 0); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -182,7 +187,7 @@ int NetRunner::Main() { | |||
| if (epochs_ > 0) { | |||
| auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; | |||
| session_->SaveToFile(trained_fn); | |||
| Model::Export(model_, trained_fn); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -27,6 +27,7 @@ | |||
| #include "include/train/accuracy_metrics.h" | |||
| #include "include/ms_tensor.h" | |||
| #include "include/datasets.h" | |||
| #include "include/model.h" | |||
| using mindspore::dataset::Dataset; | |||
| using mindspore::lite::AccuracyMetrics; | |||
| @@ -36,6 +37,7 @@ class NetRunner { | |||
| int Main(); | |||
| bool ReadArgs(int argc, char *argv[]); | |||
| ~NetRunner(); | |||
| mindspore::lite::Model *model_ = nullptr; | |||
| private: | |||
| void Usage(); | |||
| @@ -45,13 +45,17 @@ struct MS_API Model { | |||
| SubGraphPtrVector sub_graphs_; | |||
| /// \brief Static method to create a Model pointer. | |||
| /// | |||
| /// \param[in] model_buf Define the buffer read from a model file. | |||
| /// \param[in] size Define bytes number of model buffer. | |||
| /// | |||
| /// \return Pointer of MindSpore Lite Model. | |||
| static Model *Import(const char *model_buf, size_t size); | |||
| /// \brief Static method to create a Model pointer. | |||
| static Model *Import(const char *filename); | |||
| /// \brief method to export model to file. | |||
| static int Export(Model *model, const char *filename); | |||
| /// \brief method to export model to buffer. | |||
| static int Export(Model *model, char *buf, size_t *size); | |||
| /// \brief Free meta graph temporary buffer | |||
| virtual void Free() = 0; | |||
| @@ -32,23 +32,12 @@ class TrainSession : public session::LiteSession { | |||
| /// \brief Static method to create a TrainSession object | |||
| /// | |||
| /// \param[in] model_buf A buffer that was read from a MS model file | |||
| /// \param[in] size Length of the buffer | |||
| /// \param[in] model A buffer that was read from a MS model file | |||
| /// \param[in] context Defines the context of the session to be created | |||
| /// \param[in] train_mode training mode to initialize Session with | |||
| /// | |||
| /// \return Pointer of MindSpore Lite TrainSession | |||
| static TrainSession *CreateSession(const char *model_buf, size_t size, lite::Context *context, | |||
| bool train_mode = false); | |||
| /// \brief Static method to create a TrainSession object | |||
| /// | |||
| /// \param[in] filename Filename to read flatbuffer from | |||
| /// \param[in] context Defines the context of the session to be created | |||
| /// \param[in] train_mode training mode to initialize Session with | |||
| /// | |||
| /// \return Pointer of MindSpore Lite TrainSession | |||
| static TrainSession *CreateSession(const std::string &filename, lite::Context *context, bool train_mode = false); | |||
| static TrainSession *CreateSession(mindspore::lite::Model *model, lite::Context *context, bool train_mode = false); | |||
| /// \brief Static method to create a transfer lernning support TrainSession object | |||
| /// | |||
| @@ -75,21 +64,6 @@ class TrainSession : public session::LiteSession { | |||
| static TrainSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head, | |||
| lite::Context *context, bool train_mode = false); | |||
| /// \brief Export the trained model into a buffer | |||
| /// | |||
| /// \param[in] buf The buffer to Export into. If equal to nullptr, buf will be allocated | |||
| /// \param[in,out] len Size of the pre-allocated buffer, and returned size of the exported buffer | |||
| /// | |||
| /// \return pointer to the export buffer | |||
| virtual void *ExportToBuf(char *buf, size_t *len) const = 0; | |||
| /// \brief Save the trained model into a flatbuffer file | |||
| /// | |||
| /// \param[in] filename Filename to save flatbuffer to | |||
| /// | |||
| /// \return 0 on success or -1 in case of error | |||
| virtual int SaveToFile(const std::string &filename) const = 0; | |||
| /// \brief Set model to train mode | |||
| /// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h | |||
| virtual int Train() = 0; | |||
| @@ -111,7 +111,6 @@ if(SUPPORT_TRAIN) | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_model.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/loss_monitor.cc | |||
| @@ -15,9 +15,13 @@ | |||
| */ | |||
| #include "src/lite_model.h" | |||
| #include <sys/stat.h> | |||
| #include <iostream> | |||
| #include <fstream> | |||
| #include <vector> | |||
| #include <set> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "src/common/prim_util.h" | |||
| #ifdef ENABLE_V0 | |||
| #include "src/ops/compat/compat_register.h" | |||
| @@ -343,5 +347,102 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { | |||
| return model; | |||
| } | |||
| std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size) { | |||
| std::ifstream ifs(filename); | |||
| if (!ifs.good()) { | |||
| MS_LOG(ERROR) << "File: " << filename << " does not exist"; | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| if (!ifs.is_open()) { | |||
| MS_LOG(ERROR) << "File: " << filename << " open failed"; | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| ifs.seekg(0, std::ios::end); | |||
| auto tellg_ret = ifs.tellg(); | |||
| if (tellg_ret <= 0) { | |||
| MS_LOG(ERROR) << "Could not read file " << filename; | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| size_t fsize = static_cast<size_t>(tellg_ret); | |||
| std::unique_ptr<char[]> buf(new (std::nothrow) char[fsize]); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "malloc buf failed, file: " << filename; | |||
| ifs.close(); | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(buf.get(), fsize); | |||
| if (!ifs) { | |||
| MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename; | |||
| ifs.close(); | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| ifs.close(); | |||
| if (size != nullptr) { | |||
| *size = fsize; | |||
| } | |||
| return buf; | |||
| } | |||
| Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } | |||
| Model *Model::Import(const char *filename) { | |||
| size_t size = -1; | |||
| auto buf = ReadFileToBuf(filename, &size); | |||
| if (buf == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return ImportFromBuffer(buf.get(), size, false); | |||
| } | |||
| int Model::Export(Model *model, char *buffer, size_t *len) { | |||
| if (len == nullptr) { | |||
| MS_LOG(ERROR) << "len is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto *liteModel = reinterpret_cast<LiteModel *>(model); | |||
| if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) { | |||
| MS_LOG(ERROR) << "model buffer is invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| if (*len < liteModel->buf_size_ && buffer != nullptr) { | |||
| MS_LOG(ERROR) << "Buffer is too small, Export Failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (buffer == nullptr) { | |||
| buffer = reinterpret_cast<char *>(malloc(liteModel->buf_size_)); | |||
| if (buffer == nullptr) { | |||
| MS_LOG(ERROR) << "allocated model buf fail!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| memcpy(buffer, liteModel->buf, liteModel->buf_size_); | |||
| *len = liteModel->buf_size_; | |||
| return RET_OK; | |||
| } | |||
| int Model::Export(Model *model, const char *filename) { | |||
| auto *liteModel = reinterpret_cast<LiteModel *>(model); | |||
| if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) { | |||
| MS_LOG(ERROR) << "model buf is invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| std::ofstream ofs(filename); | |||
| if (!ofs.good() || !ofs.is_open()) { | |||
| MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing"; | |||
| return RET_ERROR; | |||
| } | |||
| ofs.seekp(0, std::ios::beg); | |||
| ofs.write(liteModel->buf, liteModel->buf_size_); | |||
| ofs.close(); | |||
| return chmod(filename, S_IRUSR); | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -1,91 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/train/train_model.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/graph_util.h" | |||
| namespace mindspore::lite { | |||
| TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||
| if (model_buf == nullptr) { | |||
| MS_LOG(ERROR) << "The model buf is nullptr"; | |||
| return nullptr; | |||
| } | |||
| TrainModel *model = new (std::nothrow) TrainModel(); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "new model fail!"; | |||
| return nullptr; | |||
| } | |||
| model->buf = reinterpret_cast<char *>(malloc(size)); | |||
| if (model->buf == nullptr) { | |||
| delete model; | |||
| MS_LOG(ERROR) << "malloc inner model buf fail!"; | |||
| return nullptr; | |||
| } | |||
| memcpy(model->buf, model_buf, size); | |||
| model->buf_size_ = size; | |||
| auto status = model->ConstructModel(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "construct model failed."; | |||
| delete model; | |||
| return nullptr; | |||
| } | |||
| return model; | |||
| } | |||
| void TrainModel::Free() {} | |||
| char *TrainModel::ExportBuf(char *buffer, size_t *len) const { | |||
| if (len == nullptr) { | |||
| MS_LOG(ERROR) << "len is nullptr"; | |||
| return nullptr; | |||
| } | |||
| if (buf_size_ == 0 || buf == nullptr) { | |||
| MS_LOG(ERROR) << "Model::Export is only available for Train Session"; | |||
| return nullptr; | |||
| } | |||
| if (*len < buf_size_ && buffer != nullptr) { | |||
| MS_LOG(ERROR) << "Buffer is too small, Export Failed"; | |||
| return nullptr; | |||
| } | |||
| if (buffer == nullptr) { | |||
| buffer = reinterpret_cast<char *>(malloc(buf_size_)); | |||
| if (buffer == nullptr) { | |||
| MS_LOG(ERROR) << "allocated model buf fail!"; | |||
| return nullptr; | |||
| } | |||
| } | |||
| memcpy(buffer, buf, buf_size_); | |||
| *len = buf_size_; | |||
| return buffer; | |||
| } | |||
| char *TrainModel::GetBuffer(size_t *len) const { | |||
| if (len == nullptr) { | |||
| MS_LOG(ERROR) << "len is nullptr"; | |||
| return nullptr; | |||
| } | |||
| if (buf_size_ == 0 || buf == nullptr) { | |||
| MS_LOG(ERROR) << "Model::Export is only available for Train Session"; | |||
| return nullptr; | |||
| } | |||
| *len = buf_size_; | |||
| return buf; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -1,57 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ | |||
| #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ | |||
| #include <vector> | |||
| #include "src/lite_model.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| /// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model | |||
| struct TrainModel : public lite::LiteModel { | |||
| /// \brief Static method to create a TrainModel object | |||
| /// | |||
| /// \param[in] model_buf A buffer that was read from a MS model file | |||
| /// \param[in] size Length of the buffer | |||
| // | |||
| /// \return Pointer to MindSpore Lite TrainModel | |||
| static TrainModel *Import(const char *model_buf, size_t size); | |||
| /// \brief Free meta graph related data | |||
| void Free() override; | |||
| /// \brief Class destructor, free all memory | |||
| virtual ~TrainModel() = default; | |||
| /// \brief Export Model into a buffer | |||
| /// | |||
| /// \param[in] buf The buffer to Export into. If equal to nullptr, buf will be allocated | |||
| /// \param[in,out] len Size of the pre-allocated buffer, and returned size of the exported buffer | |||
| /// | |||
| /// \return Pointer to buffer with exported model | |||
| char *ExportBuf(char *buf, size_t *len) const; | |||
| /// \brief Get Model buffer | |||
| /// | |||
| /// \param[in,out] len Return size of the buffer | |||
| /// | |||
| /// \return Pointer to model buffer | |||
| char *GetBuffer(size_t *len) const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ | |||
| @@ -25,6 +25,7 @@ | |||
| #include "include/errorcode.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/tensor.h" | |||
| #include "src/lite_model.h" | |||
| #include "src/train/loss_kernel.h" | |||
| #include "src/train/optimizer_kernel.h" | |||
| #include "src/sub_graph_kernel.h" | |||
| @@ -40,47 +41,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size) { | |||
| std::ifstream ifs(filename); | |||
| if (!ifs.good()) { | |||
| MS_LOG(ERROR) << "File: " << filename << " does not exist"; | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| if (!ifs.is_open()) { | |||
| MS_LOG(ERROR) << "File: " << filename << " open failed"; | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| ifs.seekg(0, std::ios::end); | |||
| auto tellg_ret = ifs.tellg(); | |||
| if (tellg_ret <= 0) { | |||
| MS_LOG(ERROR) << "Could not read file " << filename; | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| size_t fsize = static_cast<size_t>(tellg_ret); | |||
| std::unique_ptr<char[]> buf(new (std::nothrow) char[fsize]); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "malloc buf failed, file: " << filename; | |||
| ifs.close(); | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(buf.get(), fsize); | |||
| if (!ifs) { | |||
| MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename; | |||
| ifs.close(); | |||
| return std::unique_ptr<char[]>(nullptr); | |||
| } | |||
| ifs.close(); | |||
| if (size != nullptr) { | |||
| *size = fsize; | |||
| } | |||
| return buf; | |||
| } | |||
| static size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) { | |||
| for (size_t i = 0; i < where.size(); i++) { | |||
| if (where[i] == searchParameter) { | |||
| @@ -140,7 +100,7 @@ void TrainSession::AllocWorkSpace() { | |||
| int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; } | |||
| int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { | |||
| int TrainSession::CompileTrainGraph(mindspore::lite::Model *model) { | |||
| model_ = model; | |||
| auto restore = ReplaceOps(); | |||
| @@ -172,8 +132,6 @@ TrainSession::~TrainSession() { | |||
| } | |||
| } | |||
| void *TrainSession::ExportToBuf(char *buf, size_t *len) const { return model_->ExportBuf(buf, len); } | |||
| int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) { | |||
| this->outputs_.clear(); | |||
| @@ -231,25 +189,6 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a | |||
| return RET_OK; | |||
| } | |||
| int TrainSession::SaveToFile(const std::string &filename) const { | |||
| size_t fb_size = 0; | |||
| const auto *buf = reinterpret_cast<char *>(model_->GetBuffer(&fb_size)); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "Could not Export Trained model"; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| std::ofstream ofs(filename); | |||
| if ((true != ofs.good()) || (true != ofs.is_open())) { | |||
| MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing"; | |||
| return RET_ERROR; | |||
| } | |||
| ofs.seekp(0, std::ios::beg); | |||
| ofs.write(buf, fb_size); | |||
| ofs.close(); | |||
| return chmod(filename.c_str(), S_IRUSR); | |||
| } | |||
| int TrainSession::Train() { | |||
| // shift kernels to train mode | |||
| train_mode_ = true; | |||
| @@ -539,14 +478,8 @@ int TrainSession::SetLossName(std::string loss_name) { | |||
| } | |||
| } // namespace lite | |||
| session::TrainSession *session::TrainSession::CreateSession(const char *model_buf, size_t size, lite::Context *context, | |||
| session::TrainSession *session::TrainSession::CreateSession(mindspore::lite::Model *model, lite::Context *context, | |||
| bool train_mode) { | |||
| auto model = mindspore::lite::TrainModel::Import(model_buf, size); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "create model for train session failed"; | |||
| return nullptr; | |||
| } | |||
| auto session = new (std::nothrow) lite::TrainSession(); | |||
| if (session == nullptr) { | |||
| delete model; | |||
| @@ -581,15 +514,4 @@ session::TrainSession *session::TrainSession::CreateSession(const char *model_bu | |||
| return session; | |||
| } | |||
| session::TrainSession *session::TrainSession::CreateSession(const std::string &filename, lite::Context *context, | |||
| bool train_mode) { | |||
| size_t size = -1; | |||
| auto buf = lite::ReadFileToBuf(filename, &size); | |||
| if (buf == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return session::TrainSession::CreateSession(buf.get(), size, context, train_mode); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -21,7 +21,6 @@ | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "include/train/train_session.h" | |||
| #include "src/train/train_model.h" | |||
| #include "src/lite_session.h" | |||
| /* | |||
| @@ -52,10 +51,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: | |||
| int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; | |||
| int CompileGraph(lite::Model *model) override; | |||
| virtual int CompileTrainGraph(lite::TrainModel *model); | |||
| void *ExportToBuf(char *buf, size_t *len) const override; | |||
| int SaveToFile(const std::string &filename) const override; | |||
| virtual int CompileTrainGraph(lite::Model *model); | |||
| int Train() override; | |||
| int Eval() override; | |||
| @@ -108,7 +104,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: | |||
| virtual void CompileTrainOutputs(); | |||
| virtual void CompileEvalOutputs(); | |||
| TrainModel *model_ = nullptr; | |||
| Model *model_ = nullptr; | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_node_map_; | |||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_; | |||
| std::vector<std::string> orig_output_tensor_names_; | |||
| @@ -190,7 +190,7 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char * | |||
| return nullptr; | |||
| } | |||
| auto model = lite::TrainModel::Import(model_buf_head, size_head); | |||
| auto model = lite::Model::Import(model_buf_head, size_head); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "create model for head train session failed"; | |||
| delete session; | |||
| @@ -20,7 +20,6 @@ | |||
| #include <tuple> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "src/train/train_model.h" | |||
| #include "src/lite_session.h" | |||
| #include "src/train/train_session.h" | |||
| @@ -281,7 +281,6 @@ if(SUPPORT_TRAIN) | |||
| ${LITE_DIR}/src/train/train_populate_parameter_v0.cc | |||
| ${LITE_DIR}/src/train/train_session.cc | |||
| ${LITE_DIR}/src/train/transfer_session.cc | |||
| ${LITE_DIR}/src/train/train_model.cc | |||
| ${LITE_DIR}/src/lite_session.cc | |||
| ) | |||
| else() | |||
| @@ -359,10 +359,14 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| meta_graph.reset(); | |||
| content = nullptr; | |||
| auto *model = mindspore::lite::Model::Import(content, size); | |||
| ASSERT_NE(nullptr, model); | |||
| lite::Context context; | |||
| context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context.thread_num_ = 1; | |||
| auto session = session::TrainSession::CreateSession(content, size, &context); | |||
| auto session = session::TrainSession::CreateSession(model, &context); | |||
| ASSERT_NE(nullptr, session); | |||
| session->Train(); | |||
| session->Train(); // Just double check that calling Train twice does not cause a problem | |||
| @@ -513,7 +517,10 @@ TEST_F(NetworkTest, efficient_net) { | |||
| context->thread_num_ = 1; | |||
| std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms"; | |||
| auto session = session::TrainSession::CreateSession(net, context, false); | |||
| auto *model = mindspore::lite::Model::Import(net.c_str()); | |||
| ASSERT_NE(model, nullptr); | |||
| auto session = session::TrainSession::CreateSession(model, context, false); | |||
| ASSERT_NE(session, nullptr); | |||
| std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; | |||
| @@ -530,7 +537,6 @@ TEST_F(NetworkTest, mobileface_net) { | |||
| std::string net = "./test_data/nets/mobilefacenet0924.ms"; | |||
| ReadFile(net.c_str(), &net_size, &buf); | |||
| // auto model = lite::TrainModel::Import(buf, net_size); | |||
| auto model = lite::Model::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::Context; | |||
| @@ -538,7 +544,6 @@ TEST_F(NetworkTest, mobileface_net) { | |||
| context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| // auto session = session::TrainSession::CreateSession(context); | |||
| auto session = session::LiteSession::CreateSession(context); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->CompileGraph(model); | |||
| @@ -560,7 +565,10 @@ TEST_F(NetworkTest, setname) { | |||
| lite::Context context; | |||
| context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context.thread_num_ = 1; | |||
| auto session = mindspore::session::TrainSession::CreateSession(net, &context); | |||
| auto *model = mindspore::lite::Model::Import(net.c_str()); | |||
| ASSERT_NE(model, nullptr); | |||
| auto session = mindspore::session::TrainSession::CreateSession(model, &context); | |||
| ASSERT_NE(session, nullptr); | |||
| auto tensors_map = session->GetOutputs(); | |||
| @@ -25,6 +25,7 @@ | |||
| #include "include/context.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/version.h" | |||
| #include "include/model.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -326,7 +327,14 @@ int NetTrain::RunExportedNet() { | |||
| } | |||
| context->thread_num_ = flags_->num_threads_; | |||
| session_ = session::TrainSession::CreateSession(flags_->export_file_.c_str(), context.get()); | |||
| auto *model = mindspore::lite::Model::Import(flags_->export_file_.c_str()); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "create model for train session failed"; | |||
| return RET_ERROR; | |||
| } | |||
| session_ = session::TrainSession::CreateSession(model, context.get()); | |||
| if (session_ == nullptr) { | |||
| MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str(); | |||
| std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl; | |||
| @@ -388,7 +396,13 @@ int NetTrain::RunNetTrain() { | |||
| context->device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; | |||
| layer_checksum_ = flags_->layer_checksum_; | |||
| context->thread_num_ = flags_->num_threads_; | |||
| session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); | |||
| auto *model = mindspore::lite::Model::Import(flags_->model_file_.c_str()); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "create model for train session failed"; | |||
| return RET_ERROR; | |||
| } | |||
| session_ = session::TrainSession::CreateSession(model, context.get()); | |||
| if (session_ == nullptr) { | |||
| MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str(); | |||
| std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl; | |||
| @@ -432,7 +446,7 @@ int NetTrain::RunNetTrain() { | |||
| } | |||
| } | |||
| if (!flags_->export_file_.empty()) { | |||
| auto ret = session_->SaveToFile(flags_->export_file_); | |||
| auto ret = Model::Export(model, flags_->export_file_.c_str()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SaveToFile error"; | |||
| std::cout << "Run SaveToFile error"; | |||