You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback_manager.cc 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/callback/callback_manager.h"
  17. #include "minddata/dataset/callback/ds_callback.h"
  18. #include "minddata/dataset/util/status.h"
  19. #include "minddata/dataset/engine/datasetops/dataset_op.h"
  20. namespace mindspore {
  21. namespace dataset {
  22. void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
  23. callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
  24. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  25. callbacks.push_back(callbacks[ind]);
  26. if (callbacks[ind]->IsBeginNeeded()) begin_indices_.push_back(ind);
  27. if (callbacks[ind]->IsEndNeeded()) end_indices_.push_back(ind);
  28. if (callbacks[ind]->IsEpochBeginNeeded()) epoch_begin_indices_.push_back(ind);
  29. if (callbacks[ind]->IsEpochEndNeeded()) epoch_end_indices_.push_back(ind);
  30. if (callbacks[ind]->IsNStepBeginNeeded()) step_begin_indices_.push_back(ind);
  31. if (callbacks[ind]->IsNStepEndNeeded()) step_end_indices_.push_back(ind);
  32. }
  33. }
  34. Status CallbackManager::Init(DatasetOp *op) {
  35. RETURN_UNEXPECTED_IF_NULL(op);
  36. op_ = op;
  37. // turn the flag on if callback is set
  38. enabled_ = !callbacks_.empty();
  39. // error check for each of the callbacks
  40. for (auto &cb : callbacks_) {
  41. CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
  42. }
  43. return Status::OK();
  44. }
  45. Status CallbackManager::Begin(const CallbackParam &cb_param) {
  46. RETURN_OK_IF_TRUE(!enabled_);
  47. RETURN_UNEXPECTED_IF_NULL(op_);
  48. // Now do the actual callback
  49. for (size_t ind : begin_indices_) {
  50. RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
  51. }
  52. return Status::OK();
  53. }
  54. Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
  55. RETURN_OK_IF_TRUE(!enabled_);
  56. RETURN_UNEXPECTED_IF_NULL(op_);
  57. // only wait if there are callbacks to call
  58. if (epoch_begin_indices_.size() > 0) {
  59. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  60. }
  61. // Now do the actual callback
  62. for (size_t ind : epoch_begin_indices_) {
  63. RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
  64. }
  65. return Status::OK();
  66. }
  67. Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
  68. RETURN_OK_IF_TRUE(!enabled_);
  69. RETURN_UNEXPECTED_IF_NULL(op_);
  70. // Now do the actual callback
  71. for (size_t ind : step_begin_indices_) {
  72. if ((cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
  73. RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
  74. }
  75. return Status::OK();
  76. }
  77. Status CallbackManager::End(const CallbackParam &cb_param) {
  78. RETURN_OK_IF_TRUE(!enabled_);
  79. RETURN_UNEXPECTED_IF_NULL(op_);
  80. // return Status::OK() if no end is needed
  81. RETURN_OK_IF_TRUE(end_indices_.empty());
  82. // Now do the actual callback
  83. for (size_t ind : end_indices_) {
  84. RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
  85. }
  86. return Status::OK();
  87. }
  88. Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
  89. RETURN_OK_IF_TRUE(!enabled_);
  90. RETURN_UNEXPECTED_IF_NULL(op_);
  91. // Now do the actual callback
  92. for (size_t ind : epoch_end_indices_) {
  93. RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
  94. }
  95. return Status::OK();
  96. }
  97. Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
  98. RETURN_OK_IF_TRUE(!enabled_);
  99. RETURN_UNEXPECTED_IF_NULL(op_);
  100. // Now do the actual callback
  101. for (size_t ind : step_end_indices_) {
  102. if ((cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
  103. RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
  104. }
  105. return Status::OK();
  106. }
  107. } // namespace dataset
  108. } // namespace mindspore