You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

take_op.cc 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <utility>
  17. #include "common/utils.h"
  18. #include "dataset/engine/data_buffer.h"
  19. #include "dataset/engine/datasetops/take_op.h"
  20. #include "dataset/engine/db_connector.h"
  21. #include "dataset/engine/execution_tree.h"
  22. namespace mindspore {
  23. namespace dataset {
  24. // Builder constructor. Creates the builder object.
  25. TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {}
  26. Status TakeOp::Builder::SanityCheck() const {
  27. if (build_max_takes_ <= 0) {
  28. std::string err_msg("Take count must be greater than 0.");
  29. RETURN_STATUS_UNEXPECTED(err_msg);
  30. }
  31. return Status::OK();
  32. }
  33. // The builder "build" method creates the final object.
  34. Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
  35. RETURN_IF_NOT_OK(SanityCheck());
  36. *ptr = std::make_shared<TakeOp>(build_max_takes_);
  37. return Status::OK();
  38. }
  39. // Constructor of the TakeOp.
  40. TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
  41. // A print method typically used for debugging
  42. void TakeOp::Print(std::ostream &out, bool show_all) const {
  43. // Call base class printer first
  44. PipelineOp::Print(out, show_all);
  45. // Then display our own stuff
  46. out << "TakeOp:"
  47. << "\nCurrent take count: " << take_count_ << "\nMax take count: " << max_takes_;
  48. }
  49. // This function will be call muti times to returns the buffer, when meet required max take count or meet
  50. // EOF buffer then this will stop.
  51. Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  52. if (child_.empty()) {
  53. RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node.");
  54. }
  55. std::unique_ptr<DataBuffer> buf;
  56. bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat);
  57. if (take_count_ == max_takes_) {
  58. if (state_ == OpState::kDeOpRunning) {
  59. MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer.";
  60. auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
  61. *p_buffer = std::move(eoe_buffer);
  62. state_ = OpState::kDeOpIdle;
  63. // Reset the count and drain
  64. if (!last_repeat) {
  65. take_count_ = 0;
  66. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  67. while (!buf->eoe() && !buf->eof()) {
  68. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  69. }
  70. }
  71. } else if (state_ == OpState::kDeOpIdle) {
  72. MS_LOG(DEBUG) << "Meet max count and push-back eof buffer.";
  73. auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
  74. *p_buffer = std::move(eof_buffer);
  75. take_count_ = 0;
  76. } else {
  77. MS_LOG(WARNING) << "Invalid OpState: " << state_;
  78. }
  79. return Status::OK();
  80. }
  81. RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
  82. // Loop until non EOE is received
  83. if (buf->eoe()) {
  84. take_count_ = 0;
  85. *p_buffer = std::move(buf);
  86. return Status::OK();
  87. }
  88. // Check if the last buf is next eof
  89. if (buf->eof()) {
  90. *p_buffer = std::move(buf);
  91. return Status::OK();
  92. }
  93. // Get buffer and push back when take_count is still small
  94. if (take_count_ < max_takes_) {
  95. RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer));
  96. }
  97. return Status::OK();
  98. }
  99. // Function FillBuffer mainly prepare the buffer for returning
  100. Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer) {
  101. int32_t buffer_size = (*buffer)->NumRows();
  102. if (take_count_ + buffer_size < max_takes_) {
  103. *data_buffer = std::move(*buffer);
  104. take_count_ = take_count_ + buffer_size;
  105. } else {
  106. MS_LOG(DEBUG) << "In last buffer: Push one buffer.";
  107. std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>();
  108. while (take_count_ < max_takes_) {
  109. TensorRow new_row;
  110. RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row));
  111. take_count_++;
  112. new_tensor_table->push_back(new_row);
  113. }
  114. (*buffer)->set_tensor_table(std::move(new_tensor_table));
  115. *data_buffer = std::move(*buffer);
  116. }
  117. return Status::OK();
  118. }
  119. // Class functor operator () override.
  120. // Most dataset ops operate by launching a thread (see ExecutionTree).
  121. // However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
  122. // functor since this op runs inlined inside another operator. The function is overloaded to
  123. // ensure that it is not called by mistake (it will generate an error).
  124. Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); }
  125. Status TakeOp::PrepareNodePostAction() {
  126. RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
  127. tree_->AddToRepeatStack(shared_from_this());
  128. return Status::OK();
  129. }
  130. } // namespace dataset
  131. } // namespace mindspore