From 1a064a4f67b82fccf705c1a0344f88693b1f5eaa Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Thu, 7 Jan 2021 15:15:21 +0800 Subject: [PATCH] minddata fix generatordataset get_dataset_size issue --- mindspore/dataset/engine/datasets.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 54332d0627..a3f0644722 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3780,6 +3780,18 @@ class GeneratorDataset(MappableDataset): self.schema = schema if not isinstance(schema, Schema): self.schema = Schema(schema) + # Move get dataset_size by len from parse to here, because self.source will + # lose attribution of '__len__' after deepcopy. + self.dataset_size = None + if hasattr(self.source, "__len__"): + if not self.num_shards: + self.dataset_size = len(self.source) + else: + self.dataset_size = math.ceil(len(self.source) / self.num_shards) + + rows_from_sampler = self._get_sampler_dataset_size() + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler def __deepcopy__(self, memodict): if id(self) in memodict: @@ -3838,24 +3850,16 @@ class GeneratorDataset(MappableDataset): return self.sampler.is_sharded() def parse(self, children=None): - dataset_size = -1 - if hasattr(self.source, "__len__"): - if not self.num_shards: - dataset_size = len(self.source) - else: - dataset_size = math.ceil(len(self.source) / self.num_shards) - - rows_from_sampler = self._get_sampler_dataset_size() - if rows_from_sampler is not None and rows_from_sampler < dataset_size: - dataset_size = rows_from_sampler + if self.dataset_size is None: + self.dataset_size = -1 if self.schema is None: return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize( - dataset_size) \ + self.dataset_size) \ .SetNumWorkers(self.num_parallel_workers) schema = self.schema if isinstance(schema, Schema): schema = self.schema.cpp_schema - return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(dataset_size).SetNumWorkers( + return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(self.dataset_size).SetNumWorkers( self.num_parallel_workers)