| @@ -66,6 +66,11 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief PK cannot return an exact value because num_classes is not known until runtime, hence -1 is used | |||||
| /// \param[out] num_rows | |||||
| /// \return -1, which means PKSampler doesn't know how much data | |||||
| int64_t CalculateNumSamples(int64_t num_rows) override { return -1; } | |||||
| private: | private: | ||||
| bool shuffle_; | bool shuffle_; | ||||
| uint32_t seed_; | uint32_t seed_; | ||||
| @@ -140,6 +140,8 @@ int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| int64_t child_num_rows = num_rows; | int64_t child_num_rows = num_rows; | ||||
| if (!child_.empty()) { | if (!child_.empty()) { | ||||
| child_num_rows = child_[0]->CalculateNumSamples(num_rows); | child_num_rows = child_[0]->CalculateNumSamples(num_rows); | ||||
| // return -1 if child_num_rows is undetermined | |||||
| if (child_num_rows == -1) return child_num_rows; | |||||
| } | } | ||||
| return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | ||||
| @@ -108,7 +108,7 @@ class SamplerRT { | |||||
| // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of | // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of | ||||
| // num_samples_ | // num_samples_ | ||||
| // @return number of samples | |||||
| // @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler) | |||||
| virtual int64_t CalculateNumSamples(int64_t num_rows); | virtual int64_t CalculateNumSamples(int64_t num_rows); | ||||
| // setter for num or records in the dataset | // setter for num or records in the dataset | ||||
| @@ -109,6 +109,8 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| int64_t child_num_rows = num_rows; | int64_t child_num_rows = num_rows; | ||||
| if (!child_.empty()) { | if (!child_.empty()) { | ||||
| child_num_rows = child_[0]->CalculateNumSamples(num_rows); | child_num_rows = child_[0]->CalculateNumSamples(num_rows); | ||||
| // return -1 if child_num_rows is undetermined | |||||
| if (child_num_rows == -1) return child_num_rows; | |||||
| } | } | ||||
| int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | ||||
| // For this sampler we need to take start_index into account. Because for example in the case we are given n rows | // For this sampler we need to take start_index into account. Because for example in the case we are given n rows | ||||
| @@ -139,6 +139,8 @@ int64_t SubsetSamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| int64_t child_num_rows = num_rows; | int64_t child_num_rows = num_rows; | ||||
| if (!child_.empty()) { | if (!child_.empty()) { | ||||
| child_num_rows = child_[0]->CalculateNumSamples(num_rows); | child_num_rows = child_[0]->CalculateNumSamples(num_rows); | ||||
| // return -1 if child_num_rows is undetermined | |||||
| if (child_num_rows == -1) return child_num_rows; | |||||
| } | } | ||||
| int64_t res = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | int64_t res = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | ||||
| res = std::min(res, static_cast<int64_t>(indices_.size())); | res = std::min(res, static_cast<int64_t>(indices_.size())); | ||||
| @@ -144,6 +144,9 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | std::shared_ptr<SamplerRT> sampler_rt = nullptr; | ||||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | ||||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | sample_size = sampler_rt->CalculateNumSamples(num_rows); | ||||
| if (sample_size == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||||
| } | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -95,7 +95,9 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | std::shared_ptr<SamplerRT> sampler_rt = nullptr; | ||||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | ||||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | sample_size = sampler_rt->CalculateNumSamples(num_rows); | ||||
| if (sample_size == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||||
| } | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -88,12 +88,17 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||||
| *dataset_size = dataset_size_; | *dataset_size = dataset_size_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t num_rows, sample_size; | int64_t num_rows, sample_size; | ||||
| RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); | RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); | ||||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | std::shared_ptr<SamplerRT> sampler_rt = nullptr; | ||||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | ||||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | sample_size = sampler_rt->CalculateNumSamples(num_rows); | ||||
| if (sample_size == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||||
| } | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -151,6 +151,9 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | std::shared_ptr<SamplerRT> sampler_rt = nullptr; | ||||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | ||||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | sample_size = sampler_rt->CalculateNumSamples(num_rows); | ||||
| if (sample_size == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||||
| } | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -100,6 +100,9 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> | |||||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | std::shared_ptr<SamplerRT> sampler_rt = nullptr; | ||||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | ||||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | sample_size = sampler_rt->CalculateNumSamples(num_rows); | ||||
| if (sample_size == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||||
| } | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -123,6 +123,9 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | std::shared_ptr<SamplerRT> sampler_rt = nullptr; | ||||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | ||||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | sample_size = sampler_rt->CalculateNumSamples(num_rows); | ||||
| if (sample_size == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||||
| } | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -88,6 +88,9 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ | |||||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | std::shared_ptr<SamplerRT> sampler_rt = nullptr; | ||||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | ||||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | sample_size = sampler_rt->CalculateNumSamples(num_rows); | ||||
| if (sample_size == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||||
| } | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -139,6 +139,9 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge | |||||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | std::shared_ptr<SamplerRT> sampler_rt = nullptr; | ||||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | ||||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | sample_size = sampler_rt->CalculateNumSamples(num_rows); | ||||
| if (sample_size == -1) { | |||||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||||
| } | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -36,7 +36,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) { | |||||
| sampl = std::make_shared<PKSamplerObj>(3, false, 0); | sampl = std::make_shared<PKSamplerObj>(3, false, 0); | ||||
| EXPECT_NE(sampl, nullptr); | EXPECT_NE(sampl, nullptr); | ||||
| sampl->SamplerBuild(&sampler_rt); | sampl->SamplerBuild(&sampler_rt); | ||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), -1); | |||||
| sampl = std::make_shared<RandomSamplerObj>(false, 12); | sampl = std::make_shared<RandomSamplerObj>(false, 12); | ||||
| EXPECT_NE(sampl, nullptr); | EXPECT_NE(sampl, nullptr); | ||||
| @@ -98,7 +98,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) { | |||||
| std::shared_ptr<SamplerRT> sampler_rt6; | std::shared_ptr<SamplerRT> sampler_rt6; | ||||
| sampl6->SamplerBuild(&sampler_rt6); | sampl6->SamplerBuild(&sampler_rt6); | ||||
| sampler_rt6->AddChild(sampler_rt5); | sampler_rt6->AddChild(sampler_rt5); | ||||
| EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); | |||||
| EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), -1); | |||||
| } | } | ||||
| TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) { | TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) { | ||||
| @@ -501,6 +501,35 @@ def test_cifar_exception_file_path(): | |||||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) | assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) | ||||
| def test_cifar10_pk_sampler_get_dataset_size(): | |||||
| """ | |||||
| Test Cifar10Dataset with PKSampler and get_dataset_size | |||||
| """ | |||||
| sampler = ds.PKSampler(3) | |||||
| data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) | |||||
| num_iter = 0 | |||||
| ds_sz = data.get_dataset_size() | |||||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| num_iter += 1 | |||||
| assert ds_sz == num_iter == 30 | |||||
| def test_cifar10_with_chained_sampler_get_dataset_size(): | |||||
| """ | |||||
| Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size | |||||
| """ | |||||
| sampler = ds.SequentialSampler(start_index=0, num_samples=5) | |||||
| child_sampler = ds.PKSampler(4) | |||||
| sampler.add_child(child_sampler) | |||||
| data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) | |||||
| num_iter = 0 | |||||
| ds_sz = data.get_dataset_size() | |||||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| num_iter += 1 | |||||
| assert ds_sz == num_iter == 5 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_cifar10_content_check() | test_cifar10_content_check() | ||||
| test_cifar10_basic() | test_cifar10_basic() | ||||
| @@ -517,3 +546,6 @@ if __name__ == '__main__': | |||||
| test_cifar_usage() | test_cifar_usage() | ||||
| test_cifar_exception_file_path() | test_cifar_exception_file_path() | ||||
| test_cifar10_with_chained_sampler_get_dataset_size() | |||||
| test_cifar10_pk_sampler_get_dataset_size() | |||||