|
|
@@ -3335,14 +3335,17 @@ class VOCDataset(SourceDataset): |
|
|
decode (bool, optional): Decode the images after reading (default=False). |
|
|
decode (bool, optional): Decode the images after reading (default=False). |
|
|
sampler (Sampler, optional): Object used to choose samples from the dataset |
|
|
sampler (Sampler, optional): Object used to choose samples from the dataset |
|
|
(default=None, expected order behavior shown in the table). |
|
|
(default=None, expected order behavior shown in the table). |
|
|
distribution (str, optional): Path to the json distribution file to configure |
|
|
|
|
|
dataset sharding (default=None). This argument should be specified |
|
|
|
|
|
only when no 'sampler' is used. |
|
|
|
|
|
|
|
|
num_shards (int, optional): Number of shards that the dataset should be divided |
|
|
|
|
|
into (default=None). |
|
|
|
|
|
shard_id (int, optional): The shard ID within num_shards (default=None). This |
|
|
|
|
|
argument should be specified only when num_shards is also specified. |
|
|
|
|
|
|
|
|
Raises: |
|
|
Raises: |
|
|
RuntimeError: If distribution and sampler are specified at the same time. |
|
|
|
|
|
RuntimeError: If distribution is failed to read. |
|
|
|
|
|
RuntimeError: If shuffle and sampler are specified at the same time. |
|
|
|
|
|
|
|
|
RuntimeError: If sampler and shuffle are specified at the same time. |
|
|
|
|
|
RuntimeError: If sampler and sharding are specified at the same time. |
|
|
|
|
|
RuntimeError: If num_shards is specified but shard_id is None. |
|
|
|
|
|
RuntimeError: If shard_id is specified but num_shards is None. |
|
|
|
|
|
ValueError: If shard_id is invalid (< 0 or >= num_shards). |
|
|
|
|
|
|
|
|
Examples: |
|
|
Examples: |
|
|
>>> import mindspore.dataset as ds |
|
|
>>> import mindspore.dataset as ds |
|
|
@@ -3356,27 +3359,15 @@ class VOCDataset(SourceDataset): |
|
|
|
|
|
|
|
|
@check_vocdataset |
|
|
@check_vocdataset |
|
|
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, |
|
|
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, |
|
|
shuffle=None, decode=False, sampler=None, distribution=None): |
|
|
|
|
|
|
|
|
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): |
|
|
super().__init__(num_parallel_workers) |
|
|
super().__init__(num_parallel_workers) |
|
|
self.dataset_dir = dataset_dir |
|
|
self.dataset_dir = dataset_dir |
|
|
self.sampler = sampler |
|
|
|
|
|
if distribution is not None: |
|
|
|
|
|
if sampler is not None: |
|
|
|
|
|
raise RuntimeError("Cannot specify distribution and sampler at the same time.") |
|
|
|
|
|
try: |
|
|
|
|
|
with open(distribution, 'r') as load_d: |
|
|
|
|
|
json.load(load_d) |
|
|
|
|
|
except json.decoder.JSONDecodeError: |
|
|
|
|
|
raise RuntimeError("Json decode error when load distribution file") |
|
|
|
|
|
except Exception: |
|
|
|
|
|
raise RuntimeError("Distribution file has failed to load.") |
|
|
|
|
|
elif shuffle is not None: |
|
|
|
|
|
if sampler is not None: |
|
|
|
|
|
raise RuntimeError("Cannot specify shuffle and sampler at the same time.") |
|
|
|
|
|
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
self.num_samples = num_samples |
|
|
self.num_samples = num_samples |
|
|
self.decode = decode |
|
|
self.decode = decode |
|
|
self.distribution = distribution |
|
|
|
|
|
self.shuffle_level = shuffle |
|
|
self.shuffle_level = shuffle |
|
|
|
|
|
self.num_shards = num_shards |
|
|
|
|
|
self.shard_id = shard_id |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
args = super().get_args() |
|
|
args = super().get_args() |
|
|
@@ -3385,7 +3376,8 @@ class VOCDataset(SourceDataset): |
|
|
args["sampler"] = self.sampler |
|
|
args["sampler"] = self.sampler |
|
|
args["decode"] = self.decode |
|
|
args["decode"] = self.decode |
|
|
args["shuffle"] = self.shuffle_level |
|
|
args["shuffle"] = self.shuffle_level |
|
|
args["distribution"] = self.distribution |
|
|
|
|
|
|
|
|
args["num_shards"] = self.num_shards |
|
|
|
|
|
args["shard_id"] = self.shard_id |
|
|
return args |
|
|
return args |
|
|
|
|
|
|
|
|
def get_dataset_size(self): |
|
|
def get_dataset_size(self): |
|
|
|