You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.h 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
  17. #define MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H
  18. #include <cstddef>
  19. #include <string>
  20. #include <vector>
  21. #include <memory>
  22. #include "include/api/data_type.h"
  23. #include "include/api/dual_abi_helper.h"
  24. #ifdef _WIN32
  25. #define MS_API __declspec(dllexport)
  26. #else
  27. #define MS_API __attribute__((visibility("default")))
  28. #endif
  29. namespace mindspore {
  30. class Model;
  31. class ModelImpl;
  32. class CallbackImpl;
  33. struct TrainCallBackData {
  34. TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch),
  35. step_(step), model_(model) {}
  36. bool train_mode_; /**< training mode of LiteSession object */
  37. unsigned int epoch_; /**< the current training epoch (starts at 0) */
  38. unsigned int step_ = 0; /**< the current step within the epoch */
  39. Model *model_; /**< pointer to the Model object */
  40. };
  41. enum CallbackRetValue : uint32_t {
  42. kContinue = 0,
  43. kStopTraining = 1,
  44. kExit = 2,
  45. kUnknownRetValue = 0xFFFFFFFF
  46. };
  47. class TrainCallBack {
  48. public:
  49. virtual ~TrainCallBack() = default;
  50. /// \brief This method is called once before the network executing
  51. ///
  52. /// \param[in] cb_data info about current execution
  53. virtual void Begin(const TrainCallBackData &cb_data) {}
  54. /// \brief This method is called once following the network execution
  55. ///
  56. /// \param[in] cb_data info about current execution
  57. virtual void End(const TrainCallBackData &cb_data) {}
  58. /// \brief This method is called at the beginning of each epoch
  59. ///
  60. /// \param[in] cb_data info about current execution
  61. virtual void EpochBegin(const TrainCallBackData &cb_data) {}
  62. /// \brief This method is called after the run of each epoch
  63. ///
  64. /// \param[in] cb_data info about current execution
  65. ///
  66. /// \return indication if to continue in the train loop:
  67. /// RET_CONTINUE -- continue training
  68. /// RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy)
  69. /// RET_EXIT -- Exit training (due to error of some sort)
  70. virtual CallbackRetValue EpochEnd(const TrainCallBackData &cb_data) { return kContinue; }
  71. /// \brief This method is called at the beginning of each step
  72. ///
  73. /// \param[in] cb_data info about current execution
  74. virtual void StepBegin(const TrainCallBackData &cb_data) {}
  75. /// \brief This method is called after each step is ran
  76. ///
  77. /// \param[in] cb_data info about current execution
  78. virtual void StepEnd(const TrainCallBackData &cb_data) {}
  79. protected:
  80. friend class Model;
  81. friend class ModelImpl;
  82. CallbackImpl* callback_impl_ = nullptr;
  83. };
  84. } // namespace mindspore
  85. #endif // MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H