Browse Source

!13130 Fix source len is not divisible by batch_size in user defined sampler

From: @luoyang42
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
a0b17ea6c6
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      mindspore/dataset/engine/samplers.py

+ 5
- 1
mindspore/dataset/engine/samplers.py View File

@@ -242,6 +242,7 @@ class Sampler(BuiltinSampler):

# Indices fetcher
# Do not override this method!
# pylint: disable=missing-docstring
def _get_indices(self):
sampler_iter = iter(self)
ret = []
@@ -251,7 +252,10 @@ class Sampler(BuiltinSampler):
ret.append(idx)
except StopIteration:
break
return np.array(ret)
indices = np.array(ret)
if indices.dtype == object:
raise RuntimeError("Fetched indices can not be converted to a valid ndarray.")
return indices

# Instance fetcher
# Do not override this method!


Loading…
Cancel
Save