|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
- #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
- #include <vector>
- #include <string>
- #include <tuple>
- #include "include/lite_session.h"
- #include "include/errorcode.h"
-
- namespace mindspore {
- namespace session {
-
- /// \brief TrainSession Defines a class that allows training a MindSpore model
- class TrainSession : public session::LiteSession {
- public:
- /// \brief Class destructor
- virtual ~TrainSession() = default;
-
- /// \brief Static method to create a TrainSession object
- ///
- /// \param[in] model A buffer that was read from a MS model file
- /// \param[in] context Defines the context of the session to be created
- /// \param[in] train_mode training mode to initialize Session with
- ///
- /// \return Pointer of MindSpore Lite TrainSession
- static TrainSession *CreateSession(mindspore::lite::Model *model, lite::Context *context, bool train_mode = false);
-
- /// \brief Static method to create a transfer lernning support TrainSession object
- ///
- /// \param[in] model_buf_backbone A buffer that was read from a backbone MS model file
- /// \param[in] size_backbone Length of the backbone net buffer
- /// \param[in] model_buf_head A buffer that was read from a head MS model file
- /// \param[in] size_head Length of the head net buffer
- /// \param[in] context Defines the context of the session to be created
- /// \param[in] train_mode training mode to initialize Session with
- ///
- /// \return Pointer of MindSpore Lite TrainSession
- static TrainSession *CreateTransferSession(const char *model_buf_backbone, size_t size_backbone,
- const char *model_buf_head, size_t size_head, lite::Context *context,
- bool train_mode = false);
-
- /// \brief Static method to create a TrainSession object
- ///
- /// \param[in] filename_backbone Filename to read backbone net flatbuffer from
- /// \param[in] filename_head Filename to read head net flatbuffer from
- /// \param[in] context Defines the context of the session to be created
- /// \param[in] train_mode training mode to initialize Session with
- ///
- /// \return Pointer of MindSpore Lite TrainSession
- static TrainSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
- lite::Context *context, bool train_mode = false);
-
- /// \brief Set model to train mode
- /// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
- virtual int Train() = 0;
-
- /// \brief Check mode of model
- ///
- /// \return boolean indication if model is in train mode
- bool IsTrain() { return train_mode_ == true; }
-
- /// \brief Set model to eval mode
- /// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
- virtual int Eval() = 0;
-
- /// \brief Check mode of model
- ///
- /// \return boolean indication if model is in eval mode
- bool IsEval() { return train_mode_ == false; }
-
- /// \brief Sets the Learning Rate of the training
- ///
- /// \param[in] learning_rate to set
- ///
- /// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
- virtual int SetLearningRate(float learning_rate) = 0;
-
- /// \brief Gets the Learning Rate of the training
- ///
- /// \return learning rate. 0.0 if no optimizer was found
- virtual float GetLearningRate() = 0;
-
- /// \brief Setup training with virtual batches
- ///
- /// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable
- /// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration
- /// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration
-
- /// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
- virtual int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) = 0;
-
- /// \brief Get output MindSpore Lite MSTensors of Training model prediction
- ///
- /// \return a vector of output tensors (MindSpore Lite MSTensor).
- virtual std::vector<tensor::MSTensor *> GetPredictions() const = 0;
-
- /// \brief Set part of the name that identify a loss kernel
- /// \param[in] loss_name Identifucation name for loss kernels
- /// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
- virtual int SetLossName(std::string loss_name) {
- loss_name_ = loss_name;
- return mindspore::lite::RET_OK;
- }
- /// \brief Save model for inference (LiteSession)
- /// \param[in] fb_name pretrained model file name prefix. '.ms' is added as extension.
- /// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
- virtual int ExportInference(std::string fb_name) { return mindspore::lite::RET_ERROR; }
-
- protected:
- bool train_mode_ = false;
- std::string get_loss_name() const { return loss_name_; }
-
- private:
- std::string loss_name_ = "_loss_fn";
- };
- } // namespace session
- } // namespace mindspore
- #endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
|