You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shuffle_op.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #if defined(_WIN32) || defined(_WIN64)
  17. #include <stdlib.h>
  18. #endif
  19. #include <securec.h>
  20. #include <algorithm>
  21. #include <chrono>
  22. #include <iostream>
  23. #include <limits>
  24. #include <random>
  25. #include <utility>
  26. #include "dataset/core/config_manager.h"
  27. #include "dataset/engine/datasetops/shuffle_op.h"
  28. #include "dataset/engine/dataset_iterator.h"
  29. #include "dataset/engine/data_buffer.h"
  30. #include "dataset/engine/db_connector.h"
  31. #include "dataset/util/random.h"
  32. #include "dataset/util/status.h"
  33. #include "utils/log_adapter.h"
  34. namespace mindspore {
  35. namespace dataset {
  36. constexpr int32_t ShuffleOp::kShuffleStateInit;
  37. constexpr int32_t ShuffleOp::kShuffleStateActive;
  38. constexpr int32_t ShuffleOp::kShuffleStateDrain;
  39. // Builder constructor. Creates the builder object.
  40. ShuffleOp::Builder::Builder() : build_shuffle_size_(0), build_reshuffle_each_epoch_(true) {
  41. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  42. build_op_connector_size_ = cfg->op_connector_size();
  43. build_rows_per_buffer_ = cfg->rows_per_buffer();
  44. build_shuffle_seed_ = GetSeed();
  45. }
  46. Status ShuffleOp::Builder::SanityCheck() const {
  47. if (build_shuffle_size_ < 2) {
  48. RETURN_STATUS_UNEXPECTED("Shuffle buffer size must be greater than 1.");
  49. }
  50. return Status::OK();
  51. }
  52. // The builder "build" method creates the final object.
  53. Status ShuffleOp::Builder::Build(std::shared_ptr<ShuffleOp> *ptr) {
  54. RETURN_IF_NOT_OK(SanityCheck());
  55. *ptr = std::make_shared<ShuffleOp>(build_shuffle_size_, build_shuffle_seed_, build_op_connector_size_,
  56. build_reshuffle_each_epoch_, build_rows_per_buffer_);
  57. return Status::OK();
  58. }
  59. // Constructor of the ShuffleOp
  60. ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch,
  61. int32_t rows_per_buffer)
  62. : PipelineOp(op_connector_size),
  63. shuffle_size_(shuffle_size),
  64. shuffle_seed_(shuffle_seed),
  65. reshuffle_each_epoch_(reset_every_epoch),
  66. rng_(shuffle_seed),
  67. buffer_counter_(0),
  68. rows_per_buffer_(rows_per_buffer),
  69. shuffle_buffer_(std::make_unique<TensorTable>()),
  70. shuffle_last_row_idx_(0),
  71. shuffle_buffer_state_(kShuffleStateInit) {}
  72. // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by
  73. // itself rather than waiting for the reset driven from operators above it in the pipeline.
  74. Status ShuffleOp::SelfReset() {
  75. MS_LOG(INFO) << "Shuffle operator performing a self-reset.";
  76. // If ReshuffleEachEpoch is false, then we always use the same seed for every
  77. // epoch.
  78. // If ReshuffleEachEpoch is true, then the first epoch uses the given seed,
  79. // and all subsequent epochs will then reset the seed based on random device.
  80. if (!reshuffle_each_epoch_) {
  81. rng_ = std::mt19937_64(shuffle_seed_);
  82. } else {
  83. #if defined(_WIN32) || defined(_WIN64)
  84. unsigned int number;
  85. rand_s(&number);
  86. std::mt19937 random_device{static_cast<uint32_t>(number)};
  87. #else
  88. std::random_device random_device("/dev/urandom");
  89. #endif
  90. std::uniform_int_distribution<int32_t> distribution(0, std::numeric_limits<int32_t>::max());
  91. shuffle_seed_ = distribution(random_device);
  92. rng_ = std::mt19937_64(shuffle_seed_);
  93. }
  94. shuffle_buffer_ = std::make_unique<TensorTable>();
  95. buffer_counter_ = 0;
  96. shuffle_last_row_idx_ = 0;
  97. shuffle_buffer_state_ = kShuffleStateInit;
  98. return Status::OK();
  99. }
  100. // A print method typically used for debugging
  101. void ShuffleOp::Print(std::ostream &out, bool show_all) const {
  102. // Call base class printer first
  103. PipelineOp::Print(out, show_all);
  104. // Then display our own stuff
  105. out << "ShuffleOp:\n Shuffle size: " << shuffle_size_ << "\n rows_per_buffer_: " << rows_per_buffer_
  106. << "\n shuffle_buffer_state_: " << shuffle_buffer_state_ << "\n shuffle_seed_: " << shuffle_seed_;
  107. out << "\n-------------------------\n\n"; // End the display with this line
  108. }
  109. // Private function to add a new row to the shuffle buffer.
  110. Status ShuffleOp::AddRowToShuffleBuffer(TensorRow new_shuffle_row) {
  111. // If the last slot of our shuffle buffer was not the full size of the shuffle buffer then we are
  112. // filling it during the initial fill codepath and thus growing it's size. In that case, we push
  113. // back the new row to grow our shuffle buffer size by 1.
  114. // If we are already at the full size, then we overwrite the last slot with our row (and the last
  115. // slot better be empty because it should already have been swapped out during the random row
  116. // selection that was done previously!)
  117. if (shuffle_last_row_idx_ < (shuffle_size_ - 1)) {
  118. shuffle_buffer_->push_back(std::move(new_shuffle_row));
  119. shuffle_last_row_idx_ = (shuffle_buffer_->size()) - 1;
  120. } else {
  121. if (!(*shuffle_buffer_)[shuffle_last_row_idx_].empty()) {
  122. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  123. "Last row of shuffle buffer should not be occupied!");
  124. }
  125. (*shuffle_buffer_)[shuffle_last_row_idx_] = std::move(new_shuffle_row);
  126. }
  127. return Status::OK();
  128. }
  129. // Class functor operator () override.
  130. // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
  131. // provide the master loop that drives the logic for performing the work
  132. Status ShuffleOp::operator()() {
  133. std::unique_ptr<TensorQTable> new_buffer_table; // A tensor table to be used for output.
  134. // Synchronize with TaskManager once the thread is launched.
  135. TaskManager::FindMe()->Post();
  136. // Shuffle op does not have workers, and only consumes from child 0.
  137. // Create the child iterator to fetch our data from.
  138. int32_t worker_id = 0;
  139. int32_t child_idx = 0;
  140. child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);
  141. // Main operator loop
  142. while (true) {
  143. // Do an initial populate of the shuffle buffer
  144. RETURN_IF_NOT_OK(InitShuffleBuffer());
  145. // This is our main loop exit condition, when the iterator has no more data completely.
  146. if (child_iterator_->eof_handled()) {
  147. break;
  148. }
  149. // Next, enter into the main execution loop of the shuffle op.
  150. // When the tail index position of our shuffle buffer goes negative it means that we've
  151. // fully drained the data from the shuffle buffer and we're done.
  152. while (shuffle_last_row_idx_ >= 0) {
  153. // Step 1)
  154. // Create an output tensor table if one is not created yet.
  155. if (!new_buffer_table) {
  156. new_buffer_table = std::make_unique<TensorQTable>();
  157. }
  158. // Step 2)
  159. // Randomly select a slot from our shuffle buffer and copy that row into the output
  160. // tensor table. We remove the data from the shuffle buffer, leaving that slot
  161. // in the table as an empty vector
  162. int64_t random_slot = rng_() % (shuffle_last_row_idx_ + 1);
  163. new_buffer_table->push_back(std::move((*shuffle_buffer_)[random_slot]));
  164. // Step 3)
  165. // If the output tensor table is at the requested size, then create a buffer for it
  166. // and send this buffer on it's way up the pipeline. Special case is if this is the
  167. // last row then we also send it.
  168. if (new_buffer_table->size() == rows_per_buffer_ || shuffle_last_row_idx_ == 0) {
  169. auto new_buffer = std::make_unique<DataBuffer>(buffer_counter_, DataBuffer::kDeBFlagNone);
  170. new_buffer->set_tensor_table(std::move(new_buffer_table));
  171. new_buffer->set_column_name_map(column_name_map_);
  172. buffer_counter_++;
  173. MS_LOG(DEBUG) << "Shuffle operator sending a buffer to output.";
  174. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(new_buffer)));
  175. }
  176. // Step 4)
  177. // Take the last row from shuffle buffer, and swap it into the row position that was
  178. // just vacated. This makes the shuffle buffer contiguous, with an empty slot at the
  179. // tail of the shuffle buffer.
  180. if (random_slot != shuffle_last_row_idx_) {
  181. (*shuffle_buffer_)[random_slot] = std::move((*shuffle_buffer_)[shuffle_last_row_idx_]);
  182. }
  183. // Step 5)
  184. // Refill the last slot of the shuffle buffer with the next row from input if we are in the
  185. // active state.
  186. // If we are in the draining state, we do not need to fetch another row to replace the one we
  187. // just drained.
  188. if (shuffle_buffer_state_ == kShuffleStateActive) {
  189. TensorRow new_row;
  190. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  191. if (!new_row.empty()) {
  192. RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row)));
  193. } else {
  194. shuffle_buffer_state_ = kShuffleStateDrain;
  195. }
  196. }
  197. // If we are draining, reposition (decrement) our tail index in the shuffle buffer since we
  198. // just drained a row from it.
  199. if (shuffle_buffer_state_ == kShuffleStateDrain) {
  200. shuffle_last_row_idx_--;
  201. }
  202. }
  203. // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the
  204. // pipepline manually now that we are done draining the shuffle buffer
  205. MS_LOG(INFO) << "Shuffle operator sending EOE.";
  206. auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
  207. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
  208. // Do not wait for any reset to be flown down from operators above us.
  209. // Instead, manually update ourselves and then go reloop to start fetching from child operator
  210. // right away. Any Reset() from the parent will still perform common reset actions.
  211. RETURN_IF_NOT_OK(this->SelfReset());
  212. }
  213. return Status::OK();
  214. }
  215. // Private function populate the shuffle buffer initially by fetching from the child output
  216. // connector until the shuffle buffer is full (or there is no more data coming).
  217. Status ShuffleOp::InitShuffleBuffer() {
  218. MS_LOG(INFO) << "Shuffle operator initializing the shuffle buffer.";
  219. // The first phase of this operator is to read incoming buffers and then drain those
  220. // rows from the buffers, putting them into our own local table of tensors (the shuffle
  221. // buffer).
  222. // This shuffle buffer initialization phase stops when we've either filled up the
  223. // shuffle buffer to it's max size, or the dataset below us is not providing any more
  224. // rows.
  225. if (shuffle_buffer_state_ != kShuffleStateInit) {
  226. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  227. "Invalid shuffle buffer state (SHUFFLE_STATE_INIT expected)");
  228. }
  229. // Before we drop into the fetching loop, call the fetch once for the first time
  230. // to fill the first row and grab the first buffer.
  231. TensorRow new_row;
  232. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  233. if (child_iterator_->eof_handled()) {
  234. MS_LOG(INFO) << "Shuffle operator init picked up EOF. No more epochs.";
  235. return Status::OK();
  236. }
  237. if (new_row.empty()) {
  238. RETURN_STATUS_UNEXPECTED("Unable to fetch a single row for shuffle buffer.");
  239. }
  240. // Take a copy of the column name mapping. We'll use this when constructing output buffers later.
  241. column_name_map_ = child_iterator_->col_name_id_map();
  242. // Now fill the rest of the shuffle buffer until we are unable to get the next row or we reached
  243. // the desired shuffle buffer size.
  244. while (!new_row.empty() && shuffle_buffer_->size() < static_cast<size_t>(shuffle_size_ - 1)) {
  245. // Add the previously fetched row
  246. RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row)));
  247. // Fetch the next row
  248. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  249. }
  250. // If we quit the loop due to being at the shuffle size, still need to add the last row here.
  251. if (!new_row.empty()) {
  252. RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row)));
  253. shuffle_buffer_state_ = kShuffleStateActive; // Transition to the active state
  254. } else {
  255. // If init phase doesn't have more rows, then skip the active state and jump straight to the
  256. // shuffle buffer draining state
  257. shuffle_buffer_state_ = kShuffleStateDrain;
  258. }
  259. MS_LOG(INFO) << "Shuffle operator finished intializing the shuffle buffer.";
  260. return Status::OK();
  261. }
  262. Status ShuffleOp::EoeReceived(int32_t worker_id) {
  263. state_ = OpState::kDeOpIdle;
  264. return Status::OK();
  265. }
  266. } // namespace dataset
  267. } // namespace mindspore