From: @hfarahat Reviewed-by: @pandoublefeng,@robingrosman Signed-off-by: @pandoublefengpull/14639/MERGE
| @@ -13,7 +13,6 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| set(SRC_FILES_LIST | set(SRC_FILES_LIST | ||||
| execution_tree.cc | execution_tree.cc | ||||
| data_buffer.cc | |||||
| data_schema.cc | data_schema.cc | ||||
| dataset_iterator.cc | dataset_iterator.cc | ||||
| tree_adapter.cc | tree_adapter.cc | ||||
| @@ -34,7 +34,7 @@ | |||||
| #else | #else | ||||
| #include "minddata/dataset/engine/cache/stub/cache_grpc_client.h" | #include "minddata/dataset/engine/cache/stub/cache_grpc_client.h" | ||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/util/lock.h" | #include "minddata/dataset/util/lock.h" | ||||
| #include "minddata/dataset/util/cond_var.h" | #include "minddata/dataset/util/cond_var.h" | ||||
| #include "minddata/dataset/util/queue_map.h" | #include "minddata/dataset/util/queue_map.h" | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <sstream> | #include <sstream> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| #include "minddata/dataset/util/services.h" | #include "minddata/dataset/util/services.h" | ||||
| @@ -115,7 +115,6 @@ Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, | |||||
| Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); } | Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); } | ||||
| Status ToDevice::Send() { | Status ToDevice::Send() { | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| RETURN_IF_NOT_OK(tree_adapter_->Launch()); | RETURN_IF_NOT_OK(tree_adapter_->Launch()); | ||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | ||||
| @@ -1,89 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/util/allocator.h" | |||||
| #include "minddata/dataset/core/global_context.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Name: Constructor #1 | |||||
| // Description: This is the main constructor that is used for making a buffer | |||||
| DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {} | |||||
| // A method for debug printing of the buffer | |||||
| void DataBuffer::Print(std::ostream &out, bool show_all) const { | |||||
| out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n"; | |||||
| // If the column counts are set then it means that data has been set into | |||||
| // the tensor table. Display the tensor table here. | |||||
| if (this->NumCols() > 0) { | |||||
| out << "Tensor table:\n"; | |||||
| for (int32_t row = 0; row < DataBuffer::NumRows(); ++row) { | |||||
| out << "Row # : " << row << "\n"; | |||||
| TensorRow currRow = (*tensor_table_)[row]; | |||||
| for (int32_t col = 0; col < this->NumCols(); ++col) { | |||||
| out << "Column #: " << col << "\n"; // Should add the column name here as well? | |||||
| // Call the tensor display | |||||
| out << *(currRow[col]) << "\n"; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // Remove me!! Callers should fetch rows via pop | |||||
| Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32_t col_id) const { | |||||
| if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) { | |||||
| *ptr = (tensor_table_->at(row_id)).at(col_id); | |||||
| } else { | |||||
| std::string err_msg = | |||||
| "indices for mTensorTable out of range: (" + std::to_string(row_id) + "," + std::to_string(col_id) + ")."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Remove me!! Callers should fetch rows via pop | |||||
| Status DataBuffer::GetRow(int32_t row_id, TensorRow *ptr) const { | |||||
| if (tensor_table_ && !tensor_table_->empty() && row_id < tensor_table_->size()) { | |||||
| *ptr = tensor_table_->at(row_id); | |||||
| } else { | |||||
| std::string err_msg = "rowId for mTensorTable out of range: " + std::to_string(row_id); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DataBuffer::PopRow(TensorRow *ptr) { | |||||
| if (tensor_table_ && !tensor_table_->empty()) { | |||||
| *ptr = std::move(tensor_table_->front()); | |||||
| tensor_table_->pop_front(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DataBuffer::SliceOff(int64_t number_of_rows) { | |||||
| while (number_of_rows > 0) { | |||||
| tensor_table_->pop_back(); | |||||
| number_of_rows--; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -1,114 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/util/allocator.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/include/constants.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/tensor_row.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between | |||||
| /// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format | |||||
| /// where n TensorRows may consist of m tensors (columns). | |||||
| class DataBuffer { | |||||
| public: | |||||
| // Buffer flags | |||||
| enum BufferFlags : uint32_t { | |||||
| kDeBFlagNone = 0, | |||||
| kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg | |||||
| kDeBFlagEOE = 1u << 1, // The buffer is an eoe end-of-epoch msg | |||||
| kDeBFlagWait = 1u << 2, // The buffer is an control signal for workers to suspend operations | |||||
| kDeBFlagQuit = 1u << 3 // The buffer is a control signal for workers to quit | |||||
| }; | |||||
| // Name: Constructor #1 | |||||
| // Description: This is the main constructor that is used for making a buffer | |||||
| DataBuffer(int32_t id, BufferFlags flags); | |||||
| /// \brief default destructor | |||||
| ~DataBuffer() = default; | |||||
| /// \brief A method for debug printing of the buffer | |||||
| /// \param[in/out] out The stream to write to | |||||
| /// \param[in] show_all A boolean to toggle between details and summary printing | |||||
| void Print(std::ostream &out, bool show_all) const; | |||||
| // Provide stream operator for displaying it | |||||
| friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) { | |||||
| cb.Print(out, false); | |||||
| return out; | |||||
| } | |||||
| // Convenience getter functions for flag checking | |||||
| bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); } | |||||
| bool eoe() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOE)); } | |||||
| bool wait() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagWait)); } | |||||
| bool quit() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagQuit)); } | |||||
| // Simple getter funcs | |||||
| int32_t id() const { return buffer_id_; } | |||||
| void set_id(int32_t id) { buffer_id_ = id; } | |||||
| int32_t NumRows() const { return ((tensor_table_) ? tensor_table_->size() : 0); } | |||||
| int32_t NumCols() const { | |||||
| return (tensor_table_ == nullptr || tensor_table_->empty()) ? 0 : tensor_table_->at(0).size(); | |||||
| } | |||||
| BufferFlags buffer_flags() const { return buffer_flags_; } | |||||
| // Remove me!! Callers should fetch rows via pop | |||||
| Status GetTensor(std::shared_ptr<Tensor> *, int32_t row_id, int32_t col_id) const; | |||||
| // Remove me!! Callers should drain rows via pop. | |||||
| Status GetRow(int32_t row_id, TensorRow *) const; | |||||
| // Get a row from the TensorTable | |||||
| Status PopRow(TensorRow *); | |||||
| Status SliceOff(int64_t number_of_rows); | |||||
| // Replacing mTensorTable, the unique_ptr assignment will release the old TensorTable. | |||||
| void set_tensor_table(std::unique_ptr<TensorQTable> new_table) { tensor_table_ = std::move(new_table); } | |||||
| void set_flag(BufferFlags in_flag) { | |||||
| buffer_flags_ = static_cast<BufferFlags>(static_cast<uint32_t>(buffer_flags_) | static_cast<uint32_t>(in_flag)); | |||||
| } | |||||
| void Shuffle() {} // does nothing right now. possibly remove later | |||||
| protected: | |||||
| int32_t buffer_id_; // An id for the buffer. | |||||
| std::unique_ptr<TensorQTable> tensor_table_; // A table (row major) of Tensors | |||||
| BufferFlags buffer_flags_; // bit mask for various buffer properties | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATA_BUFFER_H_ | |||||
| @@ -19,7 +19,7 @@ | |||||
| #include "minddata/dataset/core/data_type.h" | #include "minddata/dataset/core/data_type.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/core/tensor_shape.h" | #include "minddata/dataset/core/tensor_shape.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | #include "minddata/dataset/engine/datasetops/dataset_op.h" | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| @@ -28,7 +28,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Forward declare | // Forward declare | ||||
| class DataBuffer; | |||||
| class ExecutionTree; | class ExecutionTree; | ||||
| // BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has | // BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has | ||||
| @@ -21,7 +21,7 @@ | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| #include "minddata/dataset/core/pybind_support.h" | #include "minddata/dataset/core/pybind_support.h" | ||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/kernels/data/data_utils.h" | #include "minddata/dataset/kernels/data/data_utils.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| @@ -34,7 +34,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class DataBuffer; | |||||
| using PadInfo = std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>>; | using PadInfo = std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>>; | ||||
| @@ -32,7 +32,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class DataBuffer; | |||||
| class BucketBatchByLengthOp : public PipelineOp { | class BucketBatchByLengthOp : public PipelineOp { | ||||
| public: | public: | ||||
| @@ -94,11 +94,9 @@ Status CacheBase::FetchSamplesToWorkers() { | |||||
| keys.reserve(1); | keys.reserve(1); | ||||
| std::vector<row_id_type> prefetch_keys; | std::vector<row_id_type> prefetch_keys; | ||||
| prefetch_keys.reserve(prefetch_size_); | prefetch_keys.reserve(prefetch_size_); | ||||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| while (!sampler_buffer->eoe()) { | |||||
| TensorRow sample_row; | |||||
| RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | |||||
| TensorRow sample_row; | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); | |||||
| while (!sample_row.eoe()) { | |||||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | std::shared_ptr<Tensor> sample_ids = sample_row[0]; | ||||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | ||||
| ++row_cnt_; | ++row_cnt_; | ||||
| @@ -115,7 +113,7 @@ Status CacheBase::FetchSamplesToWorkers() { | |||||
| prefetch_keys.clear(); | prefetch_keys.clear(); | ||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); | |||||
| } | } | ||||
| // Deal with any partial keys left. | // Deal with any partial keys left. | ||||
| if (!prefetch_keys.empty()) { | if (!prefetch_keys.empty()) { | ||||
| @@ -95,7 +95,7 @@ void CacheLookupOp::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status CacheLookupOp::GetNextSample(TensorRow *out) { | |||||
| std::vector<row_id_type> cache_miss; | std::vector<row_id_type> cache_miss; | ||||
| RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); | RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); | ||||
| // Ignore the case we have no cache miss, we can't return empty samples. | // Ignore the case we have no cache miss, we can't return empty samples. | ||||
| @@ -104,19 +104,16 @@ Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| } | } | ||||
| // Special code for eoe | // Special code for eoe | ||||
| if (cache_miss.at(0) == eoe_row_id) { | if (cache_miss.at(0) == eoe_row_id) { | ||||
| *out_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| *out = std::move(TensorRow(TensorRow::kFlagEOE)); | |||||
| } else { | } else { | ||||
| std::shared_ptr<Tensor> sample_ts; | std::shared_ptr<Tensor> sample_ts; | ||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size())); | RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size())); | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone); | |||||
| auto idPtr = sample_ts->begin<int64_t>(); | auto idPtr = sample_ts->begin<int64_t>(); | ||||
| for (auto i = 0; i < cache_miss.size(); ++i) { | for (auto i = 0; i < cache_miss.size(); ++i) { | ||||
| *idPtr = cache_miss.at(i); | *idPtr = cache_miss.at(i); | ||||
| ++idPtr; | ++idPtr; | ||||
| } | } | ||||
| TensorRow row; | |||||
| row.push_back(sample_ts); | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||||
| *out = {sample_ts}; | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -96,7 +96,7 @@ class CacheLookupOp : public CacheBase, public SamplerRT { | |||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; | Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; | ||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | void SamplerPrint(std::ostream &out, bool show_all) const override; | ||||
| bool AllowCacheMiss() override { return true; } | bool AllowCacheMiss() override { return true; } | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | #include "minddata/dataset/engine/datasetops/repeat_op.h" | ||||
| #include "minddata/dataset/engine/dataset_iterator.h" | #include "minddata/dataset/engine/dataset_iterator.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| @@ -26,7 +26,7 @@ | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | #include "minddata/dataset/engine/datasetops/device_queue_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "utils/system/crc32c.h" | #include "utils/system/crc32c.h" | ||||
| @@ -59,8 +59,6 @@ constexpr char kZipOp[] = "ZipOp"; | |||||
| // Forward declare | // Forward declare | ||||
| class ExecutionTree; | class ExecutionTree; | ||||
| class DataBuffer; | |||||
| class NodePass; | class NodePass; | ||||
| class SamplerRT; | class SamplerRT; | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/dataset_iterator.h" | #include "minddata/dataset/engine/dataset_iterator.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/kernels/tensor_op.h" | #include "minddata/dataset/kernels/tensor_op.h" | ||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" | #include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" | ||||
| #include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" | #include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| @@ -34,7 +34,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Forward declare | // Forward declare | ||||
| class DataBuffer; | |||||
| class ExecutionTree; | class ExecutionTree; | ||||
| // MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names. | // MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names. | ||||
| @@ -30,8 +30,6 @@ namespace dataset { | |||||
| constexpr int32_t kEndOfActions = -1; | constexpr int32_t kEndOfActions = -1; | ||||
| // Forward declares | // Forward declares | ||||
| class DataBuffer; | |||||
| class DbConnector; | class DbConnector; | ||||
| // A ParallelOp provides a multi-threaded DatasetOp | // A ParallelOp provides a multi-threaded DatasetOp | ||||
| @@ -26,8 +26,6 @@ namespace dataset { | |||||
| // forward declare | // forward declare | ||||
| class ExecutionTree; | class ExecutionTree; | ||||
| class DataBuffer; | |||||
| class PipelineOp : public DatasetOp { | class PipelineOp : public DatasetOp { | ||||
| public: | public: | ||||
| // Constructor | // Constructor | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| @@ -27,9 +27,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // forward declare | |||||
| class DataBuffer; | |||||
| class RenameOp : public PipelineOp { | class RenameOp : public PipelineOp { | ||||
| public: | public: | ||||
| // The nested builder class inside of the RenameOp is used to help manage all of | // The nested builder class inside of the RenameOp is used to help manage all of | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | #include "minddata/dataset/engine/datasetops/repeat_op.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | #include "minddata/dataset/engine/datasetops/shuffle_op.h" | ||||
| #include "minddata/dataset/engine/dataset_iterator.h" | #include "minddata/dataset/engine/dataset_iterator.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| @@ -37,8 +37,6 @@ class ExecutionTree; | |||||
| class DbConnector; | class DbConnector; | ||||
| class DataBuffer; | |||||
| class ShuffleOp : public PipelineOp { | class ShuffleOp : public PipelineOp { | ||||
| // Shuffle buffer state flags | // Shuffle buffer state flags | ||||
| // | // | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/dataset_iterator.h" | #include "minddata/dataset/engine/dataset_iterator.h" | ||||
| #include "minddata/dataset/engine/datasetops/skip_op.h" | #include "minddata/dataset/engine/datasetops/skip_op.h" | ||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| @@ -487,10 +487,8 @@ Status AlbumOp::GetNextRowPullMode(TensorRow *row) { | |||||
| if (image_rows_.empty()) PrescanEntry(); | if (image_rows_.empty()) PrescanEntry(); | ||||
| if (sample_ids_ == nullptr) { | if (sample_ids_ == nullptr) { | ||||
| RETURN_IF_NOT_OK(this->InitSampler()); | RETURN_IF_NOT_OK(this->InitSampler()); | ||||
| std::unique_ptr<DataBuffer> sample_buffer; | |||||
| TensorRow sample_row; | TensorRow sample_row; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_buffer)); | |||||
| RETURN_IF_NOT_OK(sample_buffer->PopRow(&sample_row)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); | |||||
| sample_ids_ = sample_row[0]; | sample_ids_ = sample_row[0]; | ||||
| } | } | ||||
| if (curr_row_ + 1 > sample_ids_->Size()) { | if (curr_row_ + 1 > sample_ids_->Size()) { | ||||
| @@ -27,7 +27,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | #include "minddata/dataset/engine/datasetops/source/generator_op.h" | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| @@ -219,12 +219,10 @@ Status GeneratorOp::operator()() { | |||||
| if (eoe) { | if (eoe) { | ||||
| // Push out EOE upon StopIteration exception from generator | // Push out EOE upon StopIteration exception from generator | ||||
| MS_LOG(DEBUG) << "Generator operator sends out EOE."; | MS_LOG(DEBUG) << "Generator operator sends out EOE."; | ||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->SendEOE()); | RETURN_IF_NOT_OK(out_connector_->SendEOE()); | ||||
| if (IsLastIteration()) { | if (IsLastIteration()) { | ||||
| // If last repeat or not repeated, push out EOF and exit master loop | // If last repeat or not repeated, push out EOF and exit master loop | ||||
| MS_LOG(DEBUG) << "Generator operator sends out EOF."; | MS_LOG(DEBUG) << "Generator operator sends out EOF."; | ||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| RETURN_IF_NOT_OK(out_connector_->SendEOF()); | RETURN_IF_NOT_OK(out_connector_->SendEOF()); | ||||
| MS_LOG(DEBUG) << "Generator operator main execution loop complete."; | MS_LOG(DEBUG) << "Generator operator main execution loop complete."; | ||||
| eof = true; | eof = true; | ||||
| @@ -26,7 +26,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | ||||
| @@ -30,13 +30,11 @@ MappableLeafOp::MappableLeafOp(int32_t num_wkrs, int32_t queue_size, std::shared | |||||
| // Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work | // Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work | ||||
| Status MappableLeafOp::operator()() { | Status MappableLeafOp::operator()() { | ||||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | ||||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| TensorRow sample_row; | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); | |||||
| int64_t row_cnt = 0; | int64_t row_cnt = 0; | ||||
| while (true) { // each iteration is 1 epoch, breaks when IsLastIteration() is true | while (true) { // each iteration is 1 epoch, breaks when IsLastIteration() is true | ||||
| while (sampler_buffer->eoe() == false) { | |||||
| TensorRow sample_row; | |||||
| RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | |||||
| while (sample_row.eoe() == false) { | |||||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | std::shared_ptr<Tensor> sample_ids = sample_row[0]; | ||||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) { | for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) { | ||||
| if ((*itr) >= num_rows_) { | if ((*itr) >= num_rows_) { | ||||
| @@ -46,7 +44,7 @@ Status MappableLeafOp::operator()() { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[row_cnt++ % num_workers_]->Add(std::make_unique<IOBlock>(*itr, IOBlock::kDeIoBlockNone))); | io_block_queues_[row_cnt++ % num_workers_]->Add(std::make_unique<IOBlock>(*itr, IOBlock::kDeIoBlockNone))); | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); | |||||
| } | } | ||||
| if (IsLastIteration()) { | if (IsLastIteration()) { | ||||
| std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe); | std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe); | ||||
| @@ -71,7 +69,7 @@ Status MappableLeafOp::operator()() { | |||||
| // If not the last repeat, self-reset and go to loop again. | // If not the last repeat, self-reset and go to loop again. | ||||
| if (!IsLastIteration()) { | if (!IsLastIteration()) { | ||||
| RETURN_IF_NOT_OK(Reset()); | RETURN_IF_NOT_OK(Reset()); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); | |||||
| } | } | ||||
| UpdateRepeatAndEpochCounter(); | UpdateRepeatAndEpochCounter(); | ||||
| } | } | ||||
| @@ -90,7 +88,7 @@ Status MappableLeafOp::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ | |||||
| // contains the main logic of pulling a IOBlock from IOBlockQueue, load a row and push the row to out_connector_ | |||||
| // IMPORTANT: 1 IOBlock produces 1 row | // IMPORTANT: 1 IOBlock produces 1 row | ||||
| Status MappableLeafOp::WorkerEntry(int32_t worker_id) { | Status MappableLeafOp::WorkerEntry(int32_t worker_id) { | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| @@ -26,7 +26,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | #include "minddata/dataset/engine/datasetops/dataset_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | ||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| @@ -42,7 +42,6 @@ namespace dataset { | |||||
| // Forward declares | // Forward declares | ||||
| template <typename T> | template <typename T> | ||||
| class Queue; | class Queue; | ||||
| class DataBuffer; | |||||
| using mindrecord::ShardOperator; | using mindrecord::ShardOperator; | ||||
| using mindrecord::ShardReader; | using mindrecord::ShardReader; | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | ||||
| @@ -19,7 +19,6 @@ | |||||
| #include <limits> | #include <limits> | ||||
| #include <memory> | #include <memory> | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -63,15 +62,15 @@ Status DistributedSamplerRT::InitSampler() { | |||||
| if (offset_ != -1 || !even_dist_) { | if (offset_ != -1 || !even_dist_) { | ||||
| if (offset_ == -1) offset_ = 0; | if (offset_ == -1) offset_ = 0; | ||||
| samples_per_buffer_ = (num_rows_ + offset_) / num_devices_; | |||||
| samples_per_tensor_ = (num_rows_ + offset_) / num_devices_; | |||||
| int64_t remainder = (num_rows_ + offset_) % num_devices_; | int64_t remainder = (num_rows_ + offset_) % num_devices_; | ||||
| if (device_id_ < remainder) samples_per_buffer_++; | |||||
| if (device_id_ < offset_) samples_per_buffer_--; | |||||
| if (device_id_ < remainder) samples_per_tensor_++; | |||||
| if (device_id_ < offset_) samples_per_tensor_--; | |||||
| } else { | } else { | ||||
| offset_ = 0; | offset_ = 0; | ||||
| samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) | |||||
| samples_per_tensor_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) | |||||
| } | } | ||||
| samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; | |||||
| samples_per_tensor_ = num_samples_ < samples_per_tensor_ ? num_samples_ : samples_per_tensor_; | |||||
| if (shuffle_) { | if (shuffle_) { | ||||
| shuffle_vec_.reserve(num_rows_); | shuffle_vec_.reserve(num_rows_); | ||||
| for (int64_t i = 0; i < num_rows_; i++) { | for (int64_t i = 0; i < num_rows_; i++) { | ||||
| @@ -79,51 +78,48 @@ Status DistributedSamplerRT::InitSampler() { | |||||
| } | } | ||||
| std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); | std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); | ||||
| } | } | ||||
| if (!samples_per_buffer_) non_empty_ = false; | |||||
| if (!samples_per_tensor_) non_empty_ = false; | |||||
| is_initialized = true; | is_initialized = true; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DistributedSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (cnt_ > samples_per_buffer_) { | |||||
| Status DistributedSamplerRT::GetNextSample(TensorRow *out) { | |||||
| if (cnt_ > samples_per_tensor_) { | |||||
| RETURN_STATUS_UNEXPECTED( | RETURN_STATUS_UNEXPECTED( | ||||
| "Number of samples(cnt) that have already been filled in to buffer should be less than or " | "Number of samples(cnt) that have already been filled in to buffer should be less than or " | ||||
| "equal to samples_per_buffer, but got cnt: " + | "equal to samples_per_buffer, but got cnt: " + | ||||
| std::to_string(cnt_) + ", samples_per_buffer: " + std::to_string(samples_per_buffer_)); | |||||
| } else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) { | |||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| if (!samples_per_buffer_) { | |||||
| std::to_string(cnt_) + ", samples_per_buffer: " + std::to_string(samples_per_tensor_)); | |||||
| } else if (cnt_ == samples_per_tensor_ && (non_empty_ || !even_dist_)) { | |||||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||||
| if (!samples_per_tensor_) { | |||||
| non_empty_ = false; | non_empty_ = false; | ||||
| } | } | ||||
| } else if (!samples_per_buffer_ && !non_empty_) { | |||||
| } else if (!samples_per_tensor_ && !non_empty_) { | |||||
| // If the buffer is empty, we add samples with subscript 0 in the current dataset. | // If the buffer is empty, we add samples with subscript 0 in the current dataset. | ||||
| // This step is to make up for the solution that the code default buffer is not empty before. | // This step is to make up for the solution that the code default buffer is not empty before. | ||||
| // We will remove this value in the concat phase | // We will remove this value in the concat phase | ||||
| non_empty_ = true; | non_empty_ = true; | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone); | |||||
| std::shared_ptr<Tensor> sample_ids; | std::shared_ptr<Tensor> sample_ids; | ||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, 1)); | RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, 1)); | ||||
| auto id_ptr = sample_ids->begin<int64_t>(); | auto id_ptr = sample_ids->begin<int64_t>(); | ||||
| // add index 0 | // add index 0 | ||||
| *id_ptr = 0; | *id_ptr = 0; | ||||
| TensorRow row(1, sample_ids); | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||||
| (*out) = {sample_ids}; | |||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | ||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone); | |||||
| std::shared_ptr<Tensor> sample_ids; | std::shared_ptr<Tensor> sample_ids; | ||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_)); | |||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_tensor_)); | |||||
| auto id_ptr = sample_ids->begin<int64_t>(); | auto id_ptr = sample_ids->begin<int64_t>(); | ||||
| bool flag_add_1 = false; | bool flag_add_1 = false; | ||||
| while (cnt_ < samples_per_buffer_ && id_ptr != sample_ids->end<int64_t>()) { | |||||
| while (cnt_ < samples_per_tensor_ && id_ptr != sample_ids->end<int64_t>()) { | |||||
| int64_t middle_value = num_devices_ * cnt_ + device_id_ - offset_; | int64_t middle_value = num_devices_ * cnt_ + device_id_ - offset_; | ||||
| // if index < 0, we move back one place | // if index < 0, we move back one place | ||||
| if (middle_value < 0) { | if (middle_value < 0) { | ||||
| samples_per_buffer_++; | |||||
| samples_per_tensor_++; | |||||
| cnt_++; | cnt_++; | ||||
| flag_add_1 = true; | flag_add_1 = true; | ||||
| middle_value = num_devices_ * cnt_ + device_id_ - offset_; | middle_value = num_devices_ * cnt_ + device_id_ - offset_; | ||||
| @@ -145,17 +141,16 @@ Status DistributedSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buff | |||||
| // If 1 was added before, we will cut off 1 here | // If 1 was added before, we will cut off 1 here | ||||
| if (flag_add_1) { | if (flag_add_1) { | ||||
| samples_per_buffer_--; | |||||
| samples_per_tensor_--; | |||||
| cnt_--; | cnt_--; | ||||
| } | } | ||||
| TensorRow row(1, sample_ids); | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||||
| (*out) = {sample_ids}; | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DistributedSamplerRT::ResetSampler() { | Status DistributedSamplerRT::ResetSampler() { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_tensor_, "ERROR Reset() called early/late"); | |||||
| cnt_ = 0; | cnt_ = 0; | ||||
| if (shuffle_ == true) { | if (shuffle_ == true) { | ||||
| @@ -50,7 +50,7 @@ class DistributedSamplerRT : public SamplerRT { | |||||
| /// \param std::unique_ptr<DataBuffer> * pBuffer | /// \param std::unique_ptr<DataBuffer> * pBuffer | ||||
| /// \param int32_t workerId | /// \param int32_t workerId | ||||
| /// \return Status code | /// \return Status code | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| /// Init sampler, called by base class or python | /// Init sampler, called by base class or python | ||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| @@ -52,7 +52,7 @@ Status PKSamplerRT::InitSampler() { | |||||
| num_samples_ = num_rows_; | num_samples_ = num_rows_; | ||||
| } | } | ||||
| samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; | |||||
| samples_per_tensor_ = (samples_per_tensor_ > num_samples_) ? num_samples_ : samples_per_tensor_; | |||||
| if (shuffle_ == true) { | if (shuffle_ == true) { | ||||
| std::shuffle(labels_.begin(), labels_.end(), rnd_); | std::shuffle(labels_.begin(), labels_.end(), rnd_); | ||||
| } else { | } else { | ||||
| @@ -65,19 +65,18 @@ Status PKSamplerRT::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PKSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status PKSamplerRT::GetNextSample(TensorRow *out) { | |||||
| if (next_id_ > num_samples_ || num_samples_ == 0) { | if (next_id_ > num_samples_ || num_samples_ == 0) { | ||||
| RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_)); | RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_)); | ||||
| } else if (next_id_ == num_samples_) { | } else if (next_id_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | ||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | |||||
| std::shared_ptr<Tensor> sample_ids; | std::shared_ptr<Tensor> sample_ids; | ||||
| int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; | |||||
| int64_t last_id = (samples_per_tensor_ + next_id_ > num_samples_) ? num_samples_ : samples_per_tensor_ + next_id_; | |||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); | RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); | ||||
| auto id_ptr = sample_ids->begin<int64_t>(); | auto id_ptr = sample_ids->begin<int64_t>(); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_class_ != 0, "samples cannot be zero."); | CHECK_FAIL_RETURN_UNEXPECTED(samples_per_class_ != 0, "samples cannot be zero."); | ||||
| @@ -95,8 +94,7 @@ Status PKSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| id_ptr++; | id_ptr++; | ||||
| } | } | ||||
| TensorRow row(1, sample_ids); | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||||
| (*out) = {sample_ids}; | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -41,7 +41,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||||
| // @param std::unique_ptr<DataBuffer pBuffer | // @param std::unique_ptr<DataBuffer pBuffer | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | // first handshake between leaf source op and Sampler. This func will determine the amount of data | ||||
| // in the dataset that we can sample from. | // in the dataset that we can sample from. | ||||
| @@ -23,9 +23,9 @@ namespace dataset { | |||||
| PythonSamplerRT::PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | PythonSamplerRT::PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | ||||
| : SamplerRT(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | : SamplerRT(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | ||||
| Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status PythonSamplerRT::GetNextSample(TensorRow *out) { | |||||
| if (need_to_reset_) { | if (need_to_reset_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | ||||
| @@ -34,7 +34,6 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| std::shared_ptr<Tensor> sample_ids; | std::shared_ptr<Tensor> sample_ids; | ||||
| { | { | ||||
| py::gil_scoped_acquire gil_acquire; | py::gil_scoped_acquire gil_acquire; | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone); | |||||
| if (Py_IsInitialized() == 0) { | if (Py_IsInitialized() == 0) { | ||||
| return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized"); | return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized"); | ||||
| } | } | ||||
| @@ -57,8 +56,7 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| "Invalid data, python sampler iterator should return an integer index."); | "Invalid data, python sampler iterator should return an integer index."); | ||||
| } | } | ||||
| } | } | ||||
| TensorRow row(1, sample_ids); | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||||
| (*out) = {sample_ids}; | |||||
| need_to_reset_ = true; | need_to_reset_ = true; | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -48,7 +48,7 @@ class PythonSamplerRT : public SamplerRT { | |||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| // @param out - output stream to write to | // @param out - output stream to write to | ||||
| @@ -31,19 +31,18 @@ RandomSamplerRT::RandomSamplerRT(int64_t num_samples, bool replacement, bool res | |||||
| dist(nullptr), | dist(nullptr), | ||||
| reshuffle_each_epoch_(reshuffle_each_epoch) {} | reshuffle_each_epoch_(reshuffle_each_epoch) {} | ||||
| Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status RandomSamplerRT::GetNextSample(TensorRow *out) { | |||||
| if (next_id_ > num_samples_) { | if (next_id_ > num_samples_) { | ||||
| RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); | RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); | ||||
| } else if (next_id_ == num_samples_) { | } else if (next_id_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | ||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | |||||
| std::shared_ptr<Tensor> sampleIds; | std::shared_ptr<Tensor> sampleIds; | ||||
| int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_); | |||||
| int64_t last_id = std::min(samples_per_tensor_ + next_id_, num_samples_); | |||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_)); | RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_)); | ||||
| auto id_ptr = sampleIds->begin<int64_t>(); | auto id_ptr = sampleIds->begin<int64_t>(); | ||||
| @@ -62,8 +61,7 @@ Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| *(id_ptr + static_cast<ptrdiff_t>(i)) = sampled_id; | *(id_ptr + static_cast<ptrdiff_t>(i)) = sampled_id; | ||||
| } | } | ||||
| next_id_ = last_id; | next_id_ = last_id; | ||||
| TensorRow row(1, sampleIds); | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||||
| (*out) = {sampleIds}; | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -81,7 +79,7 @@ Status RandomSamplerRT::InitSampler() { | |||||
| num_samples_ > 0 && num_rows_ > 0, | num_samples_ > 0 && num_rows_ > 0, | ||||
| "Invalid parameter, num_samples & num_rows must be greater than 0, but got num_samples: " + | "Invalid parameter, num_samples & num_rows must be greater than 0, but got num_samples: " + | ||||
| std::to_string(num_samples_) + ", num_rows: " + std::to_string(num_rows_)); | std::to_string(num_samples_) + ", num_rows: " + std::to_string(num_rows_)); | ||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||||
| samples_per_tensor_ = samples_per_tensor_ > num_samples_ ? num_samples_ : samples_per_tensor_; | |||||
| rnd_.seed(seed_); | rnd_.seed(seed_); | ||||
| if (!replacement_) { | if (!replacement_) { | ||||
| @@ -41,7 +41,7 @@ class RandomSamplerRT : public SamplerRT { | |||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| // meant to be called by base class or python | // meant to be called by base class or python | ||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| @@ -36,7 +36,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||||
| SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) | SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) | ||||
| : num_rows_(0), | : num_rows_(0), | ||||
| num_samples_(num_samples), | num_samples_(num_samples), | ||||
| samples_per_buffer_(samples_per_buffer), | |||||
| samples_per_tensor_(samples_per_buffer), | |||||
| col_desc_(nullptr), | col_desc_(nullptr), | ||||
| is_initialized(false) {} | is_initialized(false) {} | ||||
| @@ -91,22 +91,19 @@ void SamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| Status SamplerRT::GetAllIdsThenReset(py::array *data) { | Status SamplerRT::GetAllIdsThenReset(py::array *data) { | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| std::shared_ptr<Tensor> sample_ids; | std::shared_ptr<Tensor> sample_ids; | ||||
| TensorRow sample_row; | TensorRow sample_row; | ||||
| // A call to derived class to get sample ids wrapped inside a buffer | |||||
| RETURN_IF_NOT_OK(GetNextSample(&db)); | |||||
| // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch | |||||
| RETURN_IF_NOT_OK(db->GetRow(0, &sample_row)); | |||||
| // Get the only tensor inside the row that contains the actual SampleIds for the entire epoch | |||||
| RETURN_IF_NOT_OK(GetNextSample(&sample_row)); | |||||
| sample_ids = sample_row[0]; | sample_ids = sample_row[0]; | ||||
| // check this buffer is not a ctrl buffer | // check this buffer is not a ctrl buffer | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(sample_row.Flags() == TensorRow::kFlagNone, "ERROR ctrl row received"); | |||||
| // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch | // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch | ||||
| RETURN_IF_NOT_OK(GetNextSample(&db)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); | |||||
| RETURN_IF_NOT_OK(GetNextSample(&sample_row)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(sample_row.eoe(), "ERROR Non EOE received"); | |||||
| // Reset Sampler since this is the end of the epoch | // Reset Sampler since this is the end of the epoch | ||||
| RETURN_IF_NOT_OK(ResetSampler()); | RETURN_IF_NOT_OK(ResetSampler()); | ||||
| @@ -178,13 +175,11 @@ Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) { | |||||
| bool SamplerRT::HasChildSampler() { return !child_.empty(); } | bool SamplerRT::HasChildSampler() { return !child_.empty(); } | ||||
| Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | ||||
| if (child_ids_ == nullptr) { | |||||
| if (child_ids_.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); | RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); | ||||
| } | } | ||||
| TensorRow sample_row; | |||||
| RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row)); | |||||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | |||||
| std::shared_ptr<Tensor> sample_ids = child_ids_[0]; | |||||
| RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id})); | RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id})); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | #include "minddata/dataset/engine/datasetops/dataset_op.h" | ||||
| @@ -66,7 +66,7 @@ class SamplerRT { | |||||
| // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| SamplerRT(int64_t num_samples, int64_t samples_per_buffer); | SamplerRT(int64_t num_samples, int64_t samples_per_buffer); | ||||
| SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {} | |||||
| SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_tensor_) {} | |||||
| // default destructor | // default destructor | ||||
| ~SamplerRT() = default; | ~SamplerRT() = default; | ||||
| @@ -76,7 +76,7 @@ class SamplerRT { | |||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0; | |||||
| virtual Status GetNextSample(TensorRow *out) = 0; | |||||
| // This function only called by python layer. Not needed by Android. | // This function only called by python layer. Not needed by Android. | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| @@ -170,10 +170,10 @@ class SamplerRT { | |||||
| int64_t num_samples_; | int64_t num_samples_; | ||||
| bool is_initialized; | bool is_initialized; | ||||
| int64_t samples_per_buffer_; | |||||
| int64_t samples_per_tensor_; | |||||
| std::unique_ptr<ColDescriptor> col_desc_; | std::unique_ptr<ColDescriptor> col_desc_; | ||||
| std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes | std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes | ||||
| std::unique_ptr<DataBuffer> child_ids_; | |||||
| TensorRow child_ids_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,23 +24,22 @@ namespace dataset { | |||||
| SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | ||||
| : SamplerRT(num_samples, samples_per_buffer), current_id_(start_index), start_index_(start_index), id_count_(0) {} | : SamplerRT(num_samples, samples_per_buffer), current_id_(start_index), start_index_(start_index), id_count_(0) {} | ||||
| Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status SequentialSamplerRT::GetNextSample(TensorRow *out) { | |||||
| if (id_count_ > num_samples_) { | if (id_count_ > num_samples_) { | ||||
| RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | ||||
| } else if (id_count_ == num_samples_) { | } else if (id_count_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | ||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone); | |||||
| std::shared_ptr<Tensor> sampleIds; | std::shared_ptr<Tensor> sampleIds; | ||||
| // Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for | // Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for | ||||
| // samples per buffer though. | // samples per buffer though. | ||||
| int64_t remaining_ids = num_samples_ - id_count_; | int64_t remaining_ids = num_samples_ - id_count_; | ||||
| int64_t num_elements = std::min(remaining_ids, samples_per_buffer_); | |||||
| int64_t num_elements = std::min(remaining_ids, samples_per_tensor_); | |||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements)); | RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements)); | ||||
| auto idPtr = sampleIds->begin<int64_t>(); | auto idPtr = sampleIds->begin<int64_t>(); | ||||
| @@ -57,8 +56,7 @@ Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe | |||||
| id_count_ += num_elements; // Count the packed ids towards our overall sample count | id_count_ += num_elements; // Count the packed ids towards our overall sample count | ||||
| TensorRow row(1, sampleIds); | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||||
| (*out) = {sampleIds}; | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -83,9 +81,9 @@ Status SequentialSamplerRT::InitSampler() { | |||||
| num_samples_ = available_row_count; | num_samples_ = available_row_count; | ||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED( | CHECK_FAIL_RETURN_UNEXPECTED( | ||||
| (num_samples_ > 0 && samples_per_buffer_ > 0) || num_samples_ == 0, | |||||
| "Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_)); | |||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||||
| (num_samples_ > 0 && samples_per_tensor_ > 0) || num_samples_ == 0, | |||||
| "Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_tensor_)); | |||||
| samples_per_tensor_ = samples_per_tensor_ > num_samples_ ? num_samples_ : samples_per_tensor_; | |||||
| is_initialized = true; | is_initialized = true; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -47,7 +47,7 @@ class SequentialSamplerRT : public SamplerRT { | |||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| /// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers | /// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers | ||||
| /// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows, | /// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows, | ||||
| @@ -39,8 +39,8 @@ Status SubsetSamplerRT::InitSampler() { | |||||
| num_samples_ = static_cast<int64_t>(indices_.size()); | num_samples_ = static_cast<int64_t>(indices_.size()); | ||||
| } | } | ||||
| if (samples_per_buffer_ > num_samples_) { | |||||
| samples_per_buffer_ = num_samples_; | |||||
| if (samples_per_tensor_ > num_samples_) { | |||||
| samples_per_tensor_ = num_samples_; | |||||
| } | } | ||||
| is_initialized = true; | is_initialized = true; | ||||
| @@ -61,19 +61,18 @@ Status SubsetSamplerRT::ResetSampler() { | |||||
| } | } | ||||
| // Get the sample ids. | // Get the sample ids. | ||||
| Status SubsetSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status SubsetSamplerRT::GetNextSample(TensorRow *out) { | |||||
| // All samples have been drawn | // All samples have been drawn | ||||
| if (sample_id_ == num_samples_) { | if (sample_id_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | |||||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | ||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | |||||
| std::shared_ptr<Tensor> outputIds; | std::shared_ptr<Tensor> outputIds; | ||||
| int64_t last_id = sample_id_ + samples_per_buffer_; | |||||
| int64_t last_id = sample_id_ + samples_per_tensor_; | |||||
| // Handling the return all samples at once, and when last draw is not a full batch. | // Handling the return all samples at once, and when last draw is not a full batch. | ||||
| if (last_id > num_samples_) { | if (last_id > num_samples_) { | ||||
| last_id = num_samples_; | last_id = num_samples_; | ||||
| @@ -101,8 +100,7 @@ Status SubsetSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| sample_id_++; | sample_id_++; | ||||
| } | } | ||||
| // Create a TensorTable from that single tensor and push into DataBuffer | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, TensorRow(1, outputIds))); | |||||
| (*out) = {outputIds}; | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -47,9 +47,9 @@ class SubsetSamplerRT : public SamplerRT { | |||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| /// Get the sample ids. | /// Get the sample ids. | ||||
| /// \param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | |||||
| /// \param[out] out The address of a unique_ptr to DataBuffer where the sample ids will be placed. | |||||
| /// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | /// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| /// Printer for debugging purposes. | /// Printer for debugging purposes. | ||||
| /// \param out - output stream to write to | /// \param out - output stream to write to | ||||
| @@ -49,9 +49,9 @@ Status WeightedRandomSamplerRT::InitSampler() { | |||||
| num_rows_ > 0 && num_samples_, | num_rows_ > 0 && num_samples_, | ||||
| "Invalid parameter, num_samples and num_rows must be greater than 0, but got num_rows: " + | "Invalid parameter, num_samples and num_rows must be greater than 0, but got num_rows: " + | ||||
| std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_)); | std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_)); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_tensor_ > 0, | |||||
| "Invalid parameter, samples_per_buffer must be greater than 0, but got " + | "Invalid parameter, samples_per_buffer must be greater than 0, but got " + | ||||
| std::to_string(samples_per_buffer_) + ".\n"); | |||||
| std::to_string(samples_per_tensor_) + ".\n"); | |||||
| if (weights_.size() > static_cast<size_t>(num_rows_)) { | if (weights_.size() > static_cast<size_t>(num_rows_)) { | ||||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, | return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, | ||||
| @@ -69,7 +69,7 @@ Status WeightedRandomSamplerRT::InitSampler() { | |||||
| // Initialize random generator with seed from config manager | // Initialize random generator with seed from config manager | ||||
| rand_gen_.seed(GetSeed()); | rand_gen_.seed(GetSeed()); | ||||
| samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; | |||||
| samples_per_tensor_ = (samples_per_tensor_ > num_samples_) ? num_samples_ : samples_per_tensor_; | |||||
| if (!replacement_) { | if (!replacement_) { | ||||
| exp_dist_ = std::make_unique<std::exponential_distribution<>>(1); | exp_dist_ = std::make_unique<std::exponential_distribution<>>(1); | ||||
| @@ -117,7 +117,7 @@ Status WeightedRandomSamplerRT::ResetSampler() { | |||||
| } | } | ||||
| // Get the sample ids. | // Get the sample ids. | ||||
| Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status WeightedRandomSamplerRT::GetNextSample(TensorRow *out) { | |||||
| if (weights_.size() > static_cast<size_t>(num_rows_)) { | if (weights_.size() > static_cast<size_t>(num_rows_)) { | ||||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, | return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, | ||||
| "Invalid parameter, size of sample weights must be less than or equal to num of data, " | "Invalid parameter, size of sample weights must be less than or equal to num of data, " | ||||
| @@ -133,16 +133,15 @@ Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_b | |||||
| } | } | ||||
| if (sample_id_ == num_samples_) { | if (sample_id_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | |||||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | ||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | |||||
| std::shared_ptr<Tensor> outputIds; | std::shared_ptr<Tensor> outputIds; | ||||
| int64_t last_id = sample_id_ + samples_per_buffer_; | |||||
| int64_t last_id = sample_id_ + samples_per_tensor_; | |||||
| // Handling the return all samples at once, and when last draw is not a full batch. | // Handling the return all samples at once, and when last draw is not a full batch. | ||||
| if (last_id > num_samples_) { | if (last_id > num_samples_) { | ||||
| last_id = num_samples_; | last_id = num_samples_; | ||||
| @@ -178,8 +177,7 @@ Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_b | |||||
| sample_id_++; | sample_id_++; | ||||
| } | } | ||||
| // Create a TensorTable from that single tensor and push into DataBuffer | |||||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, TensorRow(1, outputIds))); | |||||
| (*out) = {outputIds}; | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -51,7 +51,7 @@ class WeightedRandomSamplerRT : public SamplerRT { | |||||
| // Get the sample ids. | // Get the sample ids. | ||||
| // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | ||||
| // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| // @param out - output stream to write to | // @param out - output stream to write to | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "./tinyxml2.h" | #include "./tinyxml2.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/dataset_iterator.h" | #include "minddata/dataset/engine/dataset_iterator.h" | ||||
| #include "minddata/dataset/engine/datasetops/take_op.h" | #include "minddata/dataset/engine/datasetops/take_op.h" | ||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| @@ -29,8 +29,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // forward declare | |||||
| class DataBuffer; | |||||
| class ZipOp : public PipelineOp { | class ZipOp : public PipelineOp { | ||||
| public: | public: | ||||
| @@ -18,8 +18,9 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/core/tensor_row.h" | |||||
| #include "minddata/dataset/engine/connector.h" | #include "minddata/dataset/engine/connector.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/connector.h" | #include "minddata/dataset/engine/connector.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| @@ -181,7 +181,6 @@ if(BUILD_MINDDATA STREQUAL "full") | |||||
| ${MINDDATA_DIR}/engine/datasetops/source/sampler/weighted_random_sampler.cc | ${MINDDATA_DIR}/engine/datasetops/source/sampler/weighted_random_sampler.cc | ||||
| ${MINDDATA_DIR}/engine/runtime_context.cc | ${MINDDATA_DIR}/engine/runtime_context.cc | ||||
| ${MINDDATA_DIR}/engine/tree_adapter.cc | ${MINDDATA_DIR}/engine/tree_adapter.cc | ||||
| ${MINDDATA_DIR}/engine/data_buffer.cc | |||||
| ${MINDDATA_DIR}/engine/execution_tree.cc | ${MINDDATA_DIR}/engine/execution_tree.cc | ||||
| ${MINDDATA_DIR}/engine/dataset_iterator.cc | ${MINDDATA_DIR}/engine/dataset_iterator.cc | ||||
| ${MINDDATA_DIR}/core/tensor_row.cc | ${MINDDATA_DIR}/core/tensor_row.cc | ||||
| @@ -27,7 +27,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| @@ -91,69 +91,69 @@ class AlbumOp { | |||||
| /// \brief Load image to tensor | /// \brief Load image to tensor | ||||
| /// \param[in] image_file Image name of file | /// \param[in] image_file Image name of file | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorPtr *tensor); | Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load vector of ints to tensor, append tensor to tensor | /// \brief Load vector of ints to tensor, append tensor to tensor | ||||
| /// \param[in] json_obj Json object containing multi-dimensional label | /// \param[in] json_obj Json object containing multi-dimensional label | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load vector of floatss to tensor, append tensor to tensor | /// \brief Load vector of floatss to tensor, append tensor to tensor | ||||
| /// \param[in] json_obj Json object containing array data | /// \param[in] json_obj Json object containing array data | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load string array into a tensor, append tensor to tensor | /// \brief Load string array into a tensor, append tensor to tensor | ||||
| /// \param[in] json_obj Json object containing string tensor | /// \param[in] json_obj Json object containing string tensor | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load string into a tensor, append tensor to tensor | /// \brief Load string into a tensor, append tensor to tensor | ||||
| /// \param[in] json_obj Json object containing string tensor | /// \param[in] json_obj Json object containing string tensor | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load float value to tensor | /// \brief Load float value to tensor | ||||
| /// \param[in] json_obj Json object containing float | /// \param[in] json_obj Json object containing float | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load int value to tensor | /// \brief Load int value to tensor | ||||
| /// \param[in] json_obj Json object containing int | /// \param[in] json_obj Json object containing int | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load emtpy tensor to tensor | |||||
| /// \brief Load empty tensor to tensor | |||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadEmptyTensor(uint32_t col_num, TensorPtr *tensor); | Status LoadEmptyTensor(uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load id from file name to tensor | /// \brief Load id from file name to tensor | ||||
| /// \param[in] file The file name to get ID from | /// \param[in] file The file name to get ID from | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] Tensor to push to | |||||
| /// \param[in,out] Tensor to push to | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorPtr *tensor); | Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorPtr *tensor); | ||||
| /// \brief Load a tensor according to a json file | /// \brief Load a tensor according to a json file | ||||
| /// \param[in] row_id_type row_id - id for this tensor row | /// \param[in] row_id_type row_id - id for this tensor row | ||||
| /// \param[in] ImageColumns file Json file location | /// \param[in] ImageColumns file Json file location | ||||
| /// \param[inout] TensorRow Json content stored into a tensor row | |||||
| /// \param[in,out] TensorRow Json content stored into a tensor row | |||||
| /// \return Status The error code returned | /// \return Status The error code returned | ||||
| Status LoadTensorRow(row_id_type row_id, const std::string &file, | Status LoadTensorRow(row_id_type row_id, const std::string &file, | ||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> *map_row); | std::unordered_map<std::string, std::shared_ptr<Tensor>> *map_row); | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -52,11 +52,9 @@ TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(num_samples); | DummyRandomAccessOp dummyRandomAccessOp(num_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -65,8 +63,8 @@ TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) { | |||||
| ASSERT_EQ(4, out.size()); | ASSERT_EQ(4, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) { | TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) { | ||||
| @@ -78,11 +76,10 @@ TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(num_samples); | DummyRandomAccessOp dummyRandomAccessOp(num_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -91,8 +88,8 @@ TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) { | |||||
| ASSERT_EQ(3, out.size()); | ASSERT_EQ(3, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| TEST_F(MindDataTestDistributedSampler, TestThreeShards) { | TEST_F(MindDataTestDistributedSampler, TestThreeShards) { | ||||
| @@ -104,11 +101,10 @@ TEST_F(MindDataTestDistributedSampler, TestThreeShards) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(num_samples); | DummyRandomAccessOp dummyRandomAccessOp(num_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -117,7 +113,6 @@ TEST_F(MindDataTestDistributedSampler, TestThreeShards) { | |||||
| ASSERT_EQ(0, out.size()); | ASSERT_EQ(0, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/rename_op.h" | #include "minddata/dataset/engine/datasetops/rename_op.h" | ||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -39,7 +39,7 @@ class MindDataTestStandAloneSampler : public UT::DatasetOpTesting { | |||||
| protected: | protected: | ||||
| class MockStorageOp : public RandomAccessOp { | class MockStorageOp : public RandomAccessOp { | ||||
| public: | public: | ||||
| MockStorageOp(int64_t val){ | |||||
| MockStorageOp(int64_t val) { | |||||
| // row count is in base class as protected member | // row count is in base class as protected member | ||||
| // GetNumRowsInDataset does not need an override, the default from base class is fine. | // GetNumRowsInDataset does not need an override, the default from base class is fine. | ||||
| num_rows_ = val; | num_rows_ = val; | ||||
| @@ -57,17 +57,17 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { | |||||
| row.push_back(t); | row.push_back(t); | ||||
| } | } | ||||
| MockStorageOp mock(20); | MockStorageOp mock(20); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| std::shared_ptr<Tensor> tensor; | std::shared_ptr<Tensor> tensor; | ||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| TensorRow sample_row; | |||||
| for (int i = 0; i < 6; i++) { | for (int i = 0; i < 6; i++) { | ||||
| std::shared_ptr<SamplerRT> sampler = | std::shared_ptr<SamplerRT> sampler = | ||||
| std::make_shared<DistributedSamplerRT>(num_samples, 3, i % 3, (i < 3 ? false : true)); | std::make_shared<DistributedSamplerRT>(num_samples, 3, i % 3, (i < 3 ? false : true)); | ||||
| sampler->HandshakeRandomAccessOp(&mock); | sampler->HandshakeRandomAccessOp(&mock); | ||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | |||||
| sampler->GetNextSample(&sample_row); | |||||
| tensor = sample_row[0]; | |||||
| MS_LOG(DEBUG) << (*tensor); | MS_LOG(DEBUG) << (*tensor); | ||||
| if(i < 3) { // This is added due to std::shuffle() | |||||
| if (i < 3) { // This is added due to std::shuffle() | |||||
| EXPECT_TRUE((*tensor) == (*row[i])); | EXPECT_TRUE((*tensor) == (*row[i])); | ||||
| } | } | ||||
| } | } | ||||
| @@ -83,20 +83,21 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { | |||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| std::shared_ptr<SamplerRT> sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index, 3); | std::shared_ptr<SamplerRT> sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index, 3); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| std::shared_ptr<Tensor> tensor; | std::shared_ptr<Tensor> tensor; | ||||
| TensorRow sample_row; | |||||
| sampler->HandshakeRandomAccessOp(&mock); | sampler->HandshakeRandomAccessOp(&mock); | ||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | |||||
| sampler->GetNextSample(&sample_row); | |||||
| tensor = sample_row[0]; | |||||
| EXPECT_TRUE((*tensor) == (*label1)); | EXPECT_TRUE((*tensor) == (*label1)); | ||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | |||||
| sampler->GetNextSample(&sample_row); | |||||
| tensor = sample_row[0]; | |||||
| EXPECT_TRUE((*tensor) == (*label2)); | EXPECT_TRUE((*tensor) == (*label2)); | ||||
| sampler->ResetSampler(); | sampler->ResetSampler(); | ||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | |||||
| sampler->GetNextSample(&sample_row); | |||||
| tensor = sample_row[0]; | |||||
| EXPECT_TRUE((*tensor) == (*label1)); | EXPECT_TRUE((*tensor) == (*label1)); | ||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | |||||
| sampler->GetNextSample(&sample_row); | |||||
| tensor = sample_row[0]; | |||||
| EXPECT_TRUE((*tensor) == (*label2)); | EXPECT_TRUE((*tensor) == (*label2)); | ||||
| } | } | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | ||||
| @@ -46,11 +46,10 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(5); | DummyRandomAccessOp dummyRandomAccessOp(5); | ||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -61,8 +60,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||||
| ASSERT_NE(in_set.find(out[i]), in_set.end()); | ASSERT_NE(in_set.find(out[i]), in_set.end()); | ||||
| } | } | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | ||||
| @@ -75,23 +74,20 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | DummyRandomAccessOp dummyRandomAccessOp(total_samples); | ||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| int epoch = 0; | int epoch = 0; | ||||
| while (!db->eoe()) { | |||||
| while (!row.eoe()) { | |||||
| epoch++; | epoch++; | ||||
| db->PopRow(&row); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| } | } | ||||
| } | } | ||||
| db.reset(); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| } | } | ||||
| ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer); | ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer); | ||||
| @@ -107,12 +103,10 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(5); | DummyRandomAccessOp dummyRandomAccessOp(5); | ||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -125,9 +119,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||||
| sampler.ResetSampler(); | sampler.ResetSampler(); | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), false); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), false); | |||||
| out.clear(); | out.clear(); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| @@ -139,6 +132,6 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||||
| ASSERT_NE(in_set.find(out[i]), in_set.end()); | ASSERT_NE(in_set.find(out[i]), in_set.end()); | ||||
| } | } | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h" | ||||
| @@ -46,11 +46,10 @@ TEST_F(MindDataTestSubsetSampler, TestAllAtOnce) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(5); | DummyRandomAccessOp dummyRandomAccessOp(5); | ||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -61,8 +60,8 @@ TEST_F(MindDataTestSubsetSampler, TestAllAtOnce) { | |||||
| ASSERT_EQ(in[i], out[i]); | ASSERT_EQ(in[i], out[i]); | ||||
| } | } | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| TEST_F(MindDataTestSubsetSampler, TestGetNextBuffer) { | TEST_F(MindDataTestSubsetSampler, TestGetNextBuffer) { | ||||
| @@ -75,23 +74,21 @@ TEST_F(MindDataTestSubsetSampler, TestGetNextBuffer) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | DummyRandomAccessOp dummyRandomAccessOp(total_samples); | ||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| int epoch = 0; | int epoch = 0; | ||||
| while (!db->eoe()) { | |||||
| while (!row.eoe()) { | |||||
| epoch++; | epoch++; | ||||
| db->PopRow(&row); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| } | } | ||||
| } | } | ||||
| db.reset(); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| } | } | ||||
| ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer); | ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer); | ||||
| @@ -107,12 +104,11 @@ TEST_F(MindDataTestSubsetSampler, TestReset) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(5); | DummyRandomAccessOp dummyRandomAccessOp(5); | ||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -125,9 +121,9 @@ TEST_F(MindDataTestSubsetSampler, TestReset) { | |||||
| sampler.ResetSampler(); | sampler.ResetSampler(); | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), false); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), false); | |||||
| out.clear(); | out.clear(); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| @@ -139,6 +135,6 @@ TEST_F(MindDataTestSubsetSampler, TestReset) { | |||||
| ASSERT_EQ(in[i], out[i]); | ASSERT_EQ(in[i], out[i]); | ||||
| } | } | ||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -55,11 +55,10 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | DummyRandomAccessOp dummyRandomAccessOp(total_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -69,8 +68,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||||
| ASSERT_EQ(num_samples, out.size()); | ASSERT_EQ(num_samples, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | ||||
| @@ -85,11 +84,10 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | DummyRandomAccessOp dummyRandomAccessOp(total_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -105,8 +103,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||||
| } | } | ||||
| } | } | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | ||||
| @@ -121,21 +119,20 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | DummyRandomAccessOp dummyRandomAccessOp(total_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| int epoch = 0; | int epoch = 0; | ||||
| while (!db->eoe()) { | |||||
| while (!row.eoe()) { | |||||
| epoch++; | epoch++; | ||||
| db->PopRow(&row); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| } | } | ||||
| } | } | ||||
| db.reset(); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| } | } | ||||
| ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer); | ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer); | ||||
| @@ -157,22 +154,21 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | DummyRandomAccessOp dummyRandomAccessOp(total_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| int epoch = 0; | int epoch = 0; | ||||
| while (!db->eoe()) { | |||||
| while (!row.eoe()) { | |||||
| epoch++; | epoch++; | ||||
| db->PopRow(&row); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| freq[*it]++; | freq[*it]++; | ||||
| } | } | ||||
| } | } | ||||
| db.reset(); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| } | } | ||||
| // Without replacement, each sample only drawn once. | // Without replacement, each sample only drawn once. | ||||
| @@ -198,11 +194,10 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | DummyRandomAccessOp dummyRandomAccessOp(total_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -211,14 +206,14 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||||
| } | } | ||||
| ASSERT_EQ(num_samples, out.size()); | ASSERT_EQ(num_samples, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| m_sampler.ResetSampler(); | m_sampler.ResetSampler(); | ||||
| out.clear(); | out.clear(); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -227,8 +222,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||||
| } | } | ||||
| ASSERT_EQ(num_samples, out.size()); | ASSERT_EQ(num_samples, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | ||||
| @@ -243,11 +238,10 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | DummyRandomAccessOp dummyRandomAccessOp(total_samples); | ||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | ||||
| std::unique_ptr<DataBuffer> db; | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -256,8 +250,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| } | } | ||||
| ASSERT_EQ(num_samples, out.size()); | ASSERT_EQ(num_samples, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| m_sampler.ResetSampler(); | m_sampler.ResetSampler(); | ||||
| out.clear(); | out.clear(); | ||||
| @@ -265,8 +259,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| freq.resize(total_samples, 0); | freq.resize(total_samples, 0); | ||||
| MS_LOG(INFO) << "Resetting sampler"; | MS_LOG(INFO) << "Resetting sampler"; | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| out.push_back(*it); | out.push_back(*it); | ||||
| @@ -282,6 +276,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| } | } | ||||
| } | } | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); | |||||
| ASSERT_EQ(row.eoe(), true); | |||||
| } | } | ||||
| @@ -28,7 +28,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -760,11 +760,11 @@ def test_cache_map_parameter_check(): | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id="1", size=0) | ds.DatasetCache(session_id="1", size=0) | ||||
| assert "Argument session_id with value 1 is not of type (<class 'int'>,)" in str(info.value) | |||||
| assert "Argument session_id with value 1 is not of type [<class 'int'>]" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=None, size=0) | ds.DatasetCache(session_id=None, size=0) | ||||
| assert "Argument session_id with value None is not of type (<class 'int'>,)" in str(info.value) | |||||
| assert "Argument session_id with value None is not of type [<class 'int'>]" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| ds.DatasetCache(session_id=1, size=-1) | ds.DatasetCache(session_id=1, size=-1) | ||||
| @@ -772,19 +772,19 @@ def test_cache_map_parameter_check(): | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size="1") | ds.DatasetCache(session_id=1, size="1") | ||||
| assert "Argument size with value 1 is not of type (<class 'int'>,)" in str(info.value) | |||||
| assert "Argument size with value 1 is not of type [<class 'int'>]" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=None) | ds.DatasetCache(session_id=1, size=None) | ||||
| assert "Argument size with value None is not of type (<class 'int'>,)" in str(info.value) | |||||
| assert "Argument size with value None is not of type [<class 'int'>]" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=0, spilling="illegal") | ds.DatasetCache(session_id=1, size=0, spilling="illegal") | ||||
| assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value) | |||||
| assert "Argument spilling with value illegal is not of type [<class 'bool'>]" in str(info.value) | |||||
| with pytest.raises(TypeError) as err: | with pytest.raises(TypeError) as err: | ||||
| ds.DatasetCache(session_id=1, size=0, hostname=50052) | ds.DatasetCache(session_id=1, size=0, hostname=50052) | ||||
| assert "Argument hostname with value 50052 is not of type (<class 'str'>,)" in str(err.value) | |||||
| assert "Argument hostname with value 50052 is not of type [<class 'str'>]" in str(err.value) | |||||
| with pytest.raises(RuntimeError) as err: | with pytest.raises(RuntimeError) as err: | ||||
| ds.DatasetCache(session_id=1, size=0, hostname="illegal") | ds.DatasetCache(session_id=1, size=0, hostname="illegal") | ||||
| @@ -796,11 +796,11 @@ def test_cache_map_parameter_check(): | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=0, port="illegal") | ds.DatasetCache(session_id=1, size=0, port="illegal") | ||||
| assert "Argument port with value illegal is not of type (<class 'int'>,)" in str(info.value) | |||||
| assert "Argument port with value illegal is not of type [<class 'int'>]" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=0, port="50052") | ds.DatasetCache(session_id=1, size=0, port="50052") | ||||
| assert "Argument port with value 50052 is not of type (<class 'int'>,)" in str(info.value) | |||||
| assert "Argument port with value 50052 is not of type [<class 'int'>]" in str(info.value) | |||||
| with pytest.raises(ValueError) as err: | with pytest.raises(ValueError) as err: | ||||
| ds.DatasetCache(session_id=1, size=0, port=0) | ds.DatasetCache(session_id=1, size=0, port=0) | ||||