Merge pull request !19433 from zhengjun10/fix_issuetags/v1.4.0
| @@ -29,6 +29,7 @@ | |||
| #include "include/train/ckpt_saver.h" | |||
| #include "include/train/lr_scheduler.h" | |||
| #include "include/train/accuracy_metrics.h" | |||
| #include "include/train/train_session.h" | |||
| #include "include/train/classification_train_accuracy_monitor.h" | |||
| #include "src/utils.h" | |||
| #include "include/dataset/datasets.h" | |||
| @@ -143,7 +144,7 @@ void NetRunner::InitAndFigureInputs() { | |||
| context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; | |||
| context.thread_num_ = 2; | |||
| session_ = mindspore::session::LiteSession::CreateTrainSession(ms_file_, &context, true); | |||
| session_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true); | |||
| MS_ASSERT(session_ != nullptr); | |||
| session_->SetupVirtualBatch(virtual_batch_); | |||
| @@ -25,6 +25,7 @@ | |||
| #include <iostream> | |||
| #include "include/context.h" | |||
| #include "include/lite_session.h" | |||
| #include "include/train/train_session.h" | |||
| #include "src/utils.h" | |||
| static unsigned int seed = time(NULL); | |||
| @@ -77,7 +78,7 @@ void NetRunner::InitAndFigureInputs() { | |||
| context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = enable_fp16_; | |||
| context.thread_num_ = 1; | |||
| session_ = mindspore::session::LiteSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context); | |||
| session_ = mindspore::session::TrainSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context); | |||
| MS_ASSERT(session_ != nullptr); | |||
| auto inputs = session_->GetInputs(); | |||
| @@ -128,29 +128,6 @@ class MS_API LiteSession { | |||
| /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. | |||
| virtual int Resize(const Vector<tensor::MSTensor *> &inputs, const Vector<Vector<int>> &dims) = 0; | |||
| /// \brief Static method to create a TrainSession object | |||
| /// | |||
| /// \param[in] filename name of flatbuffer that holds the flatbuffer | |||
| /// \param[in] context Defines the context of the session to be created | |||
| /// \param[in] train_mode training mode to initialize Session with | |||
| /// \param[in] cfg training configuration, set to null for default configuration | |||
| /// | |||
| /// \return Pointer of MindSpore LiteSession | |||
| static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context, | |||
| bool train_mode = false, const lite::TrainCfg *cfg = nullptr); | |||
| /// \brief Static method to create a TransferSession object | |||
| /// | |||
| /// \param[in] filename_backbone Filename to read backbone net flatbuffer from | |||
| /// \param[in] filename_head Filename to read head net flatbuffer from | |||
| /// \param[in] context Defines the context of the session to be created | |||
| /// \param[in] train_mode training mode to initialize Session with | |||
| /// | |||
| /// \return Pointer of MindSpore LiteSession | |||
| static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head, | |||
| const lite::Context *context, bool train_mode = false, | |||
| const lite::TrainCfg *cfg = nullptr); | |||
| /// \brief Set model to train mode | |||
| /// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h | |||
| virtual int Train() { return mindspore::lite::RET_ERROR; } | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_ | |||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_ | |||
| #include <string> | |||
| #include "include/lite_session.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| class TrainSession { | |||
| public: | |||
| /// \brief Static method to create a TransferSession object | |||
| /// | |||
| /// \param[in] filename_backbone Filename to read backbone net flatbuffer from | |||
| /// \param[in] filename_head Filename to read head net flatbuffer from | |||
| /// \param[in] context Defines the context of the session to be created | |||
| /// \param[in] train_mode training mode to initialize Session with | |||
| /// | |||
| /// \return Pointer of MindSpore LiteSession | |||
| static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head, | |||
| const lite::Context *context, bool train_mode = false, | |||
| const lite::TrainCfg *cfg = nullptr); | |||
| /// \brief Static method to create a TrainSession object | |||
| /// | |||
| /// \param[in] filename name of flatbuffer that holds the flatbuffer | |||
| /// \param[in] context Defines the context of the session to be created | |||
| /// \param[in] train_mode training mode to initialize Session with | |||
| /// \param[in] cfg training configuration, set to null for default configuration | |||
| /// | |||
| /// \return Pointer of MindSpore LiteSession | |||
| static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context, | |||
| bool train_mode = false, const lite::TrainCfg *cfg = nullptr); | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_ | |||
| @@ -16,7 +16,7 @@ | |||
| #include <jni.h> | |||
| #include "common/ms_log.h" | |||
| #include "include/lite_session.h" | |||
| #include "include/train/train_session.h" | |||
| #include "include/train/train_cfg.h" | |||
| #include "include/errorcode.h" | |||
| @@ -32,7 +32,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTra | |||
| } | |||
| auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer); | |||
| auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE), | |||
| auto session = mindspore::session::TrainSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE), | |||
| lite_context_ptr, train_mode, nullptr); | |||
| if (session == nullptr) { | |||
| MS_LOGE("CreateTrainSession failed"); | |||
| @@ -741,8 +741,8 @@ int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &featu | |||
| } | |||
| } // namespace lite | |||
| session::LiteSession *session::LiteSession::CreateTrainSession(const std::string &fn, const lite::Context *context, | |||
| bool train_mode, const lite::TrainCfg *cfg) { | |||
| session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context, | |||
| bool train_mode, const lite::TrainCfg *cfg) { | |||
| auto session = std::make_unique<lite::TrainSession>(); | |||
| if (session == nullptr) { | |||
| MS_LOG(ERROR) << "create session failed"; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <memory> | |||
| #include <map> | |||
| #include "include/train/train_cfg.h" | |||
| #include "include/train/train_session.h" | |||
| #include "src/lite_session.h" | |||
| /* | |||
| @@ -290,10 +290,10 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back | |||
| return session; | |||
| } | |||
| session::LiteSession *session::LiteSession::CreateTransferSession(const std::string &filename_backbone, | |||
| const std::string &filename_head, | |||
| const lite::Context *ctxt, bool train_mode, | |||
| const lite::TrainCfg *cfg) { | |||
| session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone, | |||
| const std::string &filename_head, | |||
| const lite::Context *ctxt, bool train_mode, | |||
| const lite::TrainCfg *cfg) { | |||
| size_t size_head = 0; | |||
| size_t size_backbone = 0; | |||
| std::string filename = filename_head; | |||
| @@ -28,6 +28,7 @@ | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/train/train_cfg.h" | |||
| #include "include/train/train_session.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "src/kernel_registry.h" | |||
| @@ -102,7 +103,7 @@ TEST_F(NetworkTest, efficient_net) { | |||
| context->thread_num_ = 1; | |||
| std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms"; | |||
| auto session = session::LiteSession::CreateTrainSession(net, context, false); | |||
| auto session = session::TrainSession::CreateTrainSession(net, context, false); | |||
| ASSERT_NE(session, nullptr); | |||
| std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; | |||
| @@ -150,7 +151,7 @@ TEST_F(NetworkTest, noname) { | |||
| lite::TrainCfg cfg; | |||
| cfg.loss_name_ = "nhwc"; | |||
| auto session = mindspore::session::LiteSession::CreateTrainSession(net, &context, true, &cfg); | |||
| auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &cfg); | |||
| ASSERT_NE(session, nullptr); | |||
| auto tensors_map = session->GetOutputs(); | |||
| auto tensor_names = session->GetOutputTensorNames(); | |||
| @@ -170,7 +171,7 @@ TEST_F(NetworkTest, setname) { | |||
| lite::TrainCfg train_cfg; | |||
| train_cfg.loss_name_ = "nhwc"; | |||
| auto session = mindspore::session::LiteSession::CreateTrainSession(net, &context, true, &train_cfg); | |||
| auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &train_cfg); | |||
| ASSERT_NE(session, nullptr); | |||
| auto tensors_map = session->GetOutputs(); | |||
| @@ -30,6 +30,7 @@ | |||
| #include "include/version.h" | |||
| #include "include/model.h" | |||
| #include "include/train/train_cfg.h" | |||
| #include "include/train/train_session.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -338,7 +339,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string | |||
| MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename; | |||
| std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl; | |||
| session = std::unique_ptr<session::LiteSession>( | |||
| session::LiteSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg)); | |||
| session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg)); | |||
| if (session == nullptr) { | |||
| MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str(); | |||
| std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl; | |||
| @@ -349,7 +350,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string | |||
| MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str(); | |||
| std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl; | |||
| session = std::unique_ptr<session::LiteSession>( | |||
| session::LiteSession::CreateTrainSession(filename, &context, true, &train_cfg)); | |||
| session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg)); | |||
| if (session == nullptr) { | |||
| MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str(); | |||
| std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl; | |||