Browse Source

!5184 fix: padded dataset when no div and with repeat op

Merge pull request !5184 from guozhijian/fix_padded_with_no_div_repeat
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
20b3134785
2 changed files with 18 additions and 0 deletions
  1. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  2. +15
    -0
      tests/ut/python/dataset/test_paddeddataset.py

+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -75,6 +75,9 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); RETURN_STATUS_UNEXPECTED("Distributed Sampler Error");
} else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) { } else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
if (!samples_per_buffer_) {
non_empty_ = false;
}
} else if (!samples_per_buffer_ && !non_empty_) { } else if (!samples_per_buffer_ && !non_empty_) {
// If the buffer is empty, we add samples with subscript 0 in the current dataset. // If the buffer is empty, we add samples with subscript 0 in the current dataset.
// This step is to make up for the solution that the code default buffer is not empty before. // This step is to make up for the solution that the code default buffer is not empty before.


+ 15
- 0
tests/ut/python/dataset/test_paddeddataset.py View File

@@ -454,6 +454,21 @@ def test_clue_padded_and_skip_with_0_samples():
count += 1 count += 1
assert count == 2 assert count == 2


def test_celeba_padded():
data = ds.CelebADataset("../data/dataset/testCelebAData/")

padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
padded_ds = ds.PaddedDataset(padded_samples)
data = data + padded_ds
dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
data.use_sampler(dis_sampler)
data = data.repeat(2)

count = 0
for _ in data.create_dict_iterator():
count = count + 1
assert count == 2

if __name__ == '__main__': if __name__ == '__main__':
test_TFRecord_Padded() test_TFRecord_Padded()
test_GeneratorDataSet_Padded() test_GeneratorDataSet_Padded()


Loading…
Cancel
Save