|
|
@@ -134,8 +134,8 @@ class Dataset: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, num_parallel_workers=None): |
|
|
def __init__(self, num_parallel_workers=None): |
|
|
self.input = [] |
|
|
|
|
|
self.output = [] |
|
|
|
|
|
|
|
|
self.children = [] |
|
|
|
|
|
self.parent = [] |
|
|
self.num_parallel_workers = num_parallel_workers |
|
|
self.num_parallel_workers = num_parallel_workers |
|
|
self._device_iter = 0 |
|
|
self._device_iter = 0 |
|
|
self._input_indexs = () |
|
|
self._input_indexs = () |
|
|
@@ -1006,9 +1006,9 @@ class Dataset: |
|
|
dev_id = output_dataset.shard_id |
|
|
dev_id = output_dataset.shard_id |
|
|
return "", dev_id |
|
|
return "", dev_id |
|
|
|
|
|
|
|
|
if not output_dataset.input: |
|
|
|
|
|
|
|
|
if not output_dataset.children: |
|
|
raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset))) |
|
|
raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset))) |
|
|
input_dataset = output_dataset.input[0] |
|
|
|
|
|
|
|
|
input_dataset = output_dataset.children[0] |
|
|
return get_distribution(input_dataset) |
|
|
return get_distribution(input_dataset) |
|
|
|
|
|
|
|
|
distribution_path, device_id = get_distribution(self) |
|
|
distribution_path, device_id = get_distribution(self) |
|
|
@@ -1129,8 +1129,8 @@ class Dataset: |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self.input: |
|
|
|
|
|
return self.input[0].get_dataset_size() |
|
|
|
|
|
|
|
|
if self.children: |
|
|
|
|
|
return self.children[0].get_dataset_size() |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
def num_classes(self): |
|
|
def num_classes(self): |
|
|
@@ -1140,23 +1140,23 @@ class Dataset: |
|
|
Return: |
|
|
Return: |
|
|
Number, number of classes. |
|
|
Number, number of classes. |
|
|
""" |
|
|
""" |
|
|
if self.input: |
|
|
|
|
|
return self.input[0].num_classes() |
|
|
|
|
|
|
|
|
if self.children: |
|
|
|
|
|
return self.children[0].num_classes() |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
def get_sync_notifiers(self): |
|
|
def get_sync_notifiers(self): |
|
|
if self.input: |
|
|
|
|
|
return self.input[0].get_sync_notifiers() |
|
|
|
|
|
|
|
|
if self.children: |
|
|
|
|
|
return self.children[0].get_sync_notifiers() |
|
|
return {} |
|
|
return {} |
|
|
|
|
|
|
|
|
def disable_sync(self): |
|
|
def disable_sync(self): |
|
|
if self.input: |
|
|
|
|
|
return self.input[0].disable_sync() |
|
|
|
|
|
|
|
|
if self.children: |
|
|
|
|
|
return self.children[0].disable_sync() |
|
|
return {} |
|
|
return {} |
|
|
|
|
|
|
|
|
def is_sync(self): |
|
|
def is_sync(self): |
|
|
if self.input: |
|
|
|
|
|
return self.input[0].is_sync() |
|
|
|
|
|
|
|
|
if self.children: |
|
|
|
|
|
return self.children[0].is_sync() |
|
|
return False |
|
|
return False |
|
|
|
|
|
|
|
|
def sync_update(self, condition_name, num_batch=None, data=None): |
|
|
def sync_update(self, condition_name, num_batch=None, data=None): |
|
|
@@ -1190,8 +1190,8 @@ class Dataset: |
|
|
Return: |
|
|
Return: |
|
|
Number, the number of data in a batch. |
|
|
Number, the number of data in a batch. |
|
|
""" |
|
|
""" |
|
|
if self.input: |
|
|
|
|
|
return self.input[0].get_batch_size() |
|
|
|
|
|
|
|
|
if self.children: |
|
|
|
|
|
return self.children[0].get_batch_size() |
|
|
return 1 |
|
|
return 1 |
|
|
|
|
|
|
|
|
def get_repeat_count(self): |
|
|
def get_repeat_count(self): |
|
|
@@ -1201,8 +1201,8 @@ class Dataset: |
|
|
Return: |
|
|
Return: |
|
|
Number, the count of repeat. |
|
|
Number, the count of repeat. |
|
|
""" |
|
|
""" |
|
|
if self.input: |
|
|
|
|
|
return self.input[0].get_repeat_count() |
|
|
|
|
|
|
|
|
if self.children: |
|
|
|
|
|
return self.children[0].get_repeat_count() |
|
|
return 1 |
|
|
return 1 |
|
|
|
|
|
|
|
|
def get_class_indexing(self): |
|
|
def get_class_indexing(self): |
|
|
@@ -1212,22 +1212,22 @@ class Dataset: |
|
|
Return: |
|
|
Return: |
|
|
Dict, A str-to-int mapping from label name to index. |
|
|
Dict, A str-to-int mapping from label name to index. |
|
|
""" |
|
|
""" |
|
|
if self.input: |
|
|
|
|
|
return self.input[0].get_class_indexing() |
|
|
|
|
|
|
|
|
if self.children: |
|
|
|
|
|
return self.children[0].get_class_indexing() |
|
|
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self))) |
|
|
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self))) |
|
|
|
|
|
|
|
|
def reset(self): |
|
|
def reset(self): |
|
|
"""Reset the dataset for next epoch.""" |
|
|
"""Reset the dataset for next epoch.""" |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
for input_dataset in self.input: |
|
|
|
|
|
|
|
|
for input_dataset in self.children: |
|
|
if input_dataset.is_shuffled(): |
|
|
if input_dataset.is_shuffled(): |
|
|
return True |
|
|
return True |
|
|
|
|
|
|
|
|
return False |
|
|
return False |
|
|
|
|
|
|
|
|
def is_sharded(self): |
|
|
def is_sharded(self): |
|
|
for input_dataset in self.input: |
|
|
|
|
|
|
|
|
for input_dataset in self.children: |
|
|
if input_dataset.is_sharded(): |
|
|
if input_dataset.is_sharded(): |
|
|
return True |
|
|
return True |
|
|
|
|
|
|
|
|
@@ -1466,8 +1466,8 @@ class BucketBatchByLengthDataset(DatasetOp): |
|
|
self.pad_to_bucket_boundary = pad_to_bucket_boundary |
|
|
self.pad_to_bucket_boundary = pad_to_bucket_boundary |
|
|
self.drop_remainder = drop_remainder |
|
|
self.drop_remainder = drop_remainder |
|
|
|
|
|
|
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
@@ -1529,8 +1529,8 @@ class BatchDataset(DatasetOp): |
|
|
self.per_batch_map = per_batch_map |
|
|
self.per_batch_map = per_batch_map |
|
|
self.input_columns = input_columns |
|
|
self.input_columns = input_columns |
|
|
self.pad_info = pad_info |
|
|
self.pad_info = pad_info |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
@@ -1549,7 +1549,7 @@ class BatchDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
child_size = self.input[0].get_dataset_size() |
|
|
|
|
|
|
|
|
child_size = self.children[0].get_dataset_size() |
|
|
if child_size is not None: |
|
|
if child_size is not None: |
|
|
if self.drop_remainder: |
|
|
if self.drop_remainder: |
|
|
return math.floor(child_size / self.batch_size) |
|
|
return math.floor(child_size / self.batch_size) |
|
|
@@ -1578,7 +1578,7 @@ class BatchDataset(DatasetOp): |
|
|
if isinstance(dataset, RepeatDataset): |
|
|
if isinstance(dataset, RepeatDataset): |
|
|
return True |
|
|
return True |
|
|
flag = False |
|
|
flag = False |
|
|
for input_dataset in dataset.input: |
|
|
|
|
|
|
|
|
for input_dataset in dataset.children: |
|
|
flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) |
|
|
flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) |
|
|
return flag |
|
|
return flag |
|
|
|
|
|
|
|
|
@@ -1593,7 +1593,7 @@ class BatchDataset(DatasetOp): |
|
|
""" |
|
|
""" |
|
|
if isinstance(dataset, SyncWaitDataset): |
|
|
if isinstance(dataset, SyncWaitDataset): |
|
|
dataset.update_sync_batch_size(batch_size) |
|
|
dataset.update_sync_batch_size(batch_size) |
|
|
for input_dataset in dataset.input: |
|
|
|
|
|
|
|
|
for input_dataset in dataset.children: |
|
|
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) |
|
|
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1699,21 +1699,21 @@ class SyncWaitDataset(DatasetOp): |
|
|
|
|
|
|
|
|
def __init__(self, input_dataset, condition_name, num_batch, callback=None): |
|
|
def __init__(self, input_dataset, condition_name, num_batch, callback=None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
# set to the default value, waiting for the batch to update it |
|
|
# set to the default value, waiting for the batch to update it |
|
|
self._condition_name = condition_name |
|
|
self._condition_name = condition_name |
|
|
if isinstance(num_batch, int) and num_batch <= 0: |
|
|
if isinstance(num_batch, int) and num_batch <= 0: |
|
|
raise ValueError("num_batch need to be greater than 0.") |
|
|
raise ValueError("num_batch need to be greater than 0.") |
|
|
|
|
|
|
|
|
self._pair = BlockReleasePair(num_batch, callback) |
|
|
self._pair = BlockReleasePair(num_batch, callback) |
|
|
if self._condition_name in self.input[0].get_sync_notifiers(): |
|
|
|
|
|
|
|
|
if self._condition_name in self.children[0].get_sync_notifiers(): |
|
|
raise RuntimeError("Condition name is already in use") |
|
|
raise RuntimeError("Condition name is already in use") |
|
|
logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging", |
|
|
logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging", |
|
|
condition_name) |
|
|
condition_name) |
|
|
|
|
|
|
|
|
def get_sync_notifiers(self): |
|
|
def get_sync_notifiers(self): |
|
|
return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} |
|
|
|
|
|
|
|
|
return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} |
|
|
|
|
|
|
|
|
def is_sync(self): |
|
|
def is_sync(self): |
|
|
return True |
|
|
return True |
|
|
@@ -1746,7 +1746,7 @@ class SyncWaitDataset(DatasetOp): |
|
|
if isinstance(dataset, BatchDataset): |
|
|
if isinstance(dataset, BatchDataset): |
|
|
return True |
|
|
return True |
|
|
flag = False |
|
|
flag = False |
|
|
for input_dataset in dataset.input: |
|
|
|
|
|
|
|
|
for input_dataset in dataset.children: |
|
|
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) |
|
|
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) |
|
|
return flag |
|
|
return flag |
|
|
|
|
|
|
|
|
@@ -1766,9 +1766,9 @@ class ShuffleDataset(DatasetOp): |
|
|
def __init__(self, input_dataset, buffer_size): |
|
|
def __init__(self, input_dataset, buffer_size): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.buffer_size = buffer_size |
|
|
self.buffer_size = buffer_size |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
self.reshuffle_each_epoch = None |
|
|
self.reshuffle_each_epoch = None |
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
if self.is_sync(): |
|
|
if self.is_sync(): |
|
|
raise RuntimeError("No shuffle after sync operators") |
|
|
raise RuntimeError("No shuffle after sync operators") |
|
|
@@ -1864,7 +1864,7 @@ class MapDataset(DatasetOp): |
|
|
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, |
|
|
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, |
|
|
num_parallel_workers=None, python_multiprocessing=False): |
|
|
num_parallel_workers=None, python_multiprocessing=False): |
|
|
super().__init__(num_parallel_workers) |
|
|
super().__init__(num_parallel_workers) |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
if input_columns is not None and not isinstance(input_columns, list): |
|
|
if input_columns is not None and not isinstance(input_columns, list): |
|
|
input_columns = [input_columns] |
|
|
input_columns = [input_columns] |
|
|
self.input_columns = input_columns |
|
|
self.input_columns = input_columns |
|
|
@@ -1881,7 +1881,7 @@ class MapDataset(DatasetOp): |
|
|
and self.columns_order is None: |
|
|
and self.columns_order is None: |
|
|
raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.") |
|
|
raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.") |
|
|
|
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self.python_multiprocessing = python_multiprocessing |
|
|
self.python_multiprocessing = python_multiprocessing |
|
|
self.process_pool = None |
|
|
self.process_pool = None |
|
|
@@ -1901,7 +1901,7 @@ class MapDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
return self.input[0].get_dataset_size() |
|
|
|
|
|
|
|
|
return self.children[0].get_dataset_size() |
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
def __deepcopy__(self, memodict): |
|
|
if id(self) in memodict: |
|
|
if id(self) in memodict: |
|
|
@@ -1909,12 +1909,12 @@ class MapDataset(DatasetOp): |
|
|
cls = self.__class__ |
|
|
cls = self.__class__ |
|
|
new_op = cls.__new__(cls) |
|
|
new_op = cls.__new__(cls) |
|
|
memodict[id(self)] = new_op |
|
|
memodict[id(self)] = new_op |
|
|
new_op.input = copy.deepcopy(self.input, memodict) |
|
|
|
|
|
|
|
|
new_op.children = copy.deepcopy(self.children, memodict) |
|
|
new_op.input_columns = copy.deepcopy(self.input_columns, memodict) |
|
|
new_op.input_columns = copy.deepcopy(self.input_columns, memodict) |
|
|
new_op.output_columns = copy.deepcopy(self.output_columns, memodict) |
|
|
new_op.output_columns = copy.deepcopy(self.output_columns, memodict) |
|
|
new_op.columns_order = copy.deepcopy(self.columns_order, memodict) |
|
|
new_op.columns_order = copy.deepcopy(self.columns_order, memodict) |
|
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) |
|
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) |
|
|
new_op.output = copy.deepcopy(self.output, memodict) |
|
|
|
|
|
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict) |
|
|
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) |
|
|
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) |
|
|
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) |
|
|
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) |
|
|
new_op.operations = self.operations |
|
|
new_op.operations = self.operations |
|
|
@@ -1975,8 +1975,8 @@ class FilterDataset(DatasetOp): |
|
|
def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): |
|
|
def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): |
|
|
super().__init__(num_parallel_workers) |
|
|
super().__init__(num_parallel_workers) |
|
|
self.predicate = lambda *args: bool(predicate(*args)) |
|
|
self.predicate = lambda *args: bool(predicate(*args)) |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
if input_columns is not None and not isinstance(input_columns, list): |
|
|
if input_columns is not None and not isinstance(input_columns, list): |
|
|
input_columns = [input_columns] |
|
|
input_columns = [input_columns] |
|
|
self.input_columns = input_columns |
|
|
self.input_columns = input_columns |
|
|
@@ -2012,8 +2012,8 @@ class RepeatDataset(DatasetOp): |
|
|
self.count = -1 |
|
|
self.count = -1 |
|
|
else: |
|
|
else: |
|
|
self.count = count |
|
|
self.count = count |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
@@ -2028,7 +2028,7 @@ class RepeatDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
child_size = self.input[0].get_dataset_size() |
|
|
|
|
|
|
|
|
child_size = self.children[0].get_dataset_size() |
|
|
if child_size is not None: |
|
|
if child_size is not None: |
|
|
return child_size |
|
|
return child_size |
|
|
return None |
|
|
return None |
|
|
@@ -2055,8 +2055,8 @@ class SkipDataset(DatasetOp): |
|
|
def __init__(self, input_dataset, count): |
|
|
def __init__(self, input_dataset, count): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.count = count |
|
|
self.count = count |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
@@ -2071,7 +2071,7 @@ class SkipDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
child_size = self.input[0].get_dataset_size() |
|
|
|
|
|
|
|
|
child_size = self.children[0].get_dataset_size() |
|
|
output_size = 0 |
|
|
output_size = 0 |
|
|
if self.count >= 0 and self.count < child_size: |
|
|
if self.count >= 0 and self.count < child_size: |
|
|
output_size = child_size - self.count |
|
|
output_size = child_size - self.count |
|
|
@@ -2090,8 +2090,8 @@ class TakeDataset(DatasetOp): |
|
|
def __init__(self, input_dataset, count): |
|
|
def __init__(self, input_dataset, count): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.count = count |
|
|
self.count = count |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
@@ -2106,7 +2106,7 @@ class TakeDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
child_size = self.input[0].get_dataset_size() |
|
|
|
|
|
|
|
|
child_size = self.children[0].get_dataset_size() |
|
|
if child_size < self.count: |
|
|
if child_size < self.count: |
|
|
return child_size |
|
|
return child_size |
|
|
return self.count |
|
|
return self.count |
|
|
@@ -2130,8 +2130,8 @@ class ZipDataset(DatasetOp): |
|
|
raise TypeError("The parameter %s of zip has type error!" % (dataset)) |
|
|
raise TypeError("The parameter %s of zip has type error!" % (dataset)) |
|
|
self.datasets = datasets |
|
|
self.datasets = datasets |
|
|
for data in datasets: |
|
|
for data in datasets: |
|
|
self.input.append(data) |
|
|
|
|
|
data.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(data) |
|
|
|
|
|
data.parent.append(self) |
|
|
|
|
|
|
|
|
def get_dataset_size(self): |
|
|
def get_dataset_size(self): |
|
|
""" |
|
|
""" |
|
|
@@ -2140,7 +2140,7 @@ class ZipDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
children_sizes = [c.get_dataset_size() for c in self.input] |
|
|
|
|
|
|
|
|
children_sizes = [c.get_dataset_size() for c in self.children] |
|
|
if all(c is not None for c in children_sizes): |
|
|
if all(c is not None for c in children_sizes): |
|
|
return min(children_sizes) |
|
|
return min(children_sizes) |
|
|
return None |
|
|
return None |
|
|
@@ -2155,7 +2155,7 @@ class ZipDataset(DatasetOp): |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
def is_sync(self): |
|
|
def is_sync(self): |
|
|
return any([c.is_sync() for c in self.input]) |
|
|
|
|
|
|
|
|
return any([c.is_sync() for c in self.children]) |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
args = super().get_args() |
|
|
args = super().get_args() |
|
|
@@ -2180,8 +2180,8 @@ class ConcatDataset(DatasetOp): |
|
|
raise TypeError("The parameter %s of concat has type error!" % (dataset)) |
|
|
raise TypeError("The parameter %s of concat has type error!" % (dataset)) |
|
|
self.datasets = datasets |
|
|
self.datasets = datasets |
|
|
for data in datasets: |
|
|
for data in datasets: |
|
|
self.input.append(data) |
|
|
|
|
|
data.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(data) |
|
|
|
|
|
data.parent.append(self) |
|
|
|
|
|
|
|
|
def get_dataset_size(self): |
|
|
def get_dataset_size(self): |
|
|
""" |
|
|
""" |
|
|
@@ -2190,7 +2190,7 @@ class ConcatDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
children_sizes = [c.get_dataset_size() for c in self.input] |
|
|
|
|
|
|
|
|
children_sizes = [c.get_dataset_size() for c in self.children] |
|
|
dataset_size = sum(children_sizes) |
|
|
dataset_size = sum(children_sizes) |
|
|
return dataset_size |
|
|
return dataset_size |
|
|
|
|
|
|
|
|
@@ -2213,8 +2213,8 @@ class RenameDataset(DatasetOp): |
|
|
output_columns = [output_columns] |
|
|
output_columns = [output_columns] |
|
|
self.input_column_names = input_columns |
|
|
self.input_column_names = input_columns |
|
|
self.output_column_names = output_columns |
|
|
self.output_column_names = output_columns |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
@@ -2240,10 +2240,10 @@ class ProjectDataset(DatasetOp): |
|
|
if not isinstance(columns, list): |
|
|
if not isinstance(columns, list): |
|
|
columns = [columns] |
|
|
columns = [columns] |
|
|
self.columns = columns |
|
|
self.columns = columns |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
self.prefetch_size = prefetch_size |
|
|
self.prefetch_size = prefetch_size |
|
|
|
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
@@ -2267,8 +2267,8 @@ class TransferDataset(DatasetOp): |
|
|
|
|
|
|
|
|
def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None): |
|
|
def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
self.queue_name = queue_name |
|
|
self.queue_name = queue_name |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._input_indexs = input_dataset.input_indexs |
|
|
self._device_type = device_type |
|
|
self._device_type = device_type |
|
|
@@ -3170,8 +3170,8 @@ class GeneratorDataset(MappableDataset): |
|
|
cls = self.__class__ |
|
|
cls = self.__class__ |
|
|
new_op = cls.__new__(cls) |
|
|
new_op = cls.__new__(cls) |
|
|
memodict[id(self)] = new_op |
|
|
memodict[id(self)] = new_op |
|
|
new_op.input = copy.deepcopy(self.input, memodict) |
|
|
|
|
|
new_op.output = copy.deepcopy(self.output, memodict) |
|
|
|
|
|
|
|
|
new_op.children = copy.deepcopy(self.children, memodict) |
|
|
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict) |
|
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) |
|
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) |
|
|
new_op.column_types = copy.deepcopy(self.column_types, memodict) |
|
|
new_op.column_types = copy.deepcopy(self.column_types, memodict) |
|
|
new_op.column_names = copy.deepcopy(self.column_names, memodict) |
|
|
new_op.column_names = copy.deepcopy(self.column_names, memodict) |
|
|
@@ -4879,14 +4879,14 @@ class BuildVocabDataset(DatasetOp): |
|
|
prefetch_size=None): |
|
|
prefetch_size=None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.columns = columns |
|
|
self.columns = columns |
|
|
self.input.append(input_dataset) |
|
|
|
|
|
|
|
|
self.children.append(input_dataset) |
|
|
self.prefetch_size = prefetch_size |
|
|
self.prefetch_size = prefetch_size |
|
|
self.vocab = vocab |
|
|
self.vocab = vocab |
|
|
self.freq_range = freq_range |
|
|
self.freq_range = freq_range |
|
|
self.top_k = top_k |
|
|
self.top_k = top_k |
|
|
self.special_tokens = special_tokens |
|
|
self.special_tokens = special_tokens |
|
|
self.special_first = special_first |
|
|
self.special_first = special_first |
|
|
input_dataset.output.append(self) |
|
|
|
|
|
|
|
|
input_dataset.parent.append(self) |
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
args = super().get_args() |
|
|
args = super().get_args() |
|
|
@@ -4905,11 +4905,11 @@ class BuildVocabDataset(DatasetOp): |
|
|
cls = self.__class__ |
|
|
cls = self.__class__ |
|
|
new_op = cls.__new__(cls) |
|
|
new_op = cls.__new__(cls) |
|
|
memodict[id(self)] = new_op |
|
|
memodict[id(self)] = new_op |
|
|
new_op.input = copy.deepcopy(self.input, memodict) |
|
|
|
|
|
|
|
|
new_op.children = copy.deepcopy(self.children, memodict) |
|
|
new_op.columns = copy.deepcopy(self.columns, memodict) |
|
|
new_op.columns = copy.deepcopy(self.columns, memodict) |
|
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) |
|
|
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) |
|
|
new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) |
|
|
new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) |
|
|
new_op.output = copy.deepcopy(self.output, memodict) |
|
|
|
|
|
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict) |
|
|
new_op.freq_range = copy.deepcopy(self.freq_range, memodict) |
|
|
new_op.freq_range = copy.deepcopy(self.freq_range, memodict) |
|
|
new_op.top_k = copy.deepcopy(self.top_k, memodict) |
|
|
new_op.top_k = copy.deepcopy(self.top_k, memodict) |
|
|
new_op.vocab = self.vocab |
|
|
new_op.vocab = self.vocab |
|
|
|