|
|
|
@@ -30,7 +30,9 @@ from enum import Enum |
|
|
|
from importlib import import_module |
|
|
|
import threading |
|
|
|
|
|
|
|
import copy |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ |
|
|
|
MindRecordOp, TextFileOp, CBatchInfo |
|
|
|
from mindspore._c_expression import typing |
|
|
|
@@ -1376,6 +1378,23 @@ class MapDataset(DatasetOp): |
|
|
|
""" |
|
|
|
return self.input[0].get_dataset_size() |
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
|
if id(self) in memodict: |
|
|
|
return memodict[id(self)] |
|
|
|
cls = self.__class__ |
|
|
|
new_op = cls.__new__(cls) |
|
|
|
memodict[id(self)] = new_op |
|
|
|
new_op.input = copy.deepcopy(self.input, memodict) |
|
|
|
new_op.input_columns = copy.deepcopy(self.input_columns, memodict) |
|
|
|
new_op.output_columns = copy.deepcopy(self.output_columns, memodict) |
|
|
|
new_op.columns_order = copy.deepcopy(self.columns_order, memodict) |
|
|
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) |
|
|
|
new_op.output = copy.deepcopy(self.output, memodict) |
|
|
|
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) |
|
|
|
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) |
|
|
|
new_op.operations = self.operations |
|
|
|
return new_op |
|
|
|
|
|
|
|
# Iterator bootstrap will be called on iterator construction. |
|
|
|
# A deep copy of Dataset object is created prior of iterator_bootstrap. |
|
|
|
# This method will create per iterator process pool and bind pyfunc execution to the pool. |
|
|
|
@@ -2600,6 +2619,23 @@ class GeneratorDataset(SourceDataset): |
|
|
|
else: |
|
|
|
raise ValueError('set dataset_size with negative value {}'.format(value)) |
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
|
if id(self) in memodict: |
|
|
|
return memodict[id(self)] |
|
|
|
cls = self.__class__ |
|
|
|
new_op = cls.__new__(cls) |
|
|
|
memodict[id(self)] = new_op |
|
|
|
new_op.input = copy.deepcopy(self.input, memodict) |
|
|
|
new_op.output = copy.deepcopy(self.output, memodict) |
|
|
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) |
|
|
|
new_op.column_types = copy.deepcopy(self.column_types, memodict) |
|
|
|
new_op.column_names = copy.deepcopy(self.column_names, memodict) |
|
|
|
|
|
|
|
new_op.source = self.source |
|
|
|
new_op.sampler = self.sampler |
|
|
|
|
|
|
|
return new_op |
|
|
|
|
|
|
|
|
|
|
|
class TFRecordDataset(SourceDataset): |
|
|
|
""" |
|
|
|
|