Browse Source

User explicit deepcopy

tags/v1.2.0-rc1
hesham 5 years ago
parent
commit
58193bc469
2 changed files with 21 additions and 6 deletions
  1. +7
    -6
      mindspore/dataset/engine/datasets.py
  2. +14
    -0
      tests/ut/python/dataset/test_datasets_generator.py

+ 7
- 6
mindspore/dataset/engine/datasets.py View File

@@ -3433,6 +3433,7 @@ class GeneratorDataset(MappableDataset):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
self.source = source
self.prepared_source = None # source to be sent to C++

self.python_multiprocessing = python_multiprocessing

@@ -3463,9 +3464,9 @@ class GeneratorDataset(MappableDataset):
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else:
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
new_op.sample_fn = sample_fn
else:
try:
@@ -3476,11 +3477,11 @@ class GeneratorDataset(MappableDataset):
iter(self.source)
except TypeError:
# Use generator function if input callable
new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples))
new_op.prepared_source = (lambda: _generator_fn(self.source, new_op.num_samples))
else:
# Use iterator function if input is iterable
# Random accessible input is also iterable
new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples))
new_op.prepared_source = (lambda: _iter_fn(self.source, new_op.num_samples))

return new_op

@@ -3492,12 +3493,12 @@ class GeneratorDataset(MappableDataset):

def parse(self, children=None):
if self.schema is None:
return cde.GeneratorNode(self.source, self.column_names, self.column_types, self.source_len,
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
self.sampler)
schema = self.schema
if isinstance(schema, Schema):
schema = self.schema.cpp_schema
return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler)
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler)


class TFRecordDataset(SourceDataset):


+ 14
- 0
tests/ut/python/dataset/test_datasets_generator.py View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
import numpy as np
import pytest

@@ -745,6 +746,18 @@ def manual_test_generator_keyboard_interrupt():
pass


def test_explicit_deepcopy():
"""
Test explicit_deepcopy
"""
logger.info("Test explicit_deepcopy")

ds1 = ds.NumpySlicesDataset([1, 2], shuffle=False)
ds2 = copy.deepcopy(ds1)
for d1, d2 in zip(ds1, ds2):
assert d1 == d2


if __name__ == "__main__":
test_generator_0()
test_generator_1()
@@ -780,3 +793,4 @@ if __name__ == "__main__":
test_generator_dataset_size_3()
test_generator_dataset_size_4()
test_generator_dataset_size_5()
test_explicit_deepcopy()

Loading…
Cancel
Save