You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

repeat_op.cc 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <utility>
  18. #include "dataset/engine/execution_tree.h"
  19. #include "dataset/engine/datasetops/repeat_op.h"
  20. #include "dataset/engine/data_buffer.h"
  21. #include "dataset/engine/db_connector.h"
  22. #include "utils/log_adapter.h"
  23. namespace mindspore {
  24. namespace dataset {
  25. // Builder constructor. Creates the builder object.
  26. RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {}
  27. Status RepeatOp::Builder::SanityCheck() const {
  28. if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) {
  29. std::string err_msg("Repeat count must be > 0 or -1.");
  30. RETURN_STATUS_UNEXPECTED(err_msg);
  31. }
  32. return Status::OK();
  33. }
  34. // The builder "build" method creates the final object.
  35. Status RepeatOp::Builder::Build(std::shared_ptr<RepeatOp> *ptr) {
  36. RETURN_IF_NOT_OK(SanityCheck());
  37. *ptr = std::make_shared<RepeatOp>(build_max_repeats_);
  38. return Status::OK();
  39. }
  40. // Constructor of the RepeatOp.
  41. RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {}
  42. // Destructor
  43. RepeatOp::~RepeatOp() {}
  44. // A print method typically used for debugging
  45. void RepeatOp::Print(std::ostream &out, bool show_all) const {
  46. // Call base class printer first
  47. PipelineOp::Print(out, show_all);
  48. // Then display our own stuff
  49. out << "RepeatOp:"
  50. << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_
  51. << "\nLeaf Nodes in my execution path:";
  52. if (!eoe_ops_.empty()) {
  53. out << "\n";
  54. for (size_t i = 0; i < eoe_ops_.size(); i++) {
  55. out << " Operator: " << eoe_ops_[i]->id() << "\n";
  56. }
  57. } else {
  58. out << " kNone.";
  59. }
  60. out << "\n-------------------------\n\n"; // End the display with this line
  61. }
  62. // Base-class override for executing specific RepeatOp configurations. This code will be called
  63. // during the execution tree prepare phase when it is visiting this operator.
  64. Status RepeatOp::PrepareNodePostAction() {
  65. // Run any common code from super class first before adding our own specific logic
  66. RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
  67. std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack();
  68. while (leaf_op != nullptr) {
  69. // Track the leaf operators that are under this repeat op.
  70. eoe_ops_.push_back(leaf_op);
  71. leaf_op = tree_->PopFromRepeatStack();
  72. }
  73. // Push ourselves to the stack in case one of our ascendants is repeat too.
  74. tree_->AddToRepeatStack(shared_from_this());
  75. return Status::OK();
  76. }
  77. // Base-class override for setting specific RepeatOp configurations. This code will be called
  78. // during the execution tree prepare phase BEFORE traversing down to child operators.
  79. uint32_t RepeatOp::PrepareFlags() const { return ExecutionTree::kDePrepRepeat; }
  80. // This function returns the buffer that is at the top of our output connector. The caller is
  81. // typically our parent node, when the parent is asking us to provide the next buffer of data.
  82. // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
  83. // a buffer from our child.
  84. // This function sets the `retryIfEoe` flag when popping from the child connector. This way,
  85. // this function will retry to pop the connector again and will get the non-EOE buffer if any.
  86. Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  87. if (child_.empty()) {
  88. RETURN_STATUS_UNEXPECTED("RepeatOp can't be the leaf node.");
  89. }
  90. std::unique_ptr<DataBuffer> buf;
  91. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  92. // Loop until non EOE is received
  93. while (buf->eoe()) {
  94. RETURN_IF_NOT_OK(EoeReceived(worker_id));
  95. if (state_ == OpState::kDeOpIdle) {
  96. *p_buffer = std::move(buf);
  97. return Status::OK();
  98. }
  99. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  100. }
  101. // Check if the last buf is next eof
  102. if (buf->eof()) {
  103. RETURN_IF_NOT_OK(EofReceived(worker_id));
  104. }
  105. *p_buffer = std::move(buf);
  106. return Status::OK();
  107. }
  108. // Base-class override for handling cases when an eoe is received.
  109. Status RepeatOp::EoeReceived(int32_t worker_id) {
  110. repeat_count_++;
  111. MS_LOG(INFO) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
  112. bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
  113. bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
  114. // If we've reached the requested repeat count, then flag the eoe nodes
  115. // to tell them they've got one more epoch to perform. When they reach the end
  116. // of the last epoch, they quit rather than loop again. This happens in two cases:
  117. // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
  118. // 2- We are not repeated
  119. if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) {
  120. for (auto &eoe_op : eoe_ops_) {
  121. eoe_op->set_control_flag(kDeOpLastRepeat);
  122. }
  123. }
  124. if (repeat_count_ == max_repeats_) {
  125. repeat_count_ = 0;
  126. state_ = OpState::kDeOpIdle;
  127. return Status::OK();
  128. }
  129. // base-class ResetSubtree
  130. return (DatasetOp::ResetSubtree());
  131. }
  132. // Class functor operator () override.
  133. // Most dataset ops operate by launching a thread (see ExecutionTree).
  134. // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the
  135. // functor since this op runs inlined inside another operator. The function is overloaded to
  136. // ensure that it is not called by mistake (it will generate an error).
  137. Status RepeatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RepeatOp is an inlined operator."); }
  138. // Base-class override for handling cases when an eof is received.
  139. Status RepeatOp::EofReceived(int32_t worker_id) {
  140. MS_LOG(INFO) << "Repeat operator EOF received, do nothing now.";
  141. return Status::OK();
  142. }
  143. int32_t RepeatOp::num_consumers() const {
  144. if (parent_.empty()) {
  145. MS_LOG(INFO) << "Repeat operator, no parent node, assuming it's root and returning 1.";
  146. return 1;
  147. } else if (parent_[0] == nullptr) {
  148. MS_LOG(INFO) << "Repeat operator, pointer to the first parent is null. Returning 0.";
  149. return 0;
  150. } else {
  151. return parent_[0]->num_consumers();
  152. }
  153. }
  154. int32_t RepeatOp::num_producers() const {
  155. if (child_.empty() || child_[0] == nullptr) {
  156. MS_LOG(INFO) << "Repeat operator, pointer to child node is null. Returning 0.";
  157. return 0;
  158. } else {
  159. return child_[0]->num_producers();
  160. }
  161. }
  162. } // namespace dataset
  163. } // namespace mindspore