Browse Source

set dataset_size in generator when source has len

tags/v0.6.0-beta
Zirui Wu 5 years ago
parent
commit
3b42c360b6
1 changed files with 4 additions and 0 deletions
  1. +4
    -0
      mindspore/dataset/engine/datasets.py

+ 4
- 0
mindspore/dataset/engine/datasets.py View File

@@ -3157,6 +3157,9 @@ class GeneratorDataset(MappableDataset):
self.column_names.append(col["name"])
self.column_types.append(DataType(col["type"]))

if source is not None and hasattr(source, "__len__"):
self._dataset_size = len(source)

def get_args(self):
args = super().get_args()
args["source"] = self.source
@@ -3177,6 +3180,7 @@ class GeneratorDataset(MappableDataset):
return self._dataset_size
if self._dataset_size is None:
return None

return min(rows_from_sampler, self._dataset_size)

# manually set dataset_size as a temporary solution.


Loading…
Cancel
Save