You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_buffer.h 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_
  18. #include <iostream>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "minddata/dataset/util/allocator.h"
  24. #include "minddata/dataset/util/status.h"
  25. #include "minddata/dataset/core/constants.h"
  26. #include "minddata/dataset/core/tensor.h"
  27. #include "minddata/dataset/core/tensor_row.h"
  28. namespace mindspore {
  29. namespace dataset {
  30. /// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between
  31. /// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format
  32. /// where n TensorRows may consist of m tensors (columns).
  33. class DataBuffer {
  34. public:
  35. // Buffer flags
  36. enum BufferFlags : uint32_t {
  37. kDeBFlagNone = 0,
  38. kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg
  39. kDeBFlagEOE = 1u << 1, // The buffer is an eoe end-of-epoch msg
  40. kDeBFlagWait = 1u << 2, // The buffer is an control signal for workers to suspend operations
  41. kDeBFlagQuit = 1u << 3 // The buffer is a control signal for workers to quit
  42. };
  43. // Name: Constructor #1
  44. // Description: This is the main constructor that is used for making a buffer
  45. DataBuffer(int32_t id, BufferFlags flags);
  46. /// \brief default destructor
  47. ~DataBuffer() = default;
  48. /// \brief A method for debug printing of the buffer
  49. /// \param[inout] out The stream to write to
  50. /// \param[in] show_all A boolean to toggle between details and summary printing
  51. void Print(std::ostream &out, bool show_all) const;
  52. // Provide stream operator for displaying it
  53. friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) {
  54. cb.Print(out, false);
  55. return out;
  56. }
  57. // Convenience getter functions for flag checking
  58. bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); }
  59. bool eoe() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOE)); }
  60. bool wait() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagWait)); }
  61. bool quit() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagQuit)); }
  62. // Simple getter funcs
  63. int32_t id() const { return buffer_id_; }
  64. void set_id(int32_t id) { buffer_id_ = id; }
  65. int32_t NumRows() const { return ((tensor_table_) ? tensor_table_->size() : 0); }
  66. int32_t NumCols() const {
  67. return (tensor_table_ == nullptr || tensor_table_->empty()) ? 0 : tensor_table_->at(0).size();
  68. }
  69. BufferFlags buffer_flags() const { return buffer_flags_; }
  70. // Remove me!! Callers should fetch rows via pop
  71. Status GetTensor(std::shared_ptr<Tensor> *, int32_t row_id, int32_t col_id) const;
  72. // Remove me!! Callers should drain rows via pop.
  73. Status GetRow(int32_t row_id, TensorRow *) const;
  74. // Get a row from the TensorTable
  75. Status PopRow(TensorRow *);
  76. Status SliceOff(int64_t number_of_rows);
  77. // Replacing mTensorTable, the unique_ptr assignment will release the old TensorTable.
  78. void set_tensor_table(std::unique_ptr<TensorQTable> new_table) { tensor_table_ = std::move(new_table); }
  79. void set_flag(BufferFlags in_flag) {
  80. buffer_flags_ = static_cast<BufferFlags>(static_cast<uint32_t>(buffer_flags_) | static_cast<uint32_t>(in_flag));
  81. }
  82. void Shuffle() {} // does nothing right now. possibly remove later
  83. protected:
  84. int32_t buffer_id_; // An id for the buffer.
  85. std::unique_ptr<TensorQTable> tensor_table_; // A table (row major) of Tensors
  86. BufferFlags buffer_flags_; // bit mask for various buffer properties
  87. };
  88. } // namespace dataset
  89. } // namespace mindspore
  90. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_