Merge pull request !2256 from JesseKLee/samplertags/v0.5.0-beta
| @@ -409,7 +409,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Reset Sampler and wakeup Master thread (functor) | // Reset Sampler and wakeup Master thread (functor) | ||||
| Status CelebAOp::Reset() { | Status CelebAOp::Reset() { | ||||
| RETURN_IF_NOT_OK(sampler_->Reset()); | |||||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||||
| wp_.Set(); // wake up master thread after reset is done | wp_.Set(); // wake up master thread after reset is done | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -241,7 +241,7 @@ void CifarOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Reset Sampler and wakeup Master thread (functor) | // Reset Sampler and wakeup Master thread (functor) | ||||
| Status CifarOp::Reset() { | Status CifarOp::Reset() { | ||||
| RETURN_IF_NOT_OK(sampler_->Reset()); | |||||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||||
| row_cnt_ = 0; | row_cnt_ = 0; | ||||
| wp_.Set(); // wake up master thread after reset is done | wp_.Set(); // wake up master thread after reset is done | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -207,7 +207,7 @@ void CocoOp::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| Status CocoOp::Reset() { | Status CocoOp::Reset() { | ||||
| RETURN_IF_NOT_OK(sampler_->Reset()); | |||||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||||
| row_cnt_ = 0; | row_cnt_ = 0; | ||||
| wp_.Set(); | wp_.Set(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -252,7 +252,7 @@ void ImageFolderOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Reset Sampler and wakeup Master thread (functor) | // Reset Sampler and wakeup Master thread (functor) | ||||
| Status ImageFolderOp::Reset() { | Status ImageFolderOp::Reset() { | ||||
| RETURN_IF_NOT_OK(sampler_->Reset()); | |||||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||||
| row_cnt_ = 0; | row_cnt_ = 0; | ||||
| wp_.Set(); // wake up master thread after reset is done | wp_.Set(); // wake up master thread after reset is done | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -241,7 +241,7 @@ void ManifestOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Reset Sampler and wakeup Master thread (functor) | // Reset Sampler and wakeup Master thread (functor) | ||||
| Status ManifestOp::Reset() { | Status ManifestOp::Reset() { | ||||
| RETURN_IF_NOT_OK(sampler_->Reset()); | |||||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||||
| row_cnt_ = 0; | row_cnt_ = 0; | ||||
| wp_.Set(); // wake up master thread after reset is done | wp_.Set(); // wake up master thread after reset is done | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -204,7 +204,7 @@ void MnistOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Reset Sampler and wakeup Master thread (functor) | // Reset Sampler and wakeup Master thread (functor) | ||||
| Status MnistOp::Reset() { | Status MnistOp::Reset() { | ||||
| RETURN_IF_NOT_OK(sampler_->Reset()); | |||||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||||
| row_cnt_ = 0; | row_cnt_ = 0; | ||||
| wp_.Set(); // wake up master thread after reset is done | wp_.Set(); // wake up master thread after reset is done | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -89,7 +89,7 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DistributedSampler::Reset() { | |||||
| Status DistributedSampler::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); | CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); | ||||
| cnt_ = 0; | cnt_ = 0; | ||||
| @@ -100,7 +100,7 @@ Status DistributedSampler::Reset() { | |||||
| } | } | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||||
| RETURN_IF_NOT_OK(child_[0]->ResetSampler()); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -47,7 +47,7 @@ class DistributedSampler : public Sampler { | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status Reset() override; | |||||
| Status ResetSampler() override; | |||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| @@ -94,13 +94,13 @@ Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PKSampler::Reset() { | |||||
| Status PKSampler::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | ||||
| next_id_ = 0; | next_id_ = 0; | ||||
| rnd_.seed(seed_++); | rnd_.seed(seed_++); | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||||
| RETURN_IF_NOT_OK(child_[0]->ResetSampler()); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -54,7 +54,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status Reset() override; | |||||
| Status ResetSampler() override; | |||||
| private: | private: | ||||
| bool shuffle_; | bool shuffle_; | ||||
| @@ -84,7 +84,7 @@ Status PythonSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PythonSampler::Reset() { | |||||
| Status PythonSampler::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); | CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); | ||||
| need_to_reset_ = false; | need_to_reset_ = false; | ||||
| py::gil_scoped_acquire gil_acquire; | py::gil_scoped_acquire gil_acquire; | ||||
| @@ -98,7 +98,7 @@ Status PythonSampler::Reset() { | |||||
| } | } | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||||
| RETURN_IF_NOT_OK(child_[0]->ResetSampler()); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -42,7 +42,7 @@ class PythonSampler : public Sampler { | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status Reset() override; | |||||
| Status ResetSampler() override; | |||||
| // Op calls this to get next Buffer that contains all the sampleIds | // Op calls this to get next Buffer that contains all the sampleIds | ||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| @@ -91,7 +91,7 @@ Status RandomSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status RandomSampler::Reset() { | |||||
| Status RandomSampler::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | ||||
| next_id_ = 0; | next_id_ = 0; | ||||
| @@ -106,7 +106,7 @@ Status RandomSampler::Reset() { | |||||
| } | } | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||||
| RETURN_IF_NOT_OK(child_[0]->ResetSampler()); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -48,7 +48,7 @@ class RandomSampler : public Sampler { | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status Reset() override; | |||||
| Status ResetSampler() override; | |||||
| virtual void Print(std::ostream &out, bool show_all) const; | virtual void Print(std::ostream &out, bool show_all) const; | ||||
| @@ -113,7 +113,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||||
| RETURN_IF_NOT_OK(GetNextSample(&db)); | RETURN_IF_NOT_OK(GetNextSample(&db)); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); | CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); | ||||
| // Reset Sampler since this is the end of the epoch | // Reset Sampler since this is the end of the epoch | ||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| RETURN_IF_NOT_OK(ResetSampler()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -62,6 +62,8 @@ class Sampler { | |||||
| // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); | explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); | ||||
| Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {} | |||||
| // default destructor | // default destructor | ||||
| ~Sampler() = default; | ~Sampler() = default; | ||||
| @@ -77,7 +79,7 @@ class Sampler { | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| virtual Status Reset() = 0; | |||||
| virtual Status ResetSampler() = 0; | |||||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | // first handshake between leaf source op and Sampler. This func will determine the amount of data | ||||
| // in the dataset that we can sample from. | // in the dataset that we can sample from. | ||||
| @@ -109,8 +111,16 @@ class Sampler { | |||||
| // @return - The error code returned. | // @return - The error code returned. | ||||
| Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); | Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); | ||||
| // A print method typically used for debugging | |||||
| // @param out - The output stream to write output to | |||||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||||
| virtual void Print(std::ostream &out, bool show_all) const; | virtual void Print(std::ostream &out, bool show_all) const; | ||||
| // << Stream output operator overload | |||||
| // @notes This allows you to write the debug print info using stream operators | |||||
| // @param out - reference to the output stream being overloaded | |||||
| // @param sampler - reference to teh sampler to print | |||||
| // @return - the output stream must be returned | |||||
| friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { | friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { | ||||
| sampler.Print(out, false); | sampler.Print(out, false); | ||||
| return out; | return out; | ||||
| @@ -77,13 +77,13 @@ Status SequentialSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status SequentialSampler::Reset() { | |||||
| Status SequentialSampler::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); | CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); | ||||
| current_id_ = start_index_; | current_id_ = start_index_; | ||||
| id_count_ = 0; | id_count_ = 0; | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||||
| RETURN_IF_NOT_OK(child_[0]->ResetSampler()); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -41,7 +41,7 @@ class SequentialSampler : public Sampler { | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status Reset() override; | |||||
| Status ResetSampler() override; | |||||
| // Op calls this to get next Buffer that contains all the sampleIds | // Op calls this to get next Buffer that contains all the sampleIds | ||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| @@ -55,7 +55,7 @@ Status SubsetRandomSampler::InitSampler() { | |||||
| } | } | ||||
| // Reset the internal variable to the initial state. | // Reset the internal variable to the initial state. | ||||
| Status SubsetRandomSampler::Reset() { | |||||
| Status SubsetRandomSampler::ResetSampler() { | |||||
| // Reset the internal counters. | // Reset the internal counters. | ||||
| sample_id_ = 0; | sample_id_ = 0; | ||||
| buffer_id_ = 0; | buffer_id_ = 0; | ||||
| @@ -65,7 +65,7 @@ Status SubsetRandomSampler::Reset() { | |||||
| std::shuffle(indices_.begin(), indices_.end(), rand_gen_); | std::shuffle(indices_.begin(), indices_.end(), rand_gen_); | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||||
| RETURN_IF_NOT_OK(child_[0]->ResetSampler()); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -44,7 +44,7 @@ class SubsetRandomSampler : public Sampler { | |||||
| // Reset the internal variable to the initial state and reshuffle the indices. | // Reset the internal variable to the initial state and reshuffle the indices. | ||||
| // @return Status | // @return Status | ||||
| Status Reset() override; | |||||
| Status ResetSampler() override; | |||||
| // Get the sample ids. | // Get the sample ids. | ||||
| // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | ||||
| @@ -77,7 +77,7 @@ void WeightedRandomSampler::InitOnePassSampling() { | |||||
| } | } | ||||
| // Reset the internal variable to the initial state and reshuffle the indices. | // Reset the internal variable to the initial state and reshuffle the indices. | ||||
| Status WeightedRandomSampler::Reset() { | |||||
| Status WeightedRandomSampler::ResetSampler() { | |||||
| sample_id_ = 0; | sample_id_ = 0; | ||||
| buffer_id_ = 0; | buffer_id_ = 0; | ||||
| rand_gen_.seed(GetSeed()); | rand_gen_.seed(GetSeed()); | ||||
| @@ -88,7 +88,7 @@ Status WeightedRandomSampler::Reset() { | |||||
| } | } | ||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||||
| RETURN_IF_NOT_OK(child_[0]->ResetSampler()); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -46,7 +46,7 @@ class WeightedRandomSampler : public Sampler { | |||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| // Reset the internal variable to the initial state and reshuffle the indices. | // Reset the internal variable to the initial state and reshuffle the indices. | ||||
| Status Reset() override; | |||||
| Status ResetSampler() override; | |||||
| // Get the sample ids. | // Get the sample ids. | ||||
| // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | ||||
| @@ -177,7 +177,7 @@ void VOCOp::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| Status VOCOp::Reset() { | Status VOCOp::Reset() { | ||||
| RETURN_IF_NOT_OK(sampler_->Reset()); | |||||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||||
| row_cnt_ = 0; | row_cnt_ = 0; | ||||
| wp_.Set(); | wp_.Set(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -94,7 +94,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { | |||||
| sampler->GetNextSample(&db); | sampler->GetNextSample(&db); | ||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| EXPECT_TRUE((*tensor) == (*label2)); | EXPECT_TRUE((*tensor) == (*label2)); | ||||
| sampler->Reset(); | |||||
| sampler->ResetSampler(); | |||||
| sampler->GetNextSample(&db); | sampler->GetNextSample(&db); | ||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| EXPECT_TRUE((*tensor) == (*label1)); | EXPECT_TRUE((*tensor) == (*label1)); | ||||
| @@ -123,7 +123,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||||
| ASSERT_NE(in_set.find(out[i]), in_set.end()); | ASSERT_NE(in_set.find(out[i]), in_set.end()); | ||||
| } | } | ||||
| sampler.Reset(); | |||||
| sampler.ResetSampler(); | |||||
| ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); | ||||
| ASSERT_EQ(db->eoe(), false); | ASSERT_EQ(db->eoe(), false); | ||||
| @@ -214,7 +214,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | ||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| m_sampler.Reset(); | |||||
| m_sampler.ResetSampler(); | |||||
| out.clear(); | out.clear(); | ||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | ||||
| @@ -259,7 +259,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); | ||||
| ASSERT_EQ(db->eoe(), true); | ASSERT_EQ(db->eoe(), true); | ||||
| m_sampler.Reset(); | |||||
| m_sampler.ResetSampler(); | |||||
| out.clear(); | out.clear(); | ||||
| freq.clear(); | freq.clear(); | ||||
| freq.resize(total_samples, 0); | freq.resize(total_samples, 0); | ||||