Browse Source

!3334 dataset: add param check for device_que and to_device

Merge pull request !3334 from ms_yan/device_que_param
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
d874150fb3
2 changed files with 22 additions and 1 deletions
  1. +3
    -1
      mindspore/dataset/engine/datasets.py
  2. +19
    -0
      mindspore/dataset/engine/validators.py

+ 3
- 1
mindspore/dataset/engine/datasets.py View File

@@ -40,7 +40,7 @@ from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, \
check_rename, check_numpyslicesdataset, check_device_send, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
@@ -953,6 +953,7 @@ class Dataset:
raise TypeError("apply_func must return a dataset.")
return dataset

@check_device_send
def device_que(self, prefetch_size=None, send_epoch_end=True):
"""
Return a transferredDataset that transfer data through device.
@@ -971,6 +972,7 @@ class Dataset:
"""
return self.to_device(send_epoch_end=send_epoch_end)

@check_device_send
def to_device(self, send_epoch_end=True):
"""
Transfer data through CPU, GPU or Ascend devices.


+ 19
- 0
mindspore/dataset/engine/validators.py View File

@@ -652,6 +652,25 @@ def check_positive_int32(method):
return new_method


def check_device_send(method):
"""check the input argument for to_device and device_que."""

@wraps(method)
def new_method(self, *args, **kwargs):
param, param_dict = parse_user_args(method, *args, **kwargs)
para_list = list(param_dict.keys())
if "prefetch_size" in para_list:
if param[0] is not None:
check_pos_int32(param[0], "prefetch_size")
type_check(param[1], (bool,), "send_epoch_end")
else:
type_check(param[0], (bool,), "send_epoch_end")

return method(self, *args, **kwargs)

return new_method


def check_zip(method):
"""check the input arguments of zip."""



Loading…
Cancel
Save