Browse Source

!477 Fix VOC dataset test cases

Merge pull request !477 from xiefangqi/xfq_fix_voc
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
eda63a559a
3 changed files with 19 additions and 27 deletions
  1. +15
    -23
      mindspore/dataset/engine/datasets.py
  2. +2
    -1
      mindspore/dataset/engine/serializer_deserializer.py
  3. +2
    -3
      mindspore/dataset/engine/validators.py

+ 15
- 23
mindspore/dataset/engine/datasets.py View File

@@ -3335,14 +3335,17 @@ class VOCDataset(SourceDataset):
decode (bool, optional): Decode the images after reading (default=False).
sampler (Sampler, optional): Object used to choose samples from the dataset
(default=None, expected order behavior shown in the table).
distribution (str, optional): Path to the json distribution file to configure
dataset sharding (default=None). This argument should be specified
only when no 'sampler' is used.
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.

Raises:
RuntimeError: If distribution and sampler are specified at the same time.
RuntimeError: If distribution is failed to read.
RuntimeError: If shuffle and sampler are specified at the same time.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).

Examples:
>>> import mindspore.dataset as ds
@@ -3356,27 +3359,15 @@ class VOCDataset(SourceDataset):

@check_vocdataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
shuffle=None, decode=False, sampler=None, distribution=None):
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.sampler = sampler
if distribution is not None:
if sampler is not None:
raise RuntimeError("Cannot specify distribution and sampler at the same time.")
try:
with open(distribution, 'r') as load_d:
json.load(load_d)
except json.decoder.JSONDecodeError:
raise RuntimeError("Json decode error when load distribution file")
except Exception:
raise RuntimeError("Distribution file has failed to load.")
elif shuffle is not None:
if sampler is not None:
raise RuntimeError("Cannot specify shuffle and sampler at the same time.")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.decode = decode
self.distribution = distribution
self.shuffle_level = shuffle
self.num_shards = num_shards
self.shard_id = shard_id

def get_args(self):
args = super().get_args()
@@ -3385,7 +3376,8 @@ class VOCDataset(SourceDataset):
args["sampler"] = self.sampler
args["decode"] = self.decode
args["shuffle"] = self.shuffle_level
args["distribution"] = self.distribution
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args

def get_dataset_size(self):


+ 2
- 1
mindspore/dataset/engine/serializer_deserializer.py View File

@@ -286,7 +286,8 @@ def create_node(node):
elif dataset_op == 'VOCDataset':
sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
node.get('shuffle'), node.get('decode'), sampler, node.get('distribution'))
node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'),
node.get('shard_id'))

elif dataset_op == 'CelebADataset':
sampler = construct_sampler(node.get('sampler'))


+ 2
- 3
mindspore/dataset/engine/validators.py View File

@@ -443,9 +443,8 @@ def check_vocdataset(method):
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)

nreq_param_int = ['num_samples', 'num_parallel_workers']
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
nreq_param_str = ['distribution']

# check dataset_dir; required argument
dataset_dir = param_dict.get('dataset_dir')
@@ -457,7 +456,7 @@ def check_vocdataset(method):

check_param_type(nreq_param_bool, param_dict, bool)

check_param_type(nreq_param_str, param_dict, str)
check_sampler_shuffle_shard_options(param_dict)

return method(*args, **kwargs)



Loading…
Cancel
Save