|
|
|
@@ -40,7 +40,7 @@ from mindspore import log as logger |
|
|
|
from . import samplers |
|
|
|
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp |
|
|
|
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ |
|
|
|
check_rename, check_numpyslicesdataset, \ |
|
|
|
check_rename, check_numpyslicesdataset, check_device_send, \ |
|
|
|
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ |
|
|
|
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ |
|
|
|
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ |
|
|
|
@@ -953,6 +953,7 @@ class Dataset: |
|
|
|
raise TypeError("apply_func must return a dataset.") |
|
|
|
return dataset |
|
|
|
|
|
|
|
@check_device_send |
|
|
|
def device_que(self, prefetch_size=None, send_epoch_end=True): |
|
|
|
""" |
|
|
|
Return a transferredDataset that transfer data through device. |
|
|
|
@@ -971,6 +972,7 @@ class Dataset: |
|
|
|
""" |
|
|
|
return self.to_device(send_epoch_end=send_epoch_end) |
|
|
|
|
|
|
|
@check_device_send |
|
|
|
def to_device(self, send_epoch_end=True): |
|
|
|
""" |
|
|
|
Transfer data through CPU, GPU or Ascend devices. |
|
|
|
|