From 5b035408f613dcec5fece207e4b145a2a7ae0ae7 Mon Sep 17 00:00:00 2001 From: liyong Date: Thu, 26 Nov 2020 14:48:47 +0800 Subject: [PATCH] fix deepycopy bug --- mindspore/dataset/engine/datasets.py | 36 ++++++++++++++++++++++++++++ mindspore/dataset/engine/samplers.py | 1 + 2 files changed, 37 insertions(+) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 72fb8c45c5..7bad45359a 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1889,6 +1889,26 @@ class BatchDataset(Dataset): for input_dataset in dataset.children: BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_op = cls.__new__(cls) + memodict[id(self)] = new_op + new_op.children = copy.deepcopy(self.children, memodict) + new_op.parent = copy.deepcopy(self.parent, memodict) + new_op.num_parallel_workers = self.num_parallel_workers + new_op.batch_size = self.batch_size + new_op.batch_size_func = self.batch_size_func + new_op.drop_remainder = self.drop_remainder + new_op.per_batch_map = self.per_batch_map + new_op.input_columns = copy.deepcopy(self.input_columns, memodict) + new_op.output_columns = copy.deepcopy(self.output_columns, memodict) + new_op.column_order = copy.deepcopy(self.column_order, memodict) + new_op.pad = self.pad + new_op.pad_info = copy.deepcopy(self.pad_info, memodict) + return new_op + class BatchInfo(cde.CBatchInfo): """ @@ -2751,6 +2771,22 @@ class TransferDataset(Dataset): if self._to_device is not None: self._to_device.release() + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_op = cls.__new__(cls) + memodict[id(self)] = new_op + new_op.children = copy.deepcopy(self.children, memodict) + new_op.parent = copy.deepcopy(self.parent, memodict) + new_op.num_parallel_workers = self.num_parallel_workers + new_op.queue_name = self.queue_name + new_op.device_type = self.device_type + new_op._send_epoch_end = self._send_epoch_end # pylint: disable=W0212 + new_op._create_data_info_queue = self._create_data_info_queue # pylint: disable=W0212 + + return new_op + class RangeDataset(MappableDataset): """ diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 58005443a4..bebc632203 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -312,6 +312,7 @@ class PKSampler(BuiltinSampler): Args: num_val (int): Number of elements to sample for each class. num_class (int, optional): Number of classes to sample (default=None, all classes). + The parameter does not supported to specify currently. shuffle (bool, optional): If True, the class IDs are shuffled (default=False). class_column (str, optional): Name of column with class labels for MindDataset (default='label'). num_samples (int, optional): The number of samples to draw (default=None, all elements).