You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_buffer.h 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATA_BUFFER_H_
  17. #define DATASET_ENGINE_DATA_BUFFER_H_
  18. #include <iostream>
  19. #include <map>
  20. #include <memory>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <utility>
  24. #include <vector>
  25. #include "dataset/util/allocator.h"
  26. #include "dataset/util/status.h"
  27. #include "dataset/core/constants.h"
  28. #include "dataset/core/tensor.h"
  29. namespace mindspore {
  30. namespace dataset {
  31. // Forward declares
  32. class StorageClient;
  33. // The DataBuffer class is a base class that will represent the data for n values based
  34. // on a unique row id for each row of data.
  35. // There can be different types of DataBuffers to abstract over how the data is stored
  36. // in memory and acquired from storage.
  37. // Each buffer holds a range of consecutive row id's.
  38. class DataBuffer {
  39. public:
  40. // Buffer flags
  41. enum BufferFlags : uint32_t {
  42. kDeBFlagNone = 0,
  43. kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg
  44. kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg
  45. };
  46. // Name: Constructor #1
  47. // Description: This is the main constructor that is used for making a buffer
  48. DataBuffer(int32_t id, BufferFlags flags);
  49. // Destructor
  50. virtual ~DataBuffer();
  51. // Name: CreateDataBuffer()
  52. // Description: A factory method to create the appropriate type of derived class
  53. // buffer. Returns the base class reference for DataBuffer.
  54. static Status CreateDataBuffer(
  55. int32_t id, // In: The id for the new buffer
  56. std::shared_ptr<StorageClient>, // In: The StorageClient is used to choose the buffer type to create
  57. std::unique_ptr<DataBuffer> *);
  58. // Name: print()
  59. // Description: A function that prints info about the DataBuffer (base class version)
  60. virtual void Print(std::ostream &out, // In: The output stream to print to
  61. bool show_all) const; // In: T/F if it should show everything
  62. // Provide stream operator for displaying it
  63. friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) {
  64. cb.Print(out, false);
  65. return out;
  66. }
  67. // Name: load()
  68. // Description: populates the DataBuffer with data based on it's id
  69. virtual Status Load();
  70. // Convenience getter functions for flag checking
  71. bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); }
  72. bool eoe() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOE)); }
  73. // Simple getter funcs
  74. int32_t id() const { return buffer_id_; }
  75. void set_id(int32_t id) { buffer_id_ = id; }
  76. int32_t NumRows() const { return ((tensor_table_) ? tensor_table_->size() : 0); }
  77. int32_t NumCols() const {
  78. return (tensor_table_ == nullptr || tensor_table_->empty()) ? 0 : tensor_table_->at(0).size();
  79. }
  80. BufferFlags buffer_flags() const { return buffer_flags_; }
  81. // Remove me!! Callers should fetch rows via pop
  82. Status GetTensor(std::shared_ptr<Tensor> *, int32_t row_id, int32_t col_id) const;
  83. // Remove me!! Callers should drain rows via pop.
  84. Status GetRow(int32_t row_id, TensorRow *) const;
  85. // Get a row from the TensorTable
  86. Status PopRow(TensorRow *);
  87. Status SliceOff(int64_t number_of_rows);
  88. // Return a mapping from col names to col id.
  89. std::unordered_map<std::string, int32_t> column_name_map() const { return column_name_map_; }
  90. // Update the column name to index mapping.
  91. void set_column_name_map(const std::unordered_map<std::string, int32_t> &new_col_name_map) {
  92. column_name_map_ = new_col_name_map;
  93. }
  94. // Replacing mTensorTable, the unique_ptr assignment will release the old TensorTable.
  95. void set_tensor_table(std::unique_ptr<TensorQTable> new_table) { tensor_table_ = std::move(new_table); }
  96. void set_flag(BufferFlags in_flag) {
  97. buffer_flags_ = static_cast<BufferFlags>(static_cast<uint32_t>(buffer_flags_) | static_cast<uint32_t>(in_flag));
  98. }
  99. void Shuffle() {} // does nothing right now. possibly remove later
  100. // ***** column_name_map_ manipulation methods *****
  101. // Append Column to mColumnNameMap
  102. Status AppendColumn(const std::string &name, const int32_t &old_id) const { // does nothing right now
  103. return Status::OK();
  104. }
  105. protected:
  106. int32_t buffer_id_; // An id for the buffer.
  107. std::unique_ptr<TensorQTable> tensor_table_; // A table (row major) of Tensors
  108. BufferFlags buffer_flags_; // bit mask for various buffer properties
  109. std::unordered_map<std::string, int32_t> column_name_map_; // A mapping between column index to column name.
  110. };
  111. } // namespace dataset
  112. } // namespace mindspore
  113. #endif // DATASET_ENGINE_DATA_BUFFER_H_