You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shuffle_op.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #if defined(_WIN32) || defined(_WIN64)
  17. #include <stdlib.h>
  18. #endif
  19. #include <securec.h>
  20. #include <algorithm>
  21. #include <chrono>
  22. #include <iomanip>
  23. #include <iostream>
  24. #include <limits>
  25. #include <random>
  26. #include <utility>
  27. #include "dataset/core/config_manager.h"
  28. #include "dataset/engine/datasetops/shuffle_op.h"
  29. #include "dataset/engine/dataset_iterator.h"
  30. #include "dataset/engine/data_buffer.h"
  31. #include "dataset/engine/db_connector.h"
  32. #include "dataset/engine/opt/pass.h"
  33. #include "dataset/util/random.h"
  34. #include "dataset/util/status.h"
  35. #include "utils/log_adapter.h"
  36. namespace mindspore {
  37. namespace dataset {
  38. constexpr int32_t ShuffleOp::kShuffleStateInit;
  39. constexpr int32_t ShuffleOp::kShuffleStateActive;
  40. constexpr int32_t ShuffleOp::kShuffleStateDrain;
  41. // Builder constructor. Creates the builder object.
  42. ShuffleOp::Builder::Builder() : build_shuffle_size_(0), build_reshuffle_each_epoch_(true) {
  43. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  44. build_op_connector_size_ = cfg->op_connector_size();
  45. build_rows_per_buffer_ = cfg->rows_per_buffer();
  46. build_shuffle_seed_ = GetSeed();
  47. }
  48. Status ShuffleOp::Builder::SanityCheck() const {
  49. if (build_shuffle_size_ < 2) {
  50. RETURN_STATUS_UNEXPECTED("Shuffle buffer size must be greater than 1.");
  51. }
  52. return Status::OK();
  53. }
  54. // The builder "build" method creates the final object.
  55. Status ShuffleOp::Builder::Build(std::shared_ptr<ShuffleOp> *ptr) {
  56. RETURN_IF_NOT_OK(SanityCheck());
  57. *ptr = std::make_shared<ShuffleOp>(build_shuffle_size_, build_shuffle_seed_, build_op_connector_size_,
  58. build_reshuffle_each_epoch_, build_rows_per_buffer_);
  59. return Status::OK();
  60. }
  61. // Constructor of the ShuffleOp
  62. ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch,
  63. int32_t rows_per_buffer)
  64. : PipelineOp(op_connector_size),
  65. shuffle_size_(shuffle_size),
  66. shuffle_seed_(shuffle_seed),
  67. reshuffle_each_epoch_(reset_every_epoch),
  68. rng_(shuffle_seed),
  69. buffer_counter_(0),
  70. rows_per_buffer_(rows_per_buffer),
  71. shuffle_buffer_(std::make_unique<TensorTable>()),
  72. shuffle_last_row_idx_(0),
  73. shuffle_buffer_state_(kShuffleStateInit) {}
  74. // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by
  75. // itself rather than waiting for the reset driven from operators above it in the pipeline.
  76. Status ShuffleOp::SelfReset() {
  77. MS_LOG(DEBUG) << "Shuffle operator performing a self-reset.";
  78. // If reshuffle_each_epoch is false, then we always use the same seed for every
  79. // epoch.
  80. // If reshuffle_each_epoch is true, then the first epoch uses the given seed,
  81. // and all subsequent epochs will then keep on using the rng_ without resetting it
  82. if (!reshuffle_each_epoch_) {
  83. rng_ = std::mt19937_64(shuffle_seed_);
  84. }
  85. shuffle_buffer_ = std::make_unique<TensorTable>();
  86. buffer_counter_ = 0;
  87. shuffle_last_row_idx_ = 0;
  88. shuffle_buffer_state_ = kShuffleStateInit;
  89. return Status::OK();
  90. }
  91. // A print method typically used for debugging
  92. void ShuffleOp::Print(std::ostream &out, bool show_all) const {
  93. // Always show the id and name as first line regardless if this summary or detailed print
  94. out << "(" << std::setw(2) << operator_id_ << ") <ShuffleOp>:";
  95. if (!show_all) {
  96. // Call the super class for displaying any common 1-liner info
  97. PipelineOp::Print(out, show_all);
  98. // Then show any custom derived-internal 1-liner info for this op
  99. out << " [shuffle size: " << shuffle_size_ << "]\n";
  100. } else {
  101. // Call the super class for displaying any common detailed info
  102. PipelineOp::Print(out, show_all);
  103. // Then show any custom derived-internal stuff
  104. out << "\nShuffle size: " << shuffle_size_ << "\nRows per buffer: " << rows_per_buffer_
  105. << "\nShuffle buffer state: " << shuffle_buffer_state_ << "\nShuffle seed: " << shuffle_seed_ << "\n\n";
  106. }
  107. }
  108. // Private function to add a new row to the shuffle buffer.
  109. Status ShuffleOp::AddRowToShuffleBuffer(TensorRow new_shuffle_row) {
  110. // If the last slot of our shuffle buffer was not the full size of the shuffle buffer then we are
  111. // filling it during the initial fill codepath and thus growing it's size. In that case, we push
  112. // back the new row to grow our shuffle buffer size by 1.
  113. // If we are already at the full size, then we overwrite the last slot with our row (and the last
  114. // slot better be empty because it should already have been swapped out during the random row
  115. // selection that was done previously!)
  116. if (shuffle_last_row_idx_ < (shuffle_size_ - 1)) {
  117. shuffle_buffer_->push_back(std::move(new_shuffle_row));
  118. shuffle_last_row_idx_ = (shuffle_buffer_->size()) - 1;
  119. } else {
  120. if (!(*shuffle_buffer_)[shuffle_last_row_idx_].empty()) {
  121. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  122. "Last row of shuffle buffer should not be occupied!");
  123. }
  124. (*shuffle_buffer_)[shuffle_last_row_idx_] = std::move(new_shuffle_row);
  125. }
  126. return Status::OK();
  127. }
  128. // Class functor operator () override.
  129. // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
  130. // provide the master loop that drives the logic for performing the work
  131. Status ShuffleOp::operator()() {
  132. std::unique_ptr<TensorQTable> new_buffer_table; // A tensor table to be used for output.
  133. // Synchronize with TaskManager once the thread is launched.
  134. TaskManager::FindMe()->Post();
  135. // Shuffle op does not have workers, and only consumes from child 0.
  136. // Create the child iterator to fetch our data from.
  137. int32_t worker_id = 0;
  138. int32_t child_idx = 0;
  139. child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);
  140. // Main operator loop
  141. while (true) {
  142. // Do an initial populate of the shuffle buffer
  143. RETURN_IF_NOT_OK(InitShuffleBuffer());
  144. // This is our main loop exit condition, when the iterator has no more data completely.
  145. if (child_iterator_->eof_handled()) {
  146. break;
  147. }
  148. // Next, enter into the main execution loop of the shuffle op.
  149. // When the tail index position of our shuffle buffer goes negative it means that we've
  150. // fully drained the data from the shuffle buffer and we're done.
  151. while (shuffle_last_row_idx_ >= 0) {
  152. // Step 1)
  153. // Create an output tensor table if one is not created yet.
  154. if (!new_buffer_table) {
  155. new_buffer_table = std::make_unique<TensorQTable>();
  156. }
  157. // Step 2)
  158. // Randomly select a slot from our shuffle buffer and copy that row into the output
  159. // tensor table. We remove the data from the shuffle buffer, leaving that slot
  160. // in the table as an empty vector
  161. int64_t random_slot = rng_() % (shuffle_last_row_idx_ + 1);
  162. new_buffer_table->push_back(std::move((*shuffle_buffer_)[random_slot]));
  163. // Step 3)
  164. // If the output tensor table is at the requested size, then create a buffer for it
  165. // and send this buffer on it's way up the pipeline. Special case is if this is the
  166. // last row then we also send it.
  167. if (new_buffer_table->size() == rows_per_buffer_ || shuffle_last_row_idx_ == 0) {
  168. auto new_buffer = std::make_unique<DataBuffer>(buffer_counter_, DataBuffer::kDeBFlagNone);
  169. new_buffer->set_tensor_table(std::move(new_buffer_table));
  170. buffer_counter_++;
  171. MS_LOG(DEBUG) << "Shuffle operator sending a buffer to output.";
  172. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(new_buffer)));
  173. }
  174. // Step 4)
  175. // Take the last row from shuffle buffer, and swap it into the row position that was
  176. // just vacated. This makes the shuffle buffer contiguous, with an empty slot at the
  177. // tail of the shuffle buffer.
  178. if (random_slot != shuffle_last_row_idx_) {
  179. (*shuffle_buffer_)[random_slot] = std::move((*shuffle_buffer_)[shuffle_last_row_idx_]);
  180. }
  181. // Step 5)
  182. // Refill the last slot of the shuffle buffer with the next row from input if we are in the
  183. // active state.
  184. // If we are in the draining state, we do not need to fetch another row to replace the one we
  185. // just drained.
  186. if (shuffle_buffer_state_ == kShuffleStateActive) {
  187. TensorRow new_row;
  188. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  189. if (!new_row.empty()) {
  190. RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row)));
  191. } else {
  192. shuffle_buffer_state_ = kShuffleStateDrain;
  193. }
  194. }
  195. // If we are draining, reposition (decrement) our tail index in the shuffle buffer since we
  196. // just drained a row from it.
  197. if (shuffle_buffer_state_ == kShuffleStateDrain) {
  198. shuffle_last_row_idx_--;
  199. }
  200. }
  201. // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the
  202. // pipepline manually now that we are done draining the shuffle buffer
  203. MS_LOG(DEBUG) << "Shuffle operator sending EOE.";
  204. auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
  205. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
  206. // Do not wait for any reset to be flown down from operators above us.
  207. // Instead, manually update ourselves and then go reloop to start fetching from child operator
  208. // right away. Any Reset() from the parent will still perform common reset actions.
  209. RETURN_IF_NOT_OK(this->SelfReset());
  210. }
  211. return Status::OK();
  212. }
  213. // Private function populate the shuffle buffer initially by fetching from the child output
  214. // connector until the shuffle buffer is full (or there is no more data coming).
  215. Status ShuffleOp::InitShuffleBuffer() {
  216. MS_LOG(DEBUG) << "Shuffle operator initializing the shuffle buffer.";
  217. // The first phase of this operator is to read incoming buffers and then drain those
  218. // rows from the buffers, putting them into our own local table of tensors (the shuffle
  219. // buffer).
  220. // This shuffle buffer initialization phase stops when we've either filled up the
  221. // shuffle buffer to it's max size, or the dataset below us is not providing any more
  222. // rows.
  223. if (shuffle_buffer_state_ != kShuffleStateInit) {
  224. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  225. "Invalid shuffle buffer state (SHUFFLE_STATE_INIT expected)");
  226. }
  227. // Before we drop into the fetching loop, call the fetch once for the first time
  228. // to fill the first row and grab the first buffer.
  229. TensorRow new_row;
  230. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  231. if (child_iterator_->eof_handled()) {
  232. MS_LOG(DEBUG) << "Shuffle operator init picked up EOF. No more epochs.";
  233. return Status::OK();
  234. }
  235. if (new_row.empty()) {
  236. RETURN_STATUS_UNEXPECTED("Unable to fetch a single row for shuffle buffer.");
  237. }
  238. // Now fill the rest of the shuffle buffer until we are unable to get the next row or we reached
  239. // the desired shuffle buffer size.
  240. while (!new_row.empty() && shuffle_buffer_->size() < static_cast<size_t>(shuffle_size_ - 1)) {
  241. // Add the previously fetched row
  242. RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row)));
  243. // Fetch the next row
  244. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  245. }
  246. // If we quit the loop due to being at the shuffle size, still need to add the last row here.
  247. if (!new_row.empty()) {
  248. RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row)));
  249. shuffle_buffer_state_ = kShuffleStateActive; // Transition to the active state
  250. } else {
  251. // If init phase doesn't have more rows, then skip the active state and jump straight to the
  252. // shuffle buffer draining state
  253. shuffle_buffer_state_ = kShuffleStateDrain;
  254. }
  255. MS_LOG(DEBUG) << "Shuffle operator finished intializing the shuffle buffer.";
  256. return Status::OK();
  257. }
  258. Status ShuffleOp::EoeReceived(int32_t worker_id) {
  259. state_ = OpState::kDeOpIdle;
  260. return Status::OK();
  261. }
  262. // Visitor accept method for NodePass
  263. Status ShuffleOp::Accept(NodePass *p, bool *modified) {
  264. // Downcast shared pointer then call visitor
  265. return p->RunOnNode(shared_from_base<ShuffleOp>(), modified);
  266. }
  267. } // namespace dataset
  268. } // namespace mindspore