You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

repeat_op.cc 7.7 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iomanip>
  17. #include <iostream>
  18. #include <utility>
  19. #include "dataset/engine/execution_tree.h"
  20. #include "dataset/engine/datasetops/repeat_op.h"
  21. #include "dataset/engine/data_buffer.h"
  22. #include "dataset/engine/db_connector.h"
  23. #include "dataset/engine/opt/pass.h"
  24. #include "utils/log_adapter.h"
  25. namespace mindspore {
  26. namespace dataset {
  27. // Builder constructor. Creates the builder object.
  28. RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {}
  29. Status RepeatOp::Builder::SanityCheck() const {
  30. if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) {
  31. std::string err_msg("Repeat count must be > 0 or -1.");
  32. RETURN_STATUS_UNEXPECTED(err_msg);
  33. }
  34. return Status::OK();
  35. }
  36. // The builder "build" method creates the final object.
  37. Status RepeatOp::Builder::Build(std::shared_ptr<RepeatOp> *ptr) {
  38. RETURN_IF_NOT_OK(SanityCheck());
  39. *ptr = std::make_shared<RepeatOp>(build_max_repeats_);
  40. return Status::OK();
  41. }
  42. // Constructor of the RepeatOp.
  43. RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {}
  44. // Destructor
  45. RepeatOp::~RepeatOp() {}
  46. // A print method typically used for debugging
  47. void RepeatOp::Print(std::ostream &out, bool show_all) const {
  48. // Always show the id and name as first line regardless if this summary or detailed print
  49. out << "(" << std::setw(2) << operator_id_ << ") <RepeatOp>:";
  50. if (!show_all) {
  51. // Call the super class for displaying any common 1-liner info
  52. PipelineOp::Print(out, show_all);
  53. // Then show any custom derived-internal 1-liner info for this op
  54. out << " [repeats: " << max_repeats_ << "]\n";
  55. } else {
  56. // Call the super class for displaying any common detailed info
  57. PipelineOp::Print(out, show_all);
  58. // Then show any custom derived-internal stuff
  59. out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_
  60. << "\nLeaf Nodes in execution path:";
  61. if (!eoe_ops_.empty()) {
  62. for (size_t i = 0; i < eoe_ops_.size(); i++) {
  63. out << "\n Operator: " << eoe_ops_[i]->id();
  64. }
  65. } else {
  66. out << " None.";
  67. }
  68. out << "\n\n";
  69. }
  70. }
  71. // Base-class override for executing specific RepeatOp configurations. This code will be called
  72. // during the execution tree prepare phase when it is visiting this operator.
  73. Status RepeatOp::PrepareNodePostAction() {
  74. // Run any common code from super class first before adding our own specific logic
  75. RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
  76. std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromEOEOpStack();
  77. while (leaf_op != nullptr) {
  78. // Track the leaf operators that are under this repeat op.
  79. eoe_ops_.push_back(leaf_op);
  80. leaf_op = tree_->PopFromEOEOpStack();
  81. }
  82. // Push ourselves to the stack in case one of our ascendants is repeat too.
  83. tree_->AddToEOEOpStack(shared_from_this());
  84. return Status::OK();
  85. }
  86. // Base-class override for setting specific RepeatOp configurations. This code will be called
  87. // during the execution tree prepare phase BEFORE traversing down to child operators.
  88. uint32_t RepeatOp::PrepareFlags() const { return ExecutionTree::kDePrepRepeat; }
  89. // This function returns the buffer that is at the top of our output connector. The caller is
  90. // typically our parent node, when the parent is asking us to provide the next buffer of data.
  91. // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
  92. // a buffer from our child.
  93. // This function sets the `retryIfEoe` flag when popping from the child connector. This way,
  94. // this function will retry to pop the connector again and will get the non-EOE buffer if any.
  95. Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  96. if (child_.empty()) {
  97. RETURN_STATUS_UNEXPECTED("RepeatOp can't be the leaf node.");
  98. }
  99. std::unique_ptr<DataBuffer> buf;
  100. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  101. // Loop until non EOE is received
  102. while (buf->eoe()) {
  103. RETURN_IF_NOT_OK(EoeReceived(worker_id));
  104. if (state_ == OpState::kDeOpIdle) {
  105. *p_buffer = std::move(buf);
  106. return Status::OK();
  107. }
  108. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  109. }
  110. // Check if the last buf is next eof
  111. if (buf->eof()) {
  112. RETURN_IF_NOT_OK(EofReceived(worker_id));
  113. }
  114. *p_buffer = std::move(buf);
  115. return Status::OK();
  116. }
  117. // Base-class override for handling cases when an eoe is received.
  118. Status RepeatOp::EoeReceived(int32_t worker_id) {
  119. repeat_count_++;
  120. MS_LOG(DEBUG) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
  121. bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
  122. bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
  123. // If we've reached the requested repeat count, then flag the eoe nodes
  124. // to tell them they've got one more epoch to perform. When they reach the end
  125. // of the last epoch, they quit rather than loop again. This happens in two cases:
  126. // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
  127. // 2- We are not repeated
  128. if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) {
  129. for (auto &eoe_op : eoe_ops_) {
  130. eoe_op->set_control_flag(kDeOpLastRepeat);
  131. }
  132. }
  133. if (repeat_count_ == max_repeats_) {
  134. repeat_count_ = 0;
  135. state_ = OpState::kDeOpIdle;
  136. return Status::OK();
  137. }
  138. // base-class ResetSubtree
  139. return (DatasetOp::ResetSubtree());
  140. }
  141. // Class functor operator () override.
  142. // Most dataset ops operate by launching a thread (see ExecutionTree).
  143. // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the
  144. // functor since this op runs inlined inside another operator. The function is overloaded to
  145. // ensure that it is not called by mistake (it will generate an error).
  146. Status RepeatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RepeatOp is an inlined operator."); }
  147. // Base-class override for handling cases when an eof is received.
  148. Status RepeatOp::EofReceived(int32_t worker_id) {
  149. MS_LOG(DEBUG) << "Repeat operator EOF received, do nothing now.";
  150. return Status::OK();
  151. }
  152. int32_t RepeatOp::num_consumers() const {
  153. if (parent_.empty()) {
  154. MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1.";
  155. return 1;
  156. } else if (parent_[0] == nullptr) {
  157. MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0.";
  158. return 0;
  159. } else {
  160. return parent_[0]->num_consumers();
  161. }
  162. }
  163. int32_t RepeatOp::num_producers() const {
  164. if (child_.empty() || child_[0] == nullptr) {
  165. MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
  166. return 0;
  167. } else {
  168. return child_[0]->num_producers();
  169. }
  170. }
  171. // Visitor accept method for NodePass
  172. Status RepeatOp::Accept(NodePass *p, bool *modified) {
  173. // Downcast shared pointer then call visitor
  174. return p->RunOnNode(shared_from_base<RepeatOp>(), modified);
  175. }
  176. } // namespace dataset
  177. } // namespace mindspore