diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index 44a9806e29..a310c50fae 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -505,8 +505,8 @@ Status CelebAOp::GetDatasetSize(int64_t *dataset_size) { num_rows = std::min(num_rows, partition_num); } - sample_size = sampler_->GetNumSamples(); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + sample_size = sampler_->CalculateNumSamples(num_rows); + *dataset_size = sample_size; return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index ade39e0e2e..ca4b4b7f31 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -518,8 +518,8 @@ Status CifarOp::GetDatasetSize(int64_t *dataset_size) { num_rows = num_rows_; if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows)); - sample_size = sampler_->GetNumSamples(); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + sample_size = sampler_->CalculateNumSamples(num_rows); + *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index c9adcc79de..c41b8603f3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -705,8 +705,8 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) { if (image_ids_.size() == 0) { RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); } - sample_size = sampler_->GetNumSamples(); - *dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows; + sample_size = sampler_->CalculateNumSamples(num_rows); + *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index a3588c1146..b8b0b3c0a1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -475,8 +475,8 @@ Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) { // GetDatasetSize will not be impacted by class_index_ RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, nullptr, {})); } - sample_size = sampler_->GetNumSamples(); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + sample_size = sampler_->CalculateNumSamples(num_rows); + *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 8e0450ea32..24c48a4b07 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -469,8 +469,8 @@ Status ManifestOp::GetDatasetSize(int64_t *dataset_size) { RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); RETURN_IF_NOT_OK(op->ParseManifestFile()); num_rows = static_cast(op->image_labelname_.size()); - sample_size = sampler_->GetNumSamples(); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + sample_size = sampler_->CalculateNumSamples(num_rows); + *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index 72ce81cdb5..45e99b6ac7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -480,8 +480,8 @@ Status MnistOp::GetDatasetSize(int64_t *dataset_size) { int64_t num_rows, sample_size; num_rows = num_rows_; if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows)); - sample_size = sampler_->GetNumSamples(); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + sample_size = sampler_->CalculateNumSamples(num_rows); + *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 75c43d8c61..99a67e355b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -427,10 +427,15 @@ Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) { *dataset_size = dataset_size_; return Status::OK(); } - int64_t num_rows, sample_size = 0; + int64_t num_rows; num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); - if (sampler_ != nullptr) sample_size = sampler_->GetNumSamples(); - *dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows; + if (sampler_ != nullptr) { + int64_t sample_size; + sample_size = sampler_->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + } else { + *dataset_size = num_rows; + } dataset_size_ = *dataset_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 1f4671a092..bcb046257c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -15,6 +15,7 @@ */ #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include #include #include @@ -160,6 +161,15 @@ Status DistributedSamplerRT::ResetSampler() { return Status::OK(); } +int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { + int64_t childs = num_rows; + if (!child_.empty()) { + childs = child_[0]->CalculateNumSamples(num_rows); + } + int64_t num_samples = (num_samples_ > 0) ? std::min(childs, num_samples_) : childs; + return std::ceil(num_samples * 1.0 / num_devices_); +} + void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const { out << "\nSampler: DistributedSampler"; if (show_all) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h index 015ad23fd3..288b7eeb55 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -63,6 +63,8 @@ class DistributedSamplerRT : public SamplerRT { int64_t GetDeviceNum() { return num_devices_; } + int64_t CalculateNumSamples(int64_t num_rows) override; + void Print(std::ostream &out, bool show_all) const override; private: diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index 7441b62771..a754f1eee7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -15,6 +15,7 @@ */ #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include #include namespace mindspore { @@ -131,6 +132,15 @@ Status SamplerRT::SetNumSamples(int64_t num_samples) { int64_t SamplerRT::GetNumSamples() { return num_samples_; } +int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) { + int64_t childs = num_rows; + if (!child_.empty()) { + childs = child_[0]->CalculateNumSamples(num_rows); + } + + return (num_samples_ > 0) ? std::min(childs, num_samples_) : childs; +} + Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) { CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); num_rows_ = num_rows; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h index 76a8dee4a8..00c48caf2b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -99,10 +99,14 @@ class SamplerRT { Status SetNumSamples(int64_t num_samples); // getter for num samples - // @param num_samples - the number of samples to return. - // @return status error code + // @return number of samples int64_t GetNumSamples(); + // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of + // num_samples_ + // @return number of samples + virtual int64_t CalculateNumSamples(int64_t num_rows); + // setter for num or records in the dataset // @param num_rows - the number of records // @return status error code diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index a68b4fc59d..0e20db5252 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -537,8 +537,8 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) { num_rows = static_cast(op->image_ids_.size()); } } - sample_size = sampler_->GetNumSamples(); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + sample_size = sampler_->CalculateNumSamples(num_rows); + *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); } diff --git a/tests/ut/cpp/dataset/c_api_samplers_test.cc b/tests/ut/cpp/dataset/c_api_samplers_test.cc index b4931d0e3d..01a3092f22 100644 --- a/tests/ut/cpp/dataset/c_api_samplers_test.cc +++ b/tests/ut/cpp/dataset/c_api_samplers_test.cc @@ -82,6 +82,75 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCalculateNumSamples) { + int64_t num_rows = 30; // dummy variable for number of rows in the dataset + std::shared_ptr sampl = DistributedSampler(2, 1, false, 6); + EXPECT_NE(sampl, nullptr); + std::shared_ptr sampler_rt = sampl->Build(); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 3); + + sampl = PKSampler(3, false); + EXPECT_NE(sampl, nullptr); + sampler_rt = sampl->Build(); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); + + sampl = RandomSampler(false, 12); + EXPECT_NE(sampl, nullptr); + sampler_rt = sampl->Build(); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); + + sampl = SequentialSampler(0, 10); + EXPECT_NE(sampl, nullptr); + sampler_rt = sampl->Build(); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); + + std::vector weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; + sampl = WeightedRandomSampler(weights, 12); + EXPECT_NE(sampl, nullptr); + sampler_rt = sampl->Build(); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); + + std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; + sampl = SubsetRandomSampler(indices, 11); + EXPECT_NE(sampl, nullptr); + sampler_rt = sampl->Build(); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); + + // Testing chains + // Parent and child have num_samples + std::shared_ptr sampl1 = WeightedRandomSampler(weights, 12); + EXPECT_NE(sampl1, nullptr); + std::shared_ptr sampler_rt1 = sampl1->Build(); + + std::shared_ptr sampl2 = SequentialSampler(0, 10); + EXPECT_NE(sampl2, nullptr); + std::shared_ptr sampler_rt2 = sampl2->Build(); + sampler_rt2->AddChild(sampler_rt1); + EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); + + // Parent doesn't have num_samples + std::shared_ptr sampl3 = WeightedRandomSampler(weights, 12); + EXPECT_NE(sampl3, nullptr); + std::shared_ptr sampler_rt3 = sampl3->Build(); + + std::shared_ptr sampl4 = SubsetRandomSampler(indices); + EXPECT_NE(sampl4, nullptr); + std::shared_ptr sampler_rt4 = sampl4->Build(); + sampler_rt4->AddChild(sampler_rt3); + EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 12); + + // Child doesn't have num_samples + std::shared_ptr sampl5 = RandomSampler(false); + EXPECT_NE(sampl5, nullptr); + std::shared_ptr sampler_rt5 = sampl5->Build(); + + std::shared_ptr sampl6 = PKSampler(3, false, 7); + EXPECT_NE(sampl6, nullptr); + std::shared_ptr sampler_rt6 = sampl6->Build(); + sampler_rt6->AddChild(sampler_rt5); + EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); +} + TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; std::shared_ptr sampl1 = SubsetRandomSampler(indices);