You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback_manager.cc 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/callback/callback_manager.h"
  17. #include "minddata/dataset/callback/ds_callback.h"
  18. #include "minddata/dataset/util/status.h"
  19. #include "minddata/dataset/engine/datasetops/dataset_op.h"
  20. namespace mindspore {
  21. namespace dataset {
  22. void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
  23. callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
  24. }
  25. Status CallbackManager::Init(DatasetOp *op) {
  26. RETURN_UNEXPECTED_IF_NULL(op);
  27. op_ = op;
  28. // turn the flag on if callback is set
  29. enabled_ = !callbacks_.empty();
  30. // error check for each of the callbacks
  31. for (auto &cb : callbacks_) {
  32. CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
  33. }
  34. return Status::OK();
  35. }
  36. Status CallbackManager::Begin(const CallbackParam &cb_param) {
  37. RETURN_OK_IF_TRUE(!enabled_);
  38. RETURN_UNEXPECTED_IF_NULL(op_);
  39. std::vector<size_t> callback_inds;
  40. // go through all callback functions to see if each function is needed
  41. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  42. if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
  43. }
  44. // return Status::OK() if no begin is needed
  45. RETURN_OK_IF_TRUE(callback_inds.empty());
  46. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  47. // Now do the actual callback
  48. for (size_t ind : callback_inds) {
  49. RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
  50. }
  51. return Status::OK();
  52. }
  53. Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
  54. RETURN_OK_IF_TRUE(!enabled_);
  55. RETURN_UNEXPECTED_IF_NULL(op_);
  56. std::vector<size_t> callback_inds;
  57. // go through all callback functions to see if each function is needed
  58. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  59. if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
  60. }
  61. // return Status::OK() if no epoch_begin is needed
  62. RETURN_OK_IF_TRUE(callback_inds.empty());
  63. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  64. // Now do the actual callback
  65. for (size_t ind : callback_inds) {
  66. RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
  67. }
  68. return Status::OK();
  69. }
  70. Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
  71. RETURN_OK_IF_TRUE(!enabled_);
  72. RETURN_UNEXPECTED_IF_NULL(op_);
  73. std::vector<size_t> callback_inds;
  74. // go through all callback functions to see if each function is needed
  75. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  76. if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
  77. callback_inds.push_back(ind);
  78. }
  79. // return Status::OK() if no step_begin is needed
  80. RETURN_OK_IF_TRUE(callback_inds.empty());
  81. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  82. // Now do the actual callback
  83. for (size_t ind : callback_inds) {
  84. RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
  85. }
  86. return Status::OK();
  87. }
  88. Status CallbackManager::End(const CallbackParam &cb_param) {
  89. RETURN_OK_IF_TRUE(!enabled_);
  90. RETURN_UNEXPECTED_IF_NULL(op_);
  91. std::vector<size_t> callback_inds;
  92. // go through all callback functions to see if each function is needed
  93. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  94. if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
  95. }
  96. // return Status::OK() if no end is needed
  97. RETURN_OK_IF_TRUE(callback_inds.empty());
  98. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  99. // Now do the actual callback
  100. for (size_t ind : callback_inds) {
  101. RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
  102. }
  103. return Status::OK();
  104. }
  105. Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
  106. RETURN_OK_IF_TRUE(!enabled_);
  107. RETURN_UNEXPECTED_IF_NULL(op_);
  108. std::vector<size_t> callback_inds;
  109. // go through all callback functions to see if each function is needed
  110. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  111. if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
  112. }
  113. // return Status::OK() if no epoch_end is needed
  114. RETURN_OK_IF_TRUE(callback_inds.empty());
  115. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  116. // Now do the actual callback
  117. for (size_t ind : callback_inds) {
  118. RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
  119. }
  120. return Status::OK();
  121. }
  122. Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
  123. RETURN_OK_IF_TRUE(!enabled_);
  124. RETURN_UNEXPECTED_IF_NULL(op_);
  125. std::vector<size_t> callback_inds;
  126. // go through all callback functions to see if each function is needed
  127. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  128. if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
  129. callback_inds.push_back(ind);
  130. }
  131. // return Status::OK() if no step_end is needed
  132. RETURN_OK_IF_TRUE(callback_inds.empty());
  133. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  134. // Now do the actual callback
  135. for (size_t ind : callback_inds) {
  136. RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
  137. }
  138. return Status::OK();
  139. }
  140. } // namespace dataset
  141. } // namespace mindspore