You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback_manager.cc 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/callback/callback_manager.h"
  17. #include "minddata/dataset/callback/ds_callback.h"
  18. #include "minddata/dataset/util/status.h"
  19. #include "minddata/dataset/engine/datasetops/dataset_op.h"
  20. namespace mindspore {
  21. namespace dataset {
  22. void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
  23. callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
  24. }
  25. Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
  26. RETURN_UNEXPECTED_IF_NULL(op);
  27. op_ = op;
  28. // turn the flag on if callback is set
  29. enabled_ = !callbacks_.empty();
  30. // error check for each of the callbacks
  31. for (auto &cb : callbacks_) {
  32. CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
  33. }
  34. return Status::OK();
  35. }
  36. Status CallbackManager::Begin(const CallbackParam &cb_param) {
  37. RETURN_OK_IF_TRUE(!enabled_);
  38. std::vector<size_t> callback_inds;
  39. // go through all callback functions to see if each function is needed
  40. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  41. if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
  42. }
  43. // return Status::OK() if no begin is needed
  44. RETURN_OK_IF_TRUE(callback_inds.empty());
  45. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  46. // Now do the actual callback
  47. for (size_t ind : callback_inds) {
  48. RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
  49. }
  50. return Status::OK();
  51. }
  52. Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
  53. RETURN_OK_IF_TRUE(!enabled_);
  54. std::vector<size_t> callback_inds;
  55. // go through all callback functions to see if each function is needed
  56. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  57. if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
  58. }
  59. // return Status::OK() if no epoch_begin is needed
  60. RETURN_OK_IF_TRUE(callback_inds.empty());
  61. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  62. // Now do the actual callback
  63. for (size_t ind : callback_inds) {
  64. RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
  65. }
  66. return Status::OK();
  67. }
  68. Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
  69. RETURN_OK_IF_TRUE(!enabled_);
  70. std::vector<size_t> callback_inds;
  71. // go through all callback functions to see if each function is needed
  72. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  73. if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
  74. callback_inds.push_back(ind);
  75. }
  76. // return Status::OK() if no step_begin is needed
  77. RETURN_OK_IF_TRUE(callback_inds.empty());
  78. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  79. // Now do the actual callback
  80. for (size_t ind : callback_inds) {
  81. RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
  82. }
  83. return Status::OK();
  84. }
  85. Status CallbackManager::End(const CallbackParam &cb_param) {
  86. RETURN_OK_IF_TRUE(!enabled_);
  87. std::vector<size_t> callback_inds;
  88. // go through all callback functions to see if each function is needed
  89. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  90. if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
  91. }
  92. // return Status::OK() if no end is needed
  93. RETURN_OK_IF_TRUE(callback_inds.empty());
  94. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  95. // Now do the actual callback
  96. for (size_t ind : callback_inds) {
  97. RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
  98. }
  99. return Status::OK();
  100. }
  101. Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
  102. RETURN_OK_IF_TRUE(!enabled_);
  103. std::vector<size_t> callback_inds;
  104. // go through all callback functions to see if each function is needed
  105. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  106. if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
  107. }
  108. // return Status::OK() if no epoch_end is needed
  109. RETURN_OK_IF_TRUE(callback_inds.empty());
  110. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  111. // Now do the actual callback
  112. for (size_t ind : callback_inds) {
  113. RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
  114. }
  115. return Status::OK();
  116. }
  117. Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
  118. RETURN_OK_IF_TRUE(!enabled_);
  119. std::vector<size_t> callback_inds;
  120. // go through all callback functions to see if each function is needed
  121. for (size_t ind = 0; ind < callbacks_.size(); ind++) {
  122. if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
  123. callback_inds.push_back(ind);
  124. }
  125. // return Status::OK() if no step_end is needed
  126. RETURN_OK_IF_TRUE(callback_inds.empty());
  127. RETURN_IF_NOT_OK(op_->WaitForWorkers());
  128. // Now do the actual callback
  129. for (size_t ind : callback_inds) {
  130. RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
  131. }
  132. return Status::OK();
  133. }
  134. } // namespace dataset
  135. } // namespace mindspore