Browse Source

!1704 fix num rows bug in sampler

Merge pull request !1704 from Peilin/fix-random-sampler-bug
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c366b3fb18
2 changed files with 19 additions and 0 deletions
  1. +1
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
  2. +18
    -0
      tests/ut/python/dataset/test_datasets_imagefolder.py

+ 1
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc View File

@@ -76,6 +76,7 @@ Status RandomSampler::InitSampler() {

if (replacement_ == false) {
num_samples_ = std::min(num_samples_, num_rows_);
num_samples_ = std::min(num_samples_, user_num_samples_);

shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) {


+ 18
- 0
tests/ut/python/dataset/test_datasets_imagefolder.py View File

@@ -57,6 +57,24 @@ def test_imagefolder_numsamples():
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 10

random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)

num_iter = 0
for item in data1.create_dict_iterator():
num_iter += 1

assert num_iter == 3

random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)

num_iter = 0
for item in data1.create_dict_iterator():
num_iter += 1

assert num_iter == 3


def test_imagefolder_numshards():
logger.info("Test Case numShards")


Loading…
Cancel
Save