Browse Source

!11092 [MD][bugfix] Fix GeneratorDataset get_dataset_size issue

From: @xiefangqi
Reviewed-by: @pandoublefeng,@heleiwang
Signed-off-by: @heleiwang
tags/v1.1.1
mindspore-ci-bot Gitee 5 years ago
parent
commit
bd566d1e08
1 changed files with 16 additions and 12 deletions
  1. +16
    -12
      mindspore/dataset/engine/datasets.py

+ 16
- 12
mindspore/dataset/engine/datasets.py View File

@@ -3780,6 +3780,18 @@ class GeneratorDataset(MappableDataset):
self.schema = schema
if not isinstance(schema, Schema):
self.schema = Schema(schema)
# Move get dataset_size by len from parse to here, because self.source will
# lose attribution of '__len__' after deepcopy.
self.dataset_size = None
if hasattr(self.source, "__len__"):
if not self.num_shards:
self.dataset_size = len(self.source)
else:
self.dataset_size = math.ceil(len(self.source) / self.num_shards)

rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler

def __deepcopy__(self, memodict):
if id(self) in memodict:
@@ -3838,24 +3850,16 @@ class GeneratorDataset(MappableDataset):
return self.sampler.is_sharded()

def parse(self, children=None):
dataset_size = -1
if hasattr(self.source, "__len__"):
if not self.num_shards:
dataset_size = len(self.source)
else:
dataset_size = math.ceil(len(self.source) / self.num_shards)

rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < dataset_size:
dataset_size = rows_from_sampler
if self.dataset_size is None:
self.dataset_size = -1
if self.schema is None:
return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize(
dataset_size) \
self.dataset_size) \
.SetNumWorkers(self.num_parallel_workers)
schema = self.schema
if isinstance(schema, Schema):
schema = self.schema.cpp_schema
return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(dataset_size).SetNumWorkers(
return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(self.dataset_size).SetNumWorkers(
self.num_parallel_workers)




Loading…
Cancel
Save