Browse Source

!2256 Rename Sampler::Reset() to Sampler::ResetSampler()

Merge pull request !2256 from JesseKLee/sampler
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a4048e192c
26 changed files with 44 additions and 34 deletions
  1. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
  2. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
  3. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
  4. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
  5. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
  6. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
  7. +2
    -2
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  8. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
  9. +2
    -2
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
  10. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
  11. +2
    -2
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
  12. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
  13. +2
    -2
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
  14. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
  15. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
  16. +11
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
  17. +2
    -2
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
  18. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
  19. +2
    -2
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
  20. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
  21. +2
    -2
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
  22. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
  23. +1
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
  24. +1
    -1
      tests/ut/cpp/dataset/stand_alone_samplers_test.cc
  25. +1
    -1
      tests/ut/cpp/dataset/subset_random_sampler_test.cc
  26. +2
    -2
      tests/ut/cpp/dataset/weighted_random_sampler_test.cc

+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc View File

@@ -409,7 +409,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const {


// Reset Sampler and wakeup Master thread (functor) // Reset Sampler and wakeup Master thread (functor)
Status CelebAOp::Reset() { Status CelebAOp::Reset() {
RETURN_IF_NOT_OK(sampler_->Reset());
RETURN_IF_NOT_OK(sampler_->ResetSampler());
wp_.Set(); // wake up master thread after reset is done wp_.Set(); // wake up master thread after reset is done
return Status::OK(); return Status::OK();
} }


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc View File

@@ -241,7 +241,7 @@ void CifarOp::Print(std::ostream &out, bool show_all) const {


// Reset Sampler and wakeup Master thread (functor) // Reset Sampler and wakeup Master thread (functor)
Status CifarOp::Reset() { Status CifarOp::Reset() {
RETURN_IF_NOT_OK(sampler_->Reset());
RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set(); // wake up master thread after reset is done wp_.Set(); // wake up master thread after reset is done
return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc View File

@@ -207,7 +207,7 @@ void CocoOp::Print(std::ostream &out, bool show_all) const {
} }


Status CocoOp::Reset() { Status CocoOp::Reset() {
RETURN_IF_NOT_OK(sampler_->Reset());
RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set(); wp_.Set();
return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc View File

@@ -252,7 +252,7 @@ void ImageFolderOp::Print(std::ostream &out, bool show_all) const {


// Reset Sampler and wakeup Master thread (functor) // Reset Sampler and wakeup Master thread (functor)
Status ImageFolderOp::Reset() { Status ImageFolderOp::Reset() {
RETURN_IF_NOT_OK(sampler_->Reset());
RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set(); // wake up master thread after reset is done wp_.Set(); // wake up master thread after reset is done
return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc View File

@@ -241,7 +241,7 @@ void ManifestOp::Print(std::ostream &out, bool show_all) const {


// Reset Sampler and wakeup Master thread (functor) // Reset Sampler and wakeup Master thread (functor)
Status ManifestOp::Reset() { Status ManifestOp::Reset() {
RETURN_IF_NOT_OK(sampler_->Reset());
RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set(); // wake up master thread after reset is done wp_.Set(); // wake up master thread after reset is done
return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc View File

@@ -204,7 +204,7 @@ void MnistOp::Print(std::ostream &out, bool show_all) const {


// Reset Sampler and wakeup Master thread (functor) // Reset Sampler and wakeup Master thread (functor)
Status MnistOp::Reset() { Status MnistOp::Reset() {
RETURN_IF_NOT_OK(sampler_->Reset());
RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set(); // wake up master thread after reset is done wp_.Set(); // wake up master thread after reset is done
return Status::OK(); return Status::OK();


+ 2
- 2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -89,7 +89,7 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
return Status::OK(); return Status::OK();
} }


Status DistributedSampler::Reset() {
Status DistributedSampler::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
cnt_ = 0; cnt_ = 0;


@@ -100,7 +100,7 @@ Status DistributedSampler::Reset() {
} }


if (HasChildSampler()) { if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
} }


return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h View File

@@ -47,7 +47,7 @@ class DistributedSampler : public Sampler {


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return
Status Reset() override;
Status ResetSampler() override;


void Print(std::ostream &out, bool show_all) const override; void Print(std::ostream &out, bool show_all) const override;




+ 2
- 2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc View File

@@ -94,13 +94,13 @@ Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
return Status::OK(); return Status::OK();
} }


Status PKSampler::Reset() {
Status PKSampler::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0; next_id_ = 0;
rnd_.seed(seed_++); rnd_.seed(seed_++);


if (HasChildSampler()) { if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
} }


return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h View File

@@ -54,7 +54,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return
Status Reset() override;
Status ResetSampler() override;


private: private:
bool shuffle_; bool shuffle_;


+ 2
- 2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc View File

@@ -84,7 +84,7 @@ Status PythonSampler::InitSampler() {
return Status::OK(); return Status::OK();
} }


Status PythonSampler::Reset() {
Status PythonSampler::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch");
need_to_reset_ = false; need_to_reset_ = false;
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
@@ -98,7 +98,7 @@ Status PythonSampler::Reset() {
} }


if (HasChildSampler()) { if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
} }


return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h View File

@@ -42,7 +42,7 @@ class PythonSampler : public Sampler {


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return
Status Reset() override;
Status ResetSampler() override;


// Op calls this to get next Buffer that contains all the sampleIds // Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp


+ 2
- 2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc View File

@@ -91,7 +91,7 @@ Status RandomSampler::InitSampler() {
return Status::OK(); return Status::OK();
} }


Status RandomSampler::Reset() {
Status RandomSampler::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0; next_id_ = 0;


@@ -106,7 +106,7 @@ Status RandomSampler::Reset() {
} }


if (HasChildSampler()) { if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
} }


return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h View File

@@ -48,7 +48,7 @@ class RandomSampler : public Sampler {


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return
Status Reset() override;
Status ResetSampler() override;


virtual void Print(std::ostream &out, bool show_all) const; virtual void Print(std::ostream &out, bool show_all) const;




+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc View File

@@ -113,7 +113,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
RETURN_IF_NOT_OK(GetNextSample(&db)); RETURN_IF_NOT_OK(GetNextSample(&db));
CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received");
// Reset Sampler since this is the end of the epoch // Reset Sampler since this is the end of the epoch
RETURN_IF_NOT_OK(Reset());
RETURN_IF_NOT_OK(ResetSampler());
return Status::OK(); return Status::OK();
} }




+ 11
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h View File

@@ -62,6 +62,8 @@ class Sampler {
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); explicit Sampler(int64_t num_samples, int64_t samples_per_buffer);


Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {}

// default destructor // default destructor
~Sampler() = default; ~Sampler() = default;


@@ -77,7 +79,7 @@ class Sampler {


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return
virtual Status Reset() = 0;
virtual Status ResetSampler() = 0;


// first handshake between leaf source op and Sampler. This func will determine the amount of data // first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from. // in the dataset that we can sample from.
@@ -109,8 +111,16 @@ class Sampler {
// @return - The error code returned. // @return - The error code returned.
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);


// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
virtual void Print(std::ostream &out, bool show_all) const; virtual void Print(std::ostream &out, bool show_all) const;


// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param sampler - reference to teh sampler to print
// @return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
sampler.Print(out, false); sampler.Print(out, false);
return out; return out;


+ 2
- 2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc View File

@@ -77,13 +77,13 @@ Status SequentialSampler::InitSampler() {
return Status::OK(); return Status::OK();
} }


Status SequentialSampler::Reset() {
Status SequentialSampler::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late");
current_id_ = start_index_; current_id_ = start_index_;
id_count_ = 0; id_count_ = 0;


if (HasChildSampler()) { if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
} }


return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h View File

@@ -41,7 +41,7 @@ class SequentialSampler : public Sampler {


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return
Status Reset() override;
Status ResetSampler() override;


// Op calls this to get next Buffer that contains all the sampleIds // Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp


+ 2
- 2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc View File

@@ -55,7 +55,7 @@ Status SubsetRandomSampler::InitSampler() {
} }


// Reset the internal variable to the initial state. // Reset the internal variable to the initial state.
Status SubsetRandomSampler::Reset() {
Status SubsetRandomSampler::ResetSampler() {
// Reset the internal counters. // Reset the internal counters.
sample_id_ = 0; sample_id_ = 0;
buffer_id_ = 0; buffer_id_ = 0;
@@ -65,7 +65,7 @@ Status SubsetRandomSampler::Reset() {
std::shuffle(indices_.begin(), indices_.end(), rand_gen_); std::shuffle(indices_.begin(), indices_.end(), rand_gen_);


if (HasChildSampler()) { if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
} }


return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h View File

@@ -44,7 +44,7 @@ class SubsetRandomSampler : public Sampler {


// Reset the internal variable to the initial state and reshuffle the indices. // Reset the internal variable to the initial state and reshuffle the indices.
// @return Status // @return Status
Status Reset() override;
Status ResetSampler() override;


// Get the sample ids. // Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.


+ 2
- 2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc View File

@@ -77,7 +77,7 @@ void WeightedRandomSampler::InitOnePassSampling() {
} }


// Reset the internal variable to the initial state and reshuffle the indices. // Reset the internal variable to the initial state and reshuffle the indices.
Status WeightedRandomSampler::Reset() {
Status WeightedRandomSampler::ResetSampler() {
sample_id_ = 0; sample_id_ = 0;
buffer_id_ = 0; buffer_id_ = 0;
rand_gen_.seed(GetSeed()); rand_gen_.seed(GetSeed());
@@ -88,7 +88,7 @@ Status WeightedRandomSampler::Reset() {
} }


if (HasChildSampler()) { if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
} }


return Status::OK(); return Status::OK();


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h View File

@@ -46,7 +46,7 @@ class WeightedRandomSampler : public Sampler {
Status InitSampler() override; Status InitSampler() override;


// Reset the internal variable to the initial state and reshuffle the indices. // Reset the internal variable to the initial state and reshuffle the indices.
Status Reset() override;
Status ResetSampler() override;


// Get the sample ids. // Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.


+ 1
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc View File

@@ -177,7 +177,7 @@ void VOCOp::Print(std::ostream &out, bool show_all) const {
} }


Status VOCOp::Reset() { Status VOCOp::Reset() {
RETURN_IF_NOT_OK(sampler_->Reset());
RETURN_IF_NOT_OK(sampler_->ResetSampler());
row_cnt_ = 0; row_cnt_ = 0;
wp_.Set(); wp_.Set();
return Status::OK(); return Status::OK();


+ 1
- 1
tests/ut/cpp/dataset/stand_alone_samplers_test.cc View File

@@ -94,7 +94,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
sampler->GetNextSample(&db); sampler->GetNextSample(&db);
db->GetTensor(&tensor, 0, 0); db->GetTensor(&tensor, 0, 0);
EXPECT_TRUE((*tensor) == (*label2)); EXPECT_TRUE((*tensor) == (*label2));
sampler->Reset();
sampler->ResetSampler();
sampler->GetNextSample(&db); sampler->GetNextSample(&db);
db->GetTensor(&tensor, 0, 0); db->GetTensor(&tensor, 0, 0);
EXPECT_TRUE((*tensor) == (*label1)); EXPECT_TRUE((*tensor) == (*label1));


+ 1
- 1
tests/ut/cpp/dataset/subset_random_sampler_test.cc View File

@@ -123,7 +123,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
ASSERT_NE(in_set.find(out[i]), in_set.end()); ASSERT_NE(in_set.find(out[i]), in_set.end());
} }


sampler.Reset();
sampler.ResetSampler();


ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), false); ASSERT_EQ(db->eoe(), false);


+ 2
- 2
tests/ut/cpp/dataset/weighted_random_sampler_test.cc View File

@@ -214,7 +214,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true); ASSERT_EQ(db->eoe(), true);


m_sampler.Reset();
m_sampler.ResetSampler();
out.clear(); out.clear();


ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
@@ -259,7 +259,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true); ASSERT_EQ(db->eoe(), true);


m_sampler.Reset();
m_sampler.ResetSampler();
out.clear(); out.clear();
freq.clear(); freq.clear();
freq.resize(total_samples, 0); freq.resize(total_samples, 0);


Loading…
Cancel
Save