Merge pull request !1983 from JesseKLee/samplertags/v0.5.0-beta
| @@ -263,7 +263,7 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) { | |||||
| Status CelebAOp::operator()() { | Status CelebAOp::operator()() { | ||||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | ||||
| std::unique_ptr<DataBuffer> data_buffer; | std::unique_ptr<DataBuffer> data_buffer; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&data_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer)); | |||||
| RETURN_IF_NOT_OK(AddIOBlock(&data_buffer)); | RETURN_IF_NOT_OK(AddIOBlock(&data_buffer)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -291,7 +291,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||||
| keys.clear(); | keys.clear(); | ||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); | |||||
| } | } | ||||
| if (!keys.empty()) { | if (!keys.empty()) { | ||||
| @@ -313,7 +313,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||||
| io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | ||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -100,7 +100,7 @@ CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const | |||||
| Status CifarOp::operator()() { | Status CifarOp::operator()() { | ||||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | ||||
| std::unique_ptr<DataBuffer> sampler_buffer; | std::unique_ptr<DataBuffer> sampler_buffer; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| while (true) { // each iterator is 1 epoch | while (true) { // each iterator is 1 epoch | ||||
| std::vector<int64_t> keys; | std::vector<int64_t> keys; | ||||
| keys.reserve(rows_per_buffer_); | keys.reserve(rows_per_buffer_); | ||||
| @@ -118,7 +118,7 @@ Status CifarOp::operator()() { | |||||
| keys.clear(); | keys.clear(); | ||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| } | } | ||||
| if (keys.empty() == false) { | if (keys.empty() == false) { | ||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| @@ -139,7 +139,7 @@ Status CifarOp::operator()() { | |||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | ||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -126,7 +126,7 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { | |||||
| Status ImageFolderOp::operator()() { | Status ImageFolderOp::operator()() { | ||||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | ||||
| std::unique_ptr<DataBuffer> sampler_buffer; | std::unique_ptr<DataBuffer> sampler_buffer; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| while (true) { // each iterator is 1 epoch | while (true) { // each iterator is 1 epoch | ||||
| std::vector<int64_t> keys; | std::vector<int64_t> keys; | ||||
| keys.reserve(rows_per_buffer_); | keys.reserve(rows_per_buffer_); | ||||
| @@ -145,7 +145,7 @@ Status ImageFolderOp::operator()() { | |||||
| keys.clear(); | keys.clear(); | ||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| } | } | ||||
| if (keys.empty() == false) { | if (keys.empty() == false) { | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -166,7 +166,7 @@ Status ImageFolderOp::operator()() { | |||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | ||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -88,7 +88,7 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f | |||||
| Status ManifestOp::operator()() { | Status ManifestOp::operator()() { | ||||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | ||||
| std::unique_ptr<DataBuffer> sampler_buffer; | std::unique_ptr<DataBuffer> sampler_buffer; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| return AddIoBlock(&sampler_buffer); | return AddIoBlock(&sampler_buffer); | ||||
| } | } | ||||
| @@ -110,7 +110,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||||
| keys.clear(); | keys.clear(); | ||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); | |||||
| } | } | ||||
| if (keys.empty() == false) { | if (keys.empty() == false) { | ||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| @@ -131,7 +131,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | ||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -98,7 +98,7 @@ Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, st | |||||
| Status MnistOp::operator()() { | Status MnistOp::operator()() { | ||||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | ||||
| std::unique_ptr<DataBuffer> sampler_buffer; | std::unique_ptr<DataBuffer> sampler_buffer; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| while (true) { // each iterator is 1 epoch | while (true) { // each iterator is 1 epoch | ||||
| std::vector<int64_t> keys; | std::vector<int64_t> keys; | ||||
| keys.reserve(rows_per_buffer_); | keys.reserve(rows_per_buffer_); | ||||
| @@ -109,7 +109,7 @@ Status MnistOp::operator()() { | |||||
| RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64"); | RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64"); | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys)); | RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys)); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| } | } | ||||
| if (keys.empty() == false) { | if (keys.empty() == false) { | ||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| @@ -130,7 +130,7 @@ Status MnistOp::operator()() { | |||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | ||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -55,14 +55,14 @@ Status DistributedSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (cnt_ > samples_per_buffer_) { | if (cnt_ > samples_per_buffer_) { | ||||
| RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); | RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); | ||||
| } else if (cnt_ == samples_per_buffer_) { | } else if (cnt_ == samples_per_buffer_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone); | (*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone); | ||||
| @@ -40,7 +40,7 @@ class DistributedSampler : public Sampler { | |||||
| // @param std::unique_ptr<DataBuffer> * pBuffer | // @param std::unique_ptr<DataBuffer> * pBuffer | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| // Init sampler, called by base class or python | // Init sampler, called by base class or python | ||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| @@ -59,14 +59,14 @@ Status PKSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (next_id_ > num_samples_ || num_samples_ == 0) { | if (next_id_ > num_samples_ || num_samples_ == 0) { | ||||
| RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); | RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); | ||||
| } else if (next_id_ == num_samples_) { | } else if (next_id_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | ||||
| @@ -41,7 +41,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED | |||||
| // @param std::unique_ptr<DataBuffer pBuffer | // @param std::unique_ptr<DataBuffer pBuffer | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | // first handshake between leaf source op and Sampler. This func will determine the amount of data | ||||
| // in the dataset that we can sample from. | // in the dataset that we can sample from. | ||||
| @@ -23,12 +23,12 @@ namespace dataset { | |||||
| PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | ||||
| : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | ||||
| Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (need_to_reset_) { | if (need_to_reset_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||||
| } | } | ||||
| std::shared_ptr<Tensor> sample_ids; | std::shared_ptr<Tensor> sample_ids; | ||||
| @@ -48,7 +48,7 @@ class PythonSampler : public Sampler { | |||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| private: | private: | ||||
| bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | ||||
| @@ -31,14 +31,14 @@ RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuff | |||||
| reshuffle_each_epoch_(reshuffle_each_epoch), | reshuffle_each_epoch_(reshuffle_each_epoch), | ||||
| dist(nullptr) {} | dist(nullptr) {} | ||||
| Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (next_id_ > num_samples_) { | if (next_id_ > num_samples_) { | ||||
| RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); | RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); | ||||
| } else if (next_id_ == num_samples_) { | } else if (next_id_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | ||||
| @@ -41,7 +41,7 @@ class RandomSampler : public Sampler { | |||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| // meant to be called by base class or python | // meant to be called by base class or python | ||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| @@ -33,11 +33,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||||
| } | } | ||||
| Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) | Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) | ||||
| : DatasetOp(0), | |||||
| num_rows_(0), | |||||
| num_samples_(num_samples), | |||||
| samples_per_buffer_(samples_per_buffer), | |||||
| col_desc_(nullptr) {} | |||||
| : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | |||||
| Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | ||||
| std::shared_ptr<Sampler> child_sampler; | std::shared_ptr<Sampler> child_sampler; | ||||
| @@ -97,7 +93,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||||
| std::shared_ptr<Tensor> sample_ids; | std::shared_ptr<Tensor> sample_ids; | ||||
| // A call to derived class to get sample ids wrapped inside a buffer | // A call to derived class to get sample ids wrapped inside a buffer | ||||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||||
| RETURN_IF_NOT_OK(GetNextSample(&db)); | |||||
| // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch | // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch | ||||
| RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); | RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); | ||||
| // check this buffer is not a ctrl buffer | // check this buffer is not a ctrl buffer | ||||
| @@ -114,7 +110,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||||
| } | } | ||||
| } | } | ||||
| // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch | // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch | ||||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||||
| RETURN_IF_NOT_OK(GetNextSample(&db)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); | CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); | ||||
| // Reset Sampler since this is the end of the epoch | // Reset Sampler since this is the end of the epoch | ||||
| RETURN_IF_NOT_OK(Reset()); | RETURN_IF_NOT_OK(Reset()); | ||||
| @@ -133,17 +129,7 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // inline op doesn't have it's own consumer, it's assigned from parent | |||||
| int32_t Sampler::num_consumers() const { | |||||
| if (parent_.empty() || parent_[0] == nullptr) { | |||||
| MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0."; | |||||
| return 0; | |||||
| } else { | |||||
| return parent_[0]->num_consumers(); | |||||
| } | |||||
| } | |||||
| Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) { | |||||
| Status Sampler::AddChild(std::shared_ptr<Sampler> child) { | |||||
| if (child == nullptr) { | if (child == nullptr) { | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -182,14 +168,5 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // inline op doesn't have it's own producers, it's assigned from child | |||||
| int32_t Sampler::num_producers() const { | |||||
| if (child_.empty() || child_[0] == nullptr) { | |||||
| MS_LOG(WARNING) << "Sampler with no child, num_producers is 0."; | |||||
| return 0; | |||||
| } else { | |||||
| return child_[0]->num_producers(); | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -54,7 +54,7 @@ class RandomAccessOp { | |||||
| int64_t num_rows_; | int64_t num_rows_; | ||||
| }; | }; | ||||
| class Sampler : public DatasetOp { | |||||
| class Sampler { | |||||
| public: | public: | ||||
| // Constructor | // Constructor | ||||
| // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | ||||
| @@ -70,14 +70,14 @@ class Sampler : public DatasetOp { | |||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0; | |||||
| virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0; | |||||
| // return all ids in one epoch as a numpy array, then call reset | // return all ids in one epoch as a numpy array, then call reset | ||||
| Status GetAllIdsThenReset(py::array *data); | Status GetAllIdsThenReset(py::array *data); | ||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status Reset() override = 0; | |||||
| virtual Status Reset() = 0; | |||||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | // first handshake between leaf source op and Sampler. This func will determine the amount of data | ||||
| // in the dataset that we can sample from. | // in the dataset that we can sample from. | ||||
| @@ -98,26 +98,10 @@ class Sampler : public DatasetOp { | |||||
| // @return status error code | // @return status error code | ||||
| Status SetNumRowsInDataset(int64_t num_rows); | Status SetNumRowsInDataset(int64_t num_rows); | ||||
| // Sampler is an inlined op and has no workers. Producers and consumers are computed. | |||||
| // @return | |||||
| int32_t num_workers() const final { return 0; } | |||||
| // Identify num consumers (inlined op) | |||||
| // @return | |||||
| int32_t num_consumers() const final; | |||||
| // Identify num producers (inlined op) | |||||
| // @return | |||||
| int32_t num_producers() const final; | |||||
| // Not meant to be called! | |||||
| // @return - The error code return | |||||
| Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); } | |||||
| // Adds a sampler to become our child. | // Adds a sampler to become our child. | ||||
| // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | ||||
| // @return - The error code returned. | // @return - The error code returned. | ||||
| Status AddChild(std::shared_ptr<DatasetOp> child); | |||||
| Status AddChild(std::shared_ptr<Sampler> child); | |||||
| // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | ||||
| // @param std::shared_ptr<Tensor>* sampleIds | // @param std::shared_ptr<Tensor>* sampleIds | ||||
| @@ -125,7 +109,7 @@ class Sampler : public DatasetOp { | |||||
| // @return - The error code returned. | // @return - The error code returned. | ||||
| Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); | Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); | ||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| virtual void Print(std::ostream &out, bool show_all) const; | |||||
| friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { | friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { | ||||
| sampler.Print(out, false); | sampler.Print(out, false); | ||||
| @@ -156,6 +140,7 @@ class Sampler : public DatasetOp { | |||||
| int64_t samples_per_buffer_; | int64_t samples_per_buffer_; | ||||
| std::unique_ptr<ColDescriptor> col_desc_; | std::unique_ptr<ColDescriptor> col_desc_; | ||||
| std::vector<std::shared_ptr<Sampler>> child_; // Child nodes | |||||
| std::unique_ptr<DataBuffer> child_ids_; | std::unique_ptr<DataBuffer> child_ids_; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -23,14 +23,14 @@ namespace dataset { | |||||
| SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | ||||
| : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | ||||
| Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (id_count_ > num_samples_) { | if (id_count_ > num_samples_) { | ||||
| RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | ||||
| } else if (id_count_ == num_samples_) { | } else if (id_count_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone); | (*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone); | ||||
| @@ -47,7 +47,7 @@ class SequentialSampler : public Sampler { | |||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| @@ -72,13 +72,13 @@ Status SubsetRandomSampler::Reset() { | |||||
| } | } | ||||
| // Get the sample ids. | // Get the sample ids. | ||||
| Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| // All samples have been drawn | // All samples have been drawn | ||||
| if (sample_id_ == num_samples_) { | if (sample_id_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | ||||
| @@ -49,7 +49,7 @@ class SubsetRandomSampler : public Sampler { | |||||
| // Get the sample ids. | // Get the sample ids. | ||||
| // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | ||||
| // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| private: | private: | ||||
| // A list of indices (already randomized in constructor). | // A list of indices (already randomized in constructor). | ||||
| @@ -95,7 +95,7 @@ Status WeightedRandomSampler::Reset() { | |||||
| } | } | ||||
| // Get the sample ids. | // Get the sample ids. | ||||
| Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (weights_.size() > static_cast<size_t>(num_rows_)) { | if (weights_.size() > static_cast<size_t>(num_rows_)) { | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | ||||
| "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); | "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); | ||||
| @@ -109,7 +109,7 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf | |||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||||
| } | } | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | ||||
| @@ -51,7 +51,7 @@ class WeightedRandomSampler : public Sampler { | |||||
| // Get the sample ids. | // Get the sample ids. | ||||
| // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | ||||
| // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||||
| private: | private: | ||||
| // A list of weights for each sample. | // A list of weights for each sample. | ||||
| @@ -123,7 +123,7 @@ Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std:: | |||||
| Status VOCOp::operator()() { | Status VOCOp::operator()() { | ||||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | ||||
| std::unique_ptr<DataBuffer> sampler_buffer; | std::unique_ptr<DataBuffer> sampler_buffer; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| while (true) { | while (true) { | ||||
| std::vector<int64_t> keys; | std::vector<int64_t> keys; | ||||
| keys.reserve(rows_per_buffer_); | keys.reserve(rows_per_buffer_); | ||||
| @@ -134,7 +134,7 @@ Status VOCOp::operator()() { | |||||
| RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); | RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); | RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| } | } | ||||
| if (keys.empty() == false) { | if (keys.empty() == false) { | ||||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | ||||
| @@ -155,7 +155,7 @@ Status VOCOp::operator()() { | |||||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK(wp_.Wait()); | RETURN_IF_NOT_OK(wp_.Wait()); | ||||
| wp_.Clear(); | wp_.Clear(); | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -68,7 +68,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { | |||||
| for (int i = 0; i < 6; i++) { | for (int i = 0; i < 6; i++) { | ||||
| std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 3, i % 3, (i < 3 ? false : true)); | std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 3, i % 3, (i < 3 ? false : true)); | ||||
| sampler->HandshakeRandomAccessOp(&mock); | sampler->HandshakeRandomAccessOp(&mock); | ||||
| sampler->GetNextBuffer(&db); | |||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| MS_LOG(DEBUG) << (*tensor); | MS_LOG(DEBUG) << (*tensor); | ||||
| if(i < 3) { // This is added due to std::shuffle() | if(i < 3) { // This is added due to std::shuffle() | ||||
| @@ -90,17 +90,17 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| std::shared_ptr<Tensor> tensor; | std::shared_ptr<Tensor> tensor; | ||||
| sampler->HandshakeRandomAccessOp(&mock); | sampler->HandshakeRandomAccessOp(&mock); | ||||
| sampler->GetNextBuffer(&db); | |||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| EXPECT_TRUE((*tensor) == (*label1)); | EXPECT_TRUE((*tensor) == (*label1)); | ||||
| sampler->GetNextBuffer(&db); | |||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| EXPECT_TRUE((*tensor) == (*label2)); | EXPECT_TRUE((*tensor) == (*label2)); | ||||
| sampler->Reset(); | sampler->Reset(); | ||||
| sampler->GetNextBuffer(&db); | |||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| EXPECT_TRUE((*tensor) == (*label1)); | EXPECT_TRUE((*tensor) == (*label1)); | ||||
| sampler->GetNextBuffer(&db); | |||||
| sampler->GetNextSample(&db); | |||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| EXPECT_TRUE((*tensor) == (*label2)); | EXPECT_TRUE((*tensor) == (*label2)); | ||||
| } | } | ||||
| @@ -49,7 +49,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| @@ -61,7 +61,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||||
| ASSERT_NE(in_set.find(out[i]), in_set.end()); | ASSERT_NE(in_set.find(out[i]), in_set.end()); | ||||
| } | } | ||||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| } | } | ||||
| @@ -79,7 +79,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| int epoch = 0; | int epoch = 0; | ||||
| while (!db->eoe()) { | while (!db->eoe()) { | ||||
| epoch++; | epoch++; | ||||
| @@ -91,7 +91,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||||
| } | } | ||||
| db.reset(); | db.reset(); | ||||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| } | } | ||||
| ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer); | ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer); | ||||
| @@ -111,7 +111,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<int64_t> out; | std::vector<int64_t> out; | ||||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | ||||
| @@ -125,7 +125,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||||
| sampler.Reset(); | sampler.Reset(); | ||||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), false); | ASSERT_EQ(db->eoe(), false); | ||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| out.clear(); | out.clear(); | ||||
| @@ -139,6 +139,6 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||||
| ASSERT_NE(in_set.find(out[i]), in_set.end()); | ASSERT_NE(in_set.find(out[i]), in_set.end()); | ||||
| } | } | ||||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| } | } | ||||
| @@ -58,7 +58,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| @@ -69,7 +69,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||||
| ASSERT_EQ(num_samples, out.size()); | ASSERT_EQ(num_samples, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| } | } | ||||
| @@ -88,7 +88,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| @@ -105,7 +105,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||||
| } | } | ||||
| } | } | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| } | } | ||||
| @@ -124,7 +124,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| int epoch = 0; | int epoch = 0; | ||||
| while (!db->eoe()) { | while (!db->eoe()) { | ||||
| epoch++; | epoch++; | ||||
| @@ -135,7 +135,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | |||||
| } | } | ||||
| } | } | ||||
| db.reset(); | db.reset(); | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| } | } | ||||
| ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer); | ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer); | ||||
| @@ -160,7 +160,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| int epoch = 0; | int epoch = 0; | ||||
| while (!db->eoe()) { | while (!db->eoe()) { | ||||
| epoch++; | epoch++; | ||||
| @@ -172,7 +172,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { | |||||
| } | } | ||||
| } | } | ||||
| db.reset(); | db.reset(); | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| } | } | ||||
| // Without replacement, each sample only drawn once. | // Without replacement, each sample only drawn once. | ||||
| @@ -201,7 +201,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| @@ -211,13 +211,13 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||||
| } | } | ||||
| ASSERT_EQ(num_samples, out.size()); | ASSERT_EQ(num_samples, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| m_sampler.Reset(); | m_sampler.Reset(); | ||||
| out.clear(); | out.clear(); | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| @@ -227,7 +227,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||||
| } | } | ||||
| ASSERT_EQ(num_samples, out.size()); | ASSERT_EQ(num_samples, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| } | } | ||||
| @@ -246,7 +246,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| std::vector<uint64_t> out; | std::vector<uint64_t> out; | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| @@ -256,7 +256,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| } | } | ||||
| ASSERT_EQ(num_samples, out.size()); | ASSERT_EQ(num_samples, out.size()); | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| m_sampler.Reset(); | m_sampler.Reset(); | ||||
| @@ -265,7 +265,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| freq.resize(total_samples, 0); | freq.resize(total_samples, 0); | ||||
| MS_LOG(INFO) << "Resetting sampler"; | MS_LOG(INFO) << "Resetting sampler"; | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| db->PopRow(&row); | db->PopRow(&row); | ||||
| for (const auto &t : row) { | for (const auto &t : row) { | ||||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | ||||
| @@ -282,6 +282,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| } | } | ||||
| } | } | ||||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| } | } | ||||