Browse Source

fix num_samples in concatDataset

tags/v1.0.0
liyong 5 years ago
parent
commit
16147669a6
3 changed files with 19 additions and 1 deletions
  1. +2
    -0
      mindspore/dataset/engine/datasets.py
  2. +0
    -1
      tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json
  3. +17
    -0
      tests/ut/python/dataset/test_paddeddataset.py

+ 2
- 0
mindspore/dataset/engine/datasets.py View File

@@ -2438,6 +2438,8 @@ class ConcatDataset(DatasetOp):
self._sampler = _select_sampler(None, sampler, None, None, None)
cumulative_samples_nums = 0
for index, child in enumerate(self.children):
if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
raise ValueError("The parameter NumSamples of %s is not support to be set!" % (child))

if isinstance(child, BatchDataset):
raise TypeError("The parameter %s of concat should't be BatchDataset!" % (child))


+ 0
- 1
tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json View File

@@ -1,6 +1,5 @@
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"image": {
"type": "uint8",


+ 17
- 0
tests/ut/python/dataset/test_paddeddataset.py View File

@@ -213,6 +213,23 @@ def test_raise_error():
ds3.use_sampler(testsampler)
assert excinfo.type == 'ValueError'

def test_imagefolder_error():
DATA_DIR = "../data/dataset/testPK/data"
data = ds.ImageFolderDataset(DATA_DIR, num_samples=14)

data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
{'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
{'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
{'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
{'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
{'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]

data2 = ds.PaddedDataset(data1)
data3 = data + data2
with pytest.raises(ValueError) as excinfo:
testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
data3.use_sampler(testsampler)
assert excinfo.type == 'ValueError'

def test_imagefolder_padded():
DATA_DIR = "../data/dataset/testPK/data"


Loading…
Cancel
Save