| @@ -505,8 +505,8 @@ Status CelebAOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| num_rows = std::min(num_rows, partition_num); | num_rows = std::min(num_rows, partition_num); | ||||
| } | } | ||||
| sample_size = sampler_->GetNumSamples(); | |||||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -518,8 +518,8 @@ Status CifarOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| num_rows = num_rows_; | num_rows = num_rows_; | ||||
| if (num_rows_ <= 0) | if (num_rows_ <= 0) | ||||
| RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows)); | RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows)); | ||||
| sample_size = sampler_->GetNumSamples(); | |||||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -705,8 +705,8 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| if (image_ids_.size() == 0) { | if (image_ids_.size() == 0) { | ||||
| RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); | RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); | ||||
| } | } | ||||
| sample_size = sampler_->GetNumSamples(); | |||||
| *dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -475,8 +475,8 @@ Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| // GetDatasetSize will not be impacted by class_index_ | // GetDatasetSize will not be impacted by class_index_ | ||||
| RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, nullptr, {})); | RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, nullptr, {})); | ||||
| } | } | ||||
| sample_size = sampler_->GetNumSamples(); | |||||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -469,8 +469,8 @@ Status ManifestOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); | RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); | ||||
| RETURN_IF_NOT_OK(op->ParseManifestFile()); | RETURN_IF_NOT_OK(op->ParseManifestFile()); | ||||
| num_rows = static_cast<int64_t>(op->image_labelname_.size()); | num_rows = static_cast<int64_t>(op->image_labelname_.size()); | ||||
| sample_size = sampler_->GetNumSamples(); | |||||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -480,8 +480,8 @@ Status MnistOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| int64_t num_rows, sample_size; | int64_t num_rows, sample_size; | ||||
| num_rows = num_rows_; | num_rows = num_rows_; | ||||
| if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows)); | if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows)); | ||||
| sample_size = sampler_->GetNumSamples(); | |||||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -427,10 +427,15 @@ Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| *dataset_size = dataset_size_; | *dataset_size = dataset_size_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t num_rows, sample_size = 0; | |||||
| int64_t num_rows; | |||||
| num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | ||||
| if (sampler_ != nullptr) sample_size = sampler_->GetNumSamples(); | |||||
| *dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| if (sampler_ != nullptr) { | |||||
| int64_t sample_size; | |||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| } else { | |||||
| *dataset_size = num_rows; | |||||
| } | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | ||||
| #include <algorithm> | |||||
| #include <limits> | #include <limits> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -160,6 +161,15 @@ Status DistributedSamplerRT::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| int64_t childs = num_rows; | |||||
| if (!child_.empty()) { | |||||
| childs = child_[0]->CalculateNumSamples(num_rows); | |||||
| } | |||||
| int64_t num_samples = (num_samples_ > 0) ? std::min(childs, num_samples_) : childs; | |||||
| return std::ceil(num_samples * 1.0 / num_devices_); | |||||
| } | |||||
| void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const { | void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const { | ||||
| out << "\nSampler: DistributedSampler"; | out << "\nSampler: DistributedSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| @@ -63,6 +63,8 @@ class DistributedSamplerRT : public SamplerRT { | |||||
| int64_t GetDeviceNum() { return num_devices_; } | int64_t GetDeviceNum() { return num_devices_; } | ||||
| int64_t CalculateNumSamples(int64_t num_rows) override; | |||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| private: | private: | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include <algorithm> | |||||
| #include <string> | #include <string> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -131,6 +132,15 @@ Status SamplerRT::SetNumSamples(int64_t num_samples) { | |||||
| int64_t SamplerRT::GetNumSamples() { return num_samples_; } | int64_t SamplerRT::GetNumSamples() { return num_samples_; } | ||||
| int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| int64_t childs = num_rows; | |||||
| if (!child_.empty()) { | |||||
| childs = child_[0]->CalculateNumSamples(num_rows); | |||||
| } | |||||
| return (num_samples_ > 0) ? std::min(childs, num_samples_) : childs; | |||||
| } | |||||
| Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) { | Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); | CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); | ||||
| num_rows_ = num_rows; | num_rows_ = num_rows; | ||||
| @@ -99,10 +99,14 @@ class SamplerRT { | |||||
| Status SetNumSamples(int64_t num_samples); | Status SetNumSamples(int64_t num_samples); | ||||
| // getter for num samples | // getter for num samples | ||||
| // @param num_samples - the number of samples to return. | |||||
| // @return status error code | |||||
| // @return number of samples | |||||
| int64_t GetNumSamples(); | int64_t GetNumSamples(); | ||||
| // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of | |||||
| // num_samples_ | |||||
| // @return number of samples | |||||
| virtual int64_t CalculateNumSamples(int64_t num_rows); | |||||
| // setter for num or records in the dataset | // setter for num or records in the dataset | ||||
| // @param num_rows - the number of records | // @param num_rows - the number of records | ||||
| // @return status error code | // @return status error code | ||||
| @@ -537,8 +537,8 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| num_rows = static_cast<int64_t>(op->image_ids_.size()); | num_rows = static_cast<int64_t>(op->image_ids_.size()); | ||||
| } | } | ||||
| } | } | ||||
| sample_size = sampler_->GetNumSamples(); | |||||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -82,6 +82,75 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestCalculateNumSamples) { | |||||
| int64_t num_rows = 30; // dummy variable for number of rows in the dataset | |||||
| std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1, false, 6); | |||||
| EXPECT_NE(sampl, nullptr); | |||||
| std::shared_ptr<SamplerRT> sampler_rt = sampl->Build(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 3); | |||||
| sampl = PKSampler(3, false); | |||||
| EXPECT_NE(sampl, nullptr); | |||||
| sampler_rt = sampl->Build(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); | |||||
| sampl = RandomSampler(false, 12); | |||||
| EXPECT_NE(sampl, nullptr); | |||||
| sampler_rt = sampl->Build(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); | |||||
| sampl = SequentialSampler(0, 10); | |||||
| EXPECT_NE(sampl, nullptr); | |||||
| sampler_rt = sampl->Build(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); | |||||
| std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; | |||||
| sampl = WeightedRandomSampler(weights, 12); | |||||
| EXPECT_NE(sampl, nullptr); | |||||
| sampler_rt = sampl->Build(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); | |||||
| std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; | |||||
| sampl = SubsetRandomSampler(indices, 11); | |||||
| EXPECT_NE(sampl, nullptr); | |||||
| sampler_rt = sampl->Build(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); | |||||
| // Testing chains | |||||
| // Parent and child have num_samples | |||||
| std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights, 12); | |||||
| EXPECT_NE(sampl1, nullptr); | |||||
| std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->Build(); | |||||
| std::shared_ptr<SamplerObj> sampl2 = SequentialSampler(0, 10); | |||||
| EXPECT_NE(sampl2, nullptr); | |||||
| std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->Build(); | |||||
| sampler_rt2->AddChild(sampler_rt1); | |||||
| EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); | |||||
| // Parent doesn't have num_samples | |||||
| std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights, 12); | |||||
| EXPECT_NE(sampl3, nullptr); | |||||
| std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->Build(); | |||||
| std::shared_ptr<SamplerObj> sampl4 = SubsetRandomSampler(indices); | |||||
| EXPECT_NE(sampl4, nullptr); | |||||
| std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->Build(); | |||||
| sampler_rt4->AddChild(sampler_rt3); | |||||
| EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 12); | |||||
| // Child doesn't have num_samples | |||||
| std::shared_ptr<SamplerObj> sampl5 = RandomSampler(false); | |||||
| EXPECT_NE(sampl5, nullptr); | |||||
| std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->Build(); | |||||
| std::shared_ptr<SamplerObj> sampl6 = PKSampler(3, false, 7); | |||||
| EXPECT_NE(sampl6, nullptr); | |||||
| std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->Build(); | |||||
| sampler_rt6->AddChild(sampler_rt5); | |||||
| EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { | TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { | ||||
| std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; | std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; | ||||
| std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices); | std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices); | ||||