diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 56ef705f60..e91dca9ce0 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -115,6 +115,8 @@ class Sampler: return self.child_sampler.is_sharded() def get_num_samples(self): + if self.num_samples is None: + return None return self._get_indices().size