Browse Source

!2380 Fix CocoDataset issue

Merge pull request !2380 from xiefangqi/xfq_fix_coco_issue_01
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
78a8bc302d
3 changed files with 17 additions and 3 deletions
  1. +4
    -3
      mindspore/dataset/engine/datasets.py
  2. +3
    -0
      mindspore/dataset/engine/validators.py
  3. +10
    -0
      tests/ut/python/dataset/test_datasets_coco.py

+ 4
- 3
mindspore/dataset/engine/datasets.py View File

@@ -996,7 +996,8 @@ class Dataset:
def get_distribution(output_dataset):
dev_id = 0
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)):
ManifestDataset, MnistDataset, VOCDataset, CocoDataset, CelebADataset,
MindDataset)):
sampler = output_dataset.sampler
if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id
@@ -4171,8 +4172,8 @@ class CocoDataset(MappableDataset):
- task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
['iscrowd', dtype=uint32], ['area', dtype=uint32]].

This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. CocoDataset doesn't support
PKSampler. Table below shows what input args are allowed and their expected behavior.

.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50


+ 3
- 0
mindspore/dataset/engine/validators.py View File

@@ -397,6 +397,9 @@ def check_cocodataset(method):

check_param_type(nreq_param_bool, param_dict, bool)

sampler = param_dict.get('sampler')
if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CocoDataset doesn't support PKSampler")
check_sampler_shuffle_shard_options(param_dict)

return method(*args, **kwargs)


+ 10
- 0
tests/ut/python/dataset/test_datasets_coco.py View File

@@ -251,6 +251,16 @@ def test_coco_case_exception():
except RuntimeError as e:
assert "json.exception.parse_error" in str(e)

try:
sampler = ds.PKSampler(3)
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection", sampler=sampler)
for _ in data1.__iter__():
pass
assert False
except ValueError as e:
assert "CocoDataset doesn't support PKSampler" in str(e)


if __name__ == '__main__':
test_coco_detection()
test_coco_stuff()


Loading…
Cancel
Save