Browse Source

fix pk sampler's get_dataset error due to num_class unavaiable at pre-runtime

tags/v1.2.0-rc1
Zirui Wu 4 years ago
parent
commit
2692e3cc3d
15 changed files with 72 additions and 4 deletions
  1. +5
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h
  2. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc
  3. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h
  4. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
  5. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc
  6. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc
  7. +3
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc
  8. +5
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc
  9. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc
  10. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  11. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc
  12. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc
  13. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc
  14. +2
    -2
      tests/ut/cpp/dataset/ir_sampler_test.cc
  15. +32
    -0
      tests/ut/python/dataset/test_datasets_cifarop.py

+ 5
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h View File

@@ -66,6 +66,11 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief PK cannot return an exact value because num_classes is not known until runtime, hence -1 is used
/// \param[out] num_rows
/// \return -1, which means PKSampler doesn't know how much data
int64_t CalculateNumSamples(int64_t num_rows) override { return -1; }

private: private:
bool shuffle_; bool shuffle_;
uint32_t seed_; uint32_t seed_;


+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc View File

@@ -140,6 +140,8 @@ int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t child_num_rows = num_rows; int64_t child_num_rows = num_rows;
if (!child_.empty()) { if (!child_.empty()) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows); child_num_rows = child_[0]->CalculateNumSamples(num_rows);
// return -1 if child_num_rows is undetermined
if (child_num_rows == -1) return child_num_rows;
} }


return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h View File

@@ -108,7 +108,7 @@ class SamplerRT {


// Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of
// num_samples_ // num_samples_
// @return number of samples
// @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler)
virtual int64_t CalculateNumSamples(int64_t num_rows); virtual int64_t CalculateNumSamples(int64_t num_rows);


// setter for num or records in the dataset // setter for num or records in the dataset


+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc View File

@@ -109,6 +109,8 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t child_num_rows = num_rows; int64_t child_num_rows = num_rows;
if (!child_.empty()) { if (!child_.empty()) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows); child_num_rows = child_[0]->CalculateNumSamples(num_rows);
// return -1 if child_num_rows is undetermined
if (child_num_rows == -1) return child_num_rows;
} }
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows // For this sampler we need to take start_index into account. Because for example in the case we are given n rows


+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc View File

@@ -139,6 +139,8 @@ int64_t SubsetSamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t child_num_rows = num_rows; int64_t child_num_rows = num_rows;
if (!child_.empty()) { if (!child_.empty()) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows); child_num_rows = child_[0]->CalculateNumSamples(num_rows);
// return -1 if child_num_rows is undetermined
if (child_num_rows == -1) return child_num_rows;
} }
int64_t res = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; int64_t res = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
res = std::min(res, static_cast<int64_t>(indices_.size())); res = std::min(res, static_cast<int64_t>(indices_.size()));


+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -144,6 +144,9 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows); sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size; *dataset_size = sample_size;
return Status::OK(); return Status::OK();
} }


+ 3
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc View File

@@ -95,7 +95,9 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows); sample_size = sampler_rt->CalculateNumSamples(num_rows);

if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 5
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc View File

@@ -88,12 +88,17 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
*dataset_size = dataset_size_; *dataset_size = dataset_size_;
return Status::OK(); return Status::OK();
} }

int64_t num_rows, sample_size; int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows); sample_size = sampler_rt->CalculateNumSamples(num_rows);


if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}

*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc View File

@@ -151,6 +151,9 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows); sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -100,6 +100,9 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows); sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc View File

@@ -123,6 +123,9 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows); sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc View File

@@ -88,6 +88,9 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows); sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc View File

@@ -139,6 +139,9 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows); sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 2
- 2
tests/ut/cpp/dataset/ir_sampler_test.cc View File

@@ -36,7 +36,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) {
sampl = std::make_shared<PKSamplerObj>(3, false, 0); sampl = std::make_shared<PKSamplerObj>(3, false, 0);
EXPECT_NE(sampl, nullptr); EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt); sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), -1);


sampl = std::make_shared<RandomSamplerObj>(false, 12); sampl = std::make_shared<RandomSamplerObj>(false, 12);
EXPECT_NE(sampl, nullptr); EXPECT_NE(sampl, nullptr);
@@ -98,7 +98,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) {
std::shared_ptr<SamplerRT> sampler_rt6; std::shared_ptr<SamplerRT> sampler_rt6;
sampl6->SamplerBuild(&sampler_rt6); sampl6->SamplerBuild(&sampler_rt6);
sampler_rt6->AddChild(sampler_rt5); sampler_rt6->AddChild(sampler_rt5);
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), -1);
} }


TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) { TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) {


+ 32
- 0
tests/ut/python/dataset/test_datasets_cifarop.py View File

@@ -501,6 +501,35 @@ def test_cifar_exception_file_path():
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)




def test_cifar10_pk_sampler_get_dataset_size():
"""
Test Cifar10Dataset with PKSampler and get_dataset_size
"""
sampler = ds.PKSampler(3)
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
num_iter = 0
ds_sz = data.get_dataset_size()
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1

assert ds_sz == num_iter == 30


def test_cifar10_with_chained_sampler_get_dataset_size():
"""
Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size
"""
sampler = ds.SequentialSampler(start_index=0, num_samples=5)
child_sampler = ds.PKSampler(4)
sampler.add_child(child_sampler)
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
num_iter = 0
ds_sz = data.get_dataset_size()
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert ds_sz == num_iter == 5


if __name__ == '__main__': if __name__ == '__main__':
test_cifar10_content_check() test_cifar10_content_check()
test_cifar10_basic() test_cifar10_basic()
@@ -517,3 +546,6 @@ if __name__ == '__main__':


test_cifar_usage() test_cifar_usage()
test_cifar_exception_file_path() test_cifar_exception_file_path()

test_cifar10_with_chained_sampler_get_dataset_size()
test_cifar10_pk_sampler_get_dataset_size()

Loading…
Cancel
Save