Browse Source

!466 Deepcopy problem when pyfunc cannot be pickled

Merge pull request !466 from h.farahat/deepcopy
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
0e3054d527
2 changed files with 67 additions and 1 deletions
  1. +36
    -0
      mindspore/dataset/engine/datasets.py
  2. +31
    -1
      tests/ut/python/dataset/test_iterator.py

+ 36
- 0
mindspore/dataset/engine/datasets.py View File

@@ -30,7 +30,9 @@ from enum import Enum
from importlib import import_module
import threading

import copy
import numpy as np

from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
MindRecordOp, TextFileOp, CBatchInfo
from mindspore._c_expression import typing
@@ -1376,6 +1378,23 @@ class MapDataset(DatasetOp):
"""
return self.input[0].get_dataset_size()

def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict)
new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
new_op.columns_order = copy.deepcopy(self.columns_order, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.output = copy.deepcopy(self.output, memodict)
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.operations = self.operations
return new_op

# Iterator bootstrap will be called on iterator construction.
# A deep copy of Dataset object is created prior of iterator_bootstrap.
# This method will create per iterator process pool and bind pyfunc execution to the pool.
@@ -2600,6 +2619,23 @@ class GeneratorDataset(SourceDataset):
else:
raise ValueError('set dataset_size with negative value {}'.format(value))

def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict)
new_op.output = copy.deepcopy(self.output, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.column_types = copy.deepcopy(self.column_types, memodict)
new_op.column_names = copy.deepcopy(self.column_names, memodict)

new_op.source = self.source
new_op.sampler = self.sampler

return new_op


class TFRecordDataset(SourceDataset):
"""


+ 31
- 1
tests/ut/python/dataset/test_iterator.py View File

@@ -14,7 +14,7 @@
# ==============================================================================
import numpy as np
import pytest
import copy
import mindspore.dataset as ds
from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup

@@ -81,3 +81,33 @@ def test_iterator_weak_ref():
assert sum(itr() is not None for itr in ITERATORS_LIST) == 2

_cleanup()


class MyDict(dict):
def __getattr__(self, key):
return self[key]

def __setattr__(self, key, value):
self[key] = value

def __call__(self, t):
return t


def test_tree_copy():
# Testing copying the tree with a pyfunc that cannot be pickled

data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS)
data1 = data.map(operations=[MyDict()])

itr = data1.create_tuple_iterator()

assert id(data1) != id(itr.dataset)
assert id(data) != id(itr.dataset.input[0])
assert id(data1.operations[0]) == id(itr.dataset.operations[0])

itr.release()


if __name__ == '__main__':
test_tree_copy()

Loading…
Cancel
Save