|
|
|
@@ -1889,6 +1889,26 @@ class BatchDataset(Dataset): |
|
|
|
for input_dataset in dataset.children: |
|
|
|
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) |
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
|
if id(self) in memodict: |
|
|
|
return memodict[id(self)] |
|
|
|
cls = self.__class__ |
|
|
|
new_op = cls.__new__(cls) |
|
|
|
memodict[id(self)] = new_op |
|
|
|
new_op.children = copy.deepcopy(self.children, memodict) |
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict) |
|
|
|
new_op.num_parallel_workers = self.num_parallel_workers |
|
|
|
new_op.batch_size = self.batch_size |
|
|
|
new_op.batch_size_func = self.batch_size_func |
|
|
|
new_op.drop_remainder = self.drop_remainder |
|
|
|
new_op.per_batch_map = self.per_batch_map |
|
|
|
new_op.input_columns = copy.deepcopy(self.input_columns, memodict) |
|
|
|
new_op.output_columns = copy.deepcopy(self.output_columns, memodict) |
|
|
|
new_op.column_order = copy.deepcopy(self.column_order, memodict) |
|
|
|
new_op.pad = self.pad |
|
|
|
new_op.pad_info = copy.deepcopy(self.pad_info, memodict) |
|
|
|
return new_op |
|
|
|
|
|
|
|
|
|
|
|
class BatchInfo(cde.CBatchInfo): |
|
|
|
""" |
|
|
|
@@ -2753,6 +2773,22 @@ class TransferDataset(Dataset): |
|
|
|
if self._to_device is not None: |
|
|
|
self._to_device.release() |
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
|
if id(self) in memodict: |
|
|
|
return memodict[id(self)] |
|
|
|
cls = self.__class__ |
|
|
|
new_op = cls.__new__(cls) |
|
|
|
memodict[id(self)] = new_op |
|
|
|
new_op.children = copy.deepcopy(self.children, memodict) |
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict) |
|
|
|
new_op.num_parallel_workers = self.num_parallel_workers |
|
|
|
new_op.queue_name = self.queue_name |
|
|
|
new_op.device_type = self.device_type |
|
|
|
new_op._send_epoch_end = self._send_epoch_end # pylint: disable=W0212 |
|
|
|
new_op._create_data_info_queue = self._create_data_info_queue # pylint: disable=W0212 |
|
|
|
|
|
|
|
return new_op |
|
|
|
|
|
|
|
|
|
|
|
class RangeDataset(MappableDataset): |
|
|
|
""" |
|
|
|
|