diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 00b68878c1..cc99dc9052 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2438,6 +2438,8 @@ class ConcatDataset(DatasetOp): self._sampler = _select_sampler(None, sampler, None, None, None) cumulative_samples_nums = 0 for index, child in enumerate(self.children): + if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None: + raise ValueError("The parameter NumSamples of %s is not support to be set!" % (child)) if isinstance(child, BatchDataset): raise TypeError("The parameter %s of concat should't be BatchDataset!" % (child)) diff --git a/tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json b/tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json index eafcfd69ea..e00fd39c10 100644 --- a/tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json +++ b/tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json @@ -1,6 +1,5 @@ { "datasetType": "TF", - "numRows": 3, "columns": { "image": { "type": "uint8", diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py index cd7ef07ae7..fd21ee5882 100644 --- a/tests/ut/python/dataset/test_paddeddataset.py +++ b/tests/ut/python/dataset/test_paddeddataset.py @@ -213,6 +213,23 @@ def test_raise_error(): ds3.use_sampler(testsampler) assert excinfo.type == 'ValueError' +def test_imagefolder_error(): + DATA_DIR = "../data/dataset/testPK/data" + data = ds.ImageFolderDataset(DATA_DIR, num_samples=14) + + data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)}, + {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)}, + {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)}, + {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)}, + {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)}, + {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}] + + data2 = ds.PaddedDataset(data1) + data3 = data + data2 + with pytest.raises(ValueError) as excinfo: + testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None) + data3.use_sampler(testsampler) + assert excinfo.type == 'ValueError' def test_imagefolder_padded(): DATA_DIR = "../data/dataset/testPK/data"