Browse Source

fix generator dataset np copy problem

tags/v0.3.0-alpha
Yanjun Peng 5 years ago
parent
commit
4bf016e67f
1 changed files with 13 additions and 13 deletions
  1. +13
    -13
      mindspore/dataset/engine/datasets.py

+ 13
- 13
mindspore/dataset/engine/datasets.py View File

@@ -2294,11 +2294,11 @@ def _iter_fn(dataset, num_samples):
except StopIteration:
return
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
yield tuple([np.array(x, copy=False) for x in val])
else:
for val in dataset:
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
yield tuple([np.array(x, copy=False) for x in val])


def _generator_fn(generator, num_samples):
@@ -2332,12 +2332,12 @@ def _py_sampler_fn(sampler, num_samples, dataset):
return
val = dataset[idx]
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
yield tuple([np.array(x, copy=False) for x in val])
else:
for i in sampler:
val = dataset[i]
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
yield tuple([np.array(x, copy=False) for x in val])


def _cpp_sampler_fn(sampler, dataset):
@@ -2348,7 +2348,7 @@ def _cpp_sampler_fn(sampler, dataset):
for i in indices:
val = dataset[i]
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
yield tuple([np.array(x, copy=False) for x in val])


def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
@@ -2437,7 +2437,7 @@ def _sampler_fn_mp(indices, dataset, num_worker):
# Set eoe event once all indices are sent
if idx_cursor == len(indices) and not eoe.is_set():
eoe.set()
yield tuple([np.array(x) for x in result])
yield tuple([np.array(x, copy=False) for x in result])


def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
@@ -2549,35 +2549,35 @@ class GeneratorDataset(SourceDataset):
when num_shards is also specified. Random accessible input is required.

Examples:
>>> import mindspore.dataengine as de
>>> import mindspore.dataset as ds
>>> # 1) Multidimensional generator function as callable input
>>> def generator_md():
>>> for i in range(64):
>>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
>>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
>>> multi_dimension_generator_dataset = de.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> # 2) Multi-column generator function as callable input
>>> def generator_mc(maxid = 64):
>>> for i in range(maxid):
>>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
>>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
>>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1", "col2"])
>>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1", "col2"])
>>> # 3) Iterable dataset as iterable input
>>> class MyIterable():
>>> def __iter__(self):
>>> return # User implementation
>>> # create iterable_generator_dataset with MyIterable object
>>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"])
>>> iterable_generator_dataset = ds.GeneratorDataset(MyIterable(), ["col1"])
>>> # 4) Random accessible dataset as Random accessible input
>>> class MyRA():
>>> def __getitem__(self, index):
>>> return # User implementation
>>> # create ra_generator_dataset with MyRA object
>>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"])
>>> ra_generator_dataset = ds.GeneratorDataset(MyRA(), ["col1"])
>>> # List/Dict/Tuple is also random accessible
>>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
>>> list_generator = ds.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
>>> # 5) Built-in Sampler
>>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
>>> my_generator = ds.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
>>>
"""



Loading…
Cancel
Save