From 363632ca9d2fa88278a235c477ae6395ae23a18f Mon Sep 17 00:00:00 2001 From: Yanjun Peng Date: Thu, 9 Apr 2020 11:04:13 +0800 Subject: [PATCH] fix dataset para validator check --- mindspore/dataset/engine/samplers.py | 1 - mindspore/dataset/engine/validators.py | 5 +++++ mindspore/dataset/transforms/vision/validators.py | 4 ++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index ed36e72b65..62a3dbed18 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -127,7 +127,6 @@ class RandomSampler(): Raises: ValueError: If replacement is not boolean. - ValueError: If num_samples is not None and replacement is false. ValueError: If num_samples is not positive. """ diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 26d6241945..b5ebc24b39 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -556,6 +556,11 @@ def check_generatordataset(method): if column_names is None: raise ValueError("column_names is not provided.") + # check prefetch_size range + prefetch_size = param_dict.get('prefetch_size') + if prefetch_size is not None and (prefetch_size <= 0 or prefetch_size > 1024): + raise ValueError("prefetch_size exceeds the boundary.") + check_param_type(nreq_param_int, param_dict, int) check_param_type(nreq_param_list, param_dict, list) diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index caab120af4..ef4b879f8c 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -104,6 +104,10 @@ def check_padding(padding): raise ValueError("The size of the padding list or tuple should be 2 or 4.") else: raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.") + if not (isinstance(left, int) and isinstance(top, int) and isinstance(right, int) and isinstance(bottom, int)): + raise TypeError("Padding value should be integer.") + if left < 0 or top < 0 or right < 0 or bottom < 0: + raise ValueError("Padding value could not be negative.") return left, top, right, bottom