Browse Source

!8218 dataset: modify api comment and repair deepcopy resulted total_batch attribute missing problem

From: @ms_yan
Reviewed-by: @jonyguo,@liucunwei
Signed-off-by: @liucunwei
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e47366cd4c
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      mindspore/dataset/engine/datasets.py

+ 8
- 3
mindspore/dataset/engine/datasets.py View File

@@ -2166,6 +2166,8 @@ class MapDataset(DatasetOp):
new_op.operations = self.operations new_op.operations = self.operations
new_op.dataset_size = self.dataset_size new_op.dataset_size = self.dataset_size
new_op.callbacks = self.callbacks new_op.callbacks = self.callbacks
if hasattr(self, "__total_batch__"):
new_op.__total_batch__ = self.__total_batch__
return new_op return new_op


# Iterator bootstrap will be called on iterator construction. # Iterator bootstrap will be called on iterator construction.
@@ -3640,6 +3642,8 @@ class GeneratorDataset(MappableDataset):
new_op.num_samples = copy.deepcopy(self.num_samples, memodict) new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
new_op.dataset_size = self.dataset_size new_op.dataset_size = self.dataset_size
new_op.sampler = copy.deepcopy(self.sampler) new_op.sampler = copy.deepcopy(self.sampler)
if hasattr(self, "__total_batch__"):
new_op.__total_batch__ = self.__total_batch__
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler, samplers.RandomSampler, samplers.SubsetRandomSampler,
@@ -5705,10 +5709,11 @@ class NumpySlicesDataset(GeneratorDataset):


Args: Args:
data (Union[list, tuple, dict]) Input of given data. Supported data types include: list, tuple, dict and other data (Union[list, tuple, dict]) Input of given data. Supported data types include: list, tuple, dict and other
NumPy formats. Input data will be sliced along the first dimension and generate additional rows.
Large data is not recommended to be loaded in this way as data is loading into memory.
NumPy formats. Input data will be sliced along the first dimension and generate additional rows, if input is
list, there will be one column in each row, otherwise there tends to be multi columns. Large data is not
recommended to be loaded in this way as data is loading into memory.
column_names (list[str], optional): List of column names of the dataset (default=None). If column_names is not column_names (list[str], optional): List of column names of the dataset (default=None). If column_names is not
provided, when data is dict, column_names will be its keys, otherwise it will be like column_1, column_2 ...
provided, when data is dict, column_names will be its keys, otherwise it will be like column_0, column_1 ...
num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images). num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images).
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.


Loading…
Cancel
Save