Browse Source

!12348 Fix SequentialSampler issue

From: @mahdirahmanihanzaki
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
1e75ac45c4
2 changed files with 6 additions and 8 deletions
  1. +3
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
  2. +3
    -3
      tests/ut/python/dataset/test_sampler_chain.py

+ 3
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc View File

@@ -113,13 +113,11 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows
// and start_index != 0 and num_samples >= n then we can't return all the n rows.
if (child_num_rows - (start_index_ - current_id_) <= 0) {
if (child_num_rows - start_index_ <= 0) {
return 0;
}
if (child_num_rows - (start_index_ - current_id_) < num_samples)
num_samples = child_num_rows - (start_index_ - current_id_) > num_samples
? num_samples
: num_samples - (start_index_ - current_id_);
if (child_num_rows - start_index_ < num_samples)
num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_;
return num_samples;
}



+ 3
- 3
tests/ut/python/dataset/test_sampler_chain.py View File

@@ -70,16 +70,16 @@ def test_numpyslices_sampler_chain():
# Use 1 statement to add child sampler
np_data = [1, 2, 3, 4]
sampler = ds.SequentialSampler(start_index=1, num_samples=2)
sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)

# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 4
assert data1_size == 1

# Verify number of rows
assert sum([1 for _ in data1]) == 4
assert sum([1 for _ in data1]) == 1

# Verify dataset contents
res = []


Loading…
Cancel
Save