Browse Source

Fix source len is not divisible by batch_size in user defined sampler

tags/v1.2.0-rc1
luoyang 5 years ago
parent
commit
25ff4c9312
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      mindspore/dataset/engine/samplers.py

+ 5
- 1
mindspore/dataset/engine/samplers.py View File

@@ -237,6 +237,7 @@ class Sampler(BuiltinSampler):

# Indices fetcher
# Do not override this method!
# pylint: disable=missing-docstring
def _get_indices(self):
sampler_iter = iter(self)
ret = []
@@ -246,7 +247,10 @@ class Sampler(BuiltinSampler):
ret.append(idx)
except StopIteration:
break
return np.array(ret)
indices = np.array(ret)
if indices.dtype == object:
raise RuntimeError("Fetched indices can not be converted to a valid ndarray.")
return indices

# Instance fetcher
# Do not override this method!


Loading…
Cancel
Save