Merge pull request !1983 from JesseKLee/samplertags/v0.5.0-beta
| @@ -263,7 +263,7 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) { | |||
| Status CelebAOp::operator()() { | |||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | |||
| std::unique_ptr<DataBuffer> data_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&data_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer)); | |||
| RETURN_IF_NOT_OK(AddIOBlock(&data_buffer)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -291,7 +291,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||
| keys.clear(); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); | |||
| } | |||
| if (!keys.empty()) { | |||
| @@ -313,7 +313,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||
| io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); | |||
| } | |||
| } | |||
| } | |||
| @@ -100,7 +100,7 @@ CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const | |||
| Status CifarOp::operator()() { | |||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | |||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| while (true) { // each iterator is 1 epoch | |||
| std::vector<int64_t> keys; | |||
| keys.reserve(rows_per_buffer_); | |||
| @@ -118,7 +118,7 @@ Status CifarOp::operator()() { | |||
| keys.clear(); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| if (keys.empty() == false) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | |||
| @@ -139,7 +139,7 @@ Status CifarOp::operator()() { | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| } | |||
| } | |||
| @@ -126,7 +126,7 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { | |||
| Status ImageFolderOp::operator()() { | |||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | |||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| while (true) { // each iterator is 1 epoch | |||
| std::vector<int64_t> keys; | |||
| keys.reserve(rows_per_buffer_); | |||
| @@ -145,7 +145,7 @@ Status ImageFolderOp::operator()() { | |||
| keys.clear(); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| if (keys.empty() == false) { | |||
| RETURN_IF_NOT_OK( | |||
| @@ -166,7 +166,7 @@ Status ImageFolderOp::operator()() { | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| } | |||
| } | |||
| @@ -88,7 +88,7 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f | |||
| Status ManifestOp::operator()() { | |||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | |||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| return AddIoBlock(&sampler_buffer); | |||
| } | |||
| @@ -110,7 +110,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||
| keys.clear(); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); | |||
| } | |||
| if (keys.empty() == false) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | |||
| @@ -131,7 +131,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); | |||
| } | |||
| } | |||
| } | |||
| @@ -98,7 +98,7 @@ Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, st | |||
| Status MnistOp::operator()() { | |||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | |||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| while (true) { // each iterator is 1 epoch | |||
| std::vector<int64_t> keys; | |||
| keys.reserve(rows_per_buffer_); | |||
| @@ -109,7 +109,7 @@ Status MnistOp::operator()() { | |||
| RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64"); | |||
| } | |||
| RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| if (keys.empty() == false) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | |||
| @@ -130,7 +130,7 @@ Status MnistOp::operator()() { | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| } | |||
| } | |||
| @@ -55,14 +55,14 @@ Status DistributedSampler::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (cnt_ > samples_per_buffer_) { | |||
| RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); | |||
| } else if (cnt_ == samples_per_buffer_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||
| } | |||
| (*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone); | |||
| @@ -40,7 +40,7 @@ class DistributedSampler : public Sampler { | |||
| // @param std::unique_ptr<DataBuffer> * pBuffer | |||
| // @param int32_t workerId | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // Init sampler, called by base class or python | |||
| Status InitSampler() override; | |||
| @@ -59,14 +59,14 @@ Status PKSampler::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (next_id_ > num_samples_ || num_samples_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); | |||
| } else if (next_id_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||
| } | |||
| (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | |||
| @@ -41,7 +41,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED | |||
| // @param std::unique_ptr<DataBuffer pBuffer | |||
| // @param int32_t workerId | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | |||
| // in the dataset that we can sample from. | |||
| @@ -23,12 +23,12 @@ namespace dataset { | |||
| PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||
| Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (need_to_reset_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||
| } | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| @@ -48,7 +48,7 @@ class PythonSampler : public Sampler { | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| private: | |||
| bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | |||
| @@ -31,14 +31,14 @@ RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuff | |||
| reshuffle_each_epoch_(reshuffle_each_epoch), | |||
| dist(nullptr) {} | |||
| Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (next_id_ > num_samples_) { | |||
| RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); | |||
| } else if (next_id_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||
| } | |||
| (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | |||
| @@ -41,7 +41,7 @@ class RandomSampler : public Sampler { | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // meant to be called by base class or python | |||
| Status InitSampler() override; | |||
| @@ -33,11 +33,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||
| } | |||
| Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) | |||
| : DatasetOp(0), | |||
| num_rows_(0), | |||
| num_samples_(num_samples), | |||
| samples_per_buffer_(samples_per_buffer), | |||
| col_desc_(nullptr) {} | |||
| : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | |||
| Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| std::shared_ptr<Sampler> child_sampler; | |||
| @@ -97,7 +93,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| // A call to derived class to get sample ids wrapped inside a buffer | |||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||
| RETURN_IF_NOT_OK(GetNextSample(&db)); | |||
| // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch | |||
| RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); | |||
| // check this buffer is not a ctrl buffer | |||
| @@ -114,7 +110,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||
| } | |||
| } | |||
| // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch | |||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||
| RETURN_IF_NOT_OK(GetNextSample(&db)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); | |||
| // Reset Sampler since this is the end of the epoch | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| @@ -133,17 +129,7 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) { | |||
| return Status::OK(); | |||
| } | |||
| // inline op doesn't have it's own consumer, it's assigned from parent | |||
| int32_t Sampler::num_consumers() const { | |||
| if (parent_.empty() || parent_[0] == nullptr) { | |||
| MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0."; | |||
| return 0; | |||
| } else { | |||
| return parent_[0]->num_consumers(); | |||
| } | |||
| } | |||
| Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) { | |||
| Status Sampler::AddChild(std::shared_ptr<Sampler> child) { | |||
| if (child == nullptr) { | |||
| return Status::OK(); | |||
| } | |||
| @@ -182,14 +168,5 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | |||
| return Status::OK(); | |||
| } | |||
| // inline op doesn't have it's own producers, it's assigned from child | |||
| int32_t Sampler::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(WARNING) << "Sampler with no child, num_producers is 0."; | |||
| return 0; | |||
| } else { | |||
| return child_[0]->num_producers(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -54,7 +54,7 @@ class RandomAccessOp { | |||
| int64_t num_rows_; | |||
| }; | |||
| class Sampler : public DatasetOp { | |||
| class Sampler { | |||
| public: | |||
| // Constructor | |||
| // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | |||
| @@ -70,14 +70,14 @@ class Sampler : public DatasetOp { | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0; | |||
| virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0; | |||
| // return all ids in one epoch as a numpy array, then call reset | |||
| Status GetAllIdsThenReset(py::array *data); | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| Status Reset() override = 0; | |||
| virtual Status Reset() = 0; | |||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | |||
| // in the dataset that we can sample from. | |||
| @@ -98,26 +98,10 @@ class Sampler : public DatasetOp { | |||
| // @return status error code | |||
| Status SetNumRowsInDataset(int64_t num_rows); | |||
| // Sampler is an inlined op and has no workers. Producers and consumers are computed. | |||
| // @return | |||
| int32_t num_workers() const final { return 0; } | |||
| // Identify num consumers (inlined op) | |||
| // @return | |||
| int32_t num_consumers() const final; | |||
| // Identify num producers (inlined op) | |||
| // @return | |||
| int32_t num_producers() const final; | |||
| // Not meant to be called! | |||
| // @return - The error code return | |||
| Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); } | |||
| // Adds a sampler to become our child. | |||
| // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | |||
| // @return - The error code returned. | |||
| Status AddChild(std::shared_ptr<DatasetOp> child); | |||
| Status AddChild(std::shared_ptr<Sampler> child); | |||
| // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | |||
| // @param std::shared_ptr<Tensor>* sampleIds | |||
| @@ -125,7 +109,7 @@ class Sampler : public DatasetOp { | |||
| // @return - The error code returned. | |||
| Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| virtual void Print(std::ostream &out, bool show_all) const; | |||
| friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { | |||
| sampler.Print(out, false); | |||
| @@ -156,6 +140,7 @@ class Sampler : public DatasetOp { | |||
| int64_t samples_per_buffer_; | |||
| std::unique_ptr<ColDescriptor> col_desc_; | |||
| std::vector<std::shared_ptr<Sampler>> child_; // Child nodes | |||
| std::unique_ptr<DataBuffer> child_ids_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -23,14 +23,14 @@ namespace dataset { | |||
| SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | |||
| Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (id_count_ > num_samples_) { | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | |||
| } else if (id_count_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||
| } | |||
| (*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone); | |||
| @@ -47,7 +47,7 @@ class SequentialSampler : public Sampler { | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| @@ -72,13 +72,13 @@ Status SubsetRandomSampler::Reset() { | |||
| } | |||
| // Get the sample ids. | |||
| Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| // All samples have been drawn | |||
| if (sample_id_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||
| } | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | |||
| @@ -49,7 +49,7 @@ class SubsetRandomSampler : public Sampler { | |||
| // Get the sample ids. | |||
| // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | |||
| // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| private: | |||
| // A list of indices (already randomized in constructor). | |||
| @@ -95,7 +95,7 @@ Status WeightedRandomSampler::Reset() { | |||
| } | |||
| // Get the sample ids. | |||
| Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (weights_.size() > static_cast<size_t>(num_rows_)) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); | |||
| @@ -109,7 +109,7 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); | |||
| } | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | |||
| @@ -51,7 +51,7 @@ class WeightedRandomSampler : public Sampler { | |||
| // Get the sample ids. | |||
| // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | |||
| // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| private: | |||
| // A list of weights for each sample. | |||
| @@ -123,7 +123,7 @@ Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std:: | |||
| Status VOCOp::operator()() { | |||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | |||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| while (true) { | |||
| std::vector<int64_t> keys; | |||
| keys.reserve(rows_per_buffer_); | |||
| @@ -134,7 +134,7 @@ Status VOCOp::operator()() { | |||
| RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); | |||
| } | |||
| RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| if (keys.empty() == false) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( | |||
| @@ -155,7 +155,7 @@ Status VOCOp::operator()() { | |||
| io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| wp_.Clear(); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| } | |||
| } | |||
| @@ -68,7 +68,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { | |||
| for (int i = 0; i < 6; i++) { | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 3, i % 3, (i < 3 ? false : true)); | |||
| sampler->HandshakeRandomAccessOp(&mock); | |||
| sampler->GetNextBuffer(&db); | |||
| sampler->GetNextSample(&db); | |||
| db->GetTensor(&tensor, 0, 0); | |||
| MS_LOG(DEBUG) << (*tensor); | |||
| if(i < 3) { // This is added due to std::shuffle() | |||
| @@ -90,17 +90,17 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| std::shared_ptr<Tensor> tensor; | |||
| sampler->HandshakeRandomAccessOp(&mock); | |||
| sampler->GetNextBuffer(&db); | |||
| sampler->GetNextSample(&db); | |||
| db->GetTensor(&tensor, 0, 0); | |||
| EXPECT_TRUE((*tensor) == (*label1)); | |||
| sampler->GetNextBuffer(&db); | |||
| sampler->GetNextSample(&db); | |||
| db->GetTensor(&tensor, 0, 0); | |||
| EXPECT_TRUE((*tensor) == (*label2)); | |||
| sampler->Reset(); | |||
| sampler->GetNextBuffer(&db); | |||
| sampler->GetNextSample(&db); | |||
| db->GetTensor(&tensor, 0, 0); | |||
| EXPECT_TRUE((*tensor) == (*label1)); | |||
| sampler->GetNextBuffer(&db); | |||
| sampler->GetNextSample(&db); | |||
| db->GetTensor(&tensor, 0, 0); | |||
| EXPECT_TRUE((*tensor) == (*label2)); | |||
| } | |||
| @@ -49,7 +49,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<int64_t> out; | |||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | |||
| @@ -61,7 +61,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||
| ASSERT_NE(in_set.find(out[i]), in_set.end()); | |||
| } | |||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||
| @@ -79,7 +79,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||
| TensorRow row; | |||
| std::vector<int64_t> out; | |||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||
| int epoch = 0; | |||
| while (!db->eoe()) { | |||
| epoch++; | |||
| @@ -91,7 +91,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||
| } | |||
| db.reset(); | |||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||
| } | |||
| ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer); | |||
| @@ -111,7 +111,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||
| TensorRow row; | |||
| std::vector<int64_t> out; | |||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { | |||
| @@ -125,7 +125,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||
| sampler.Reset(); | |||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), false); | |||
| db->PopRow(&row); | |||
| out.clear(); | |||
| @@ -139,6 +139,6 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||
| ASSERT_NE(in_set.find(out[i]), in_set.end()); | |||
| } | |||
| ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||
| @@ -58,7 +58,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| @@ -69,7 +69,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||
| ASSERT_EQ(num_samples, out.size()); | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||
| @@ -88,7 +88,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| @@ -105,7 +105,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||
| } | |||
| } | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||
| @@ -124,7 +124,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| int epoch = 0; | |||
| while (!db->eoe()) { | |||
| epoch++; | |||
| @@ -135,7 +135,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | |||
| } | |||
| } | |||
| db.reset(); | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| } | |||
| ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer); | |||
| @@ -160,7 +160,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| int epoch = 0; | |||
| while (!db->eoe()) { | |||
| epoch++; | |||
| @@ -172,7 +172,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { | |||
| } | |||
| } | |||
| db.reset(); | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| } | |||
| // Without replacement, each sample only drawn once. | |||
| @@ -201,7 +201,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| @@ -211,13 +211,13 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||
| } | |||
| ASSERT_EQ(num_samples, out.size()); | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| m_sampler.Reset(); | |||
| out.clear(); | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| @@ -227,7 +227,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||
| } | |||
| ASSERT_EQ(num_samples, out.size()); | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||
| @@ -246,7 +246,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| std::vector<uint64_t> out; | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| @@ -256,7 +256,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||
| } | |||
| ASSERT_EQ(num_samples, out.size()); | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| m_sampler.Reset(); | |||
| @@ -265,7 +265,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||
| freq.resize(total_samples, 0); | |||
| MS_LOG(INFO) << "Resetting sampler"; | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| db->PopRow(&row); | |||
| for (const auto &t : row) { | |||
| for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { | |||
| @@ -282,6 +282,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||
| } | |||
| } | |||
| ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); | |||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | |||
| ASSERT_EQ(db->eoe(), true); | |||
| } | |||