From c0e2a63fdb989ff598869b38e184c5049cea1948 Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Wed, 8 Apr 2020 16:36:06 -0400 Subject: [PATCH] Correct dataset error checking --- mindspore/dataset/engine/datasets.py | 2 -- mindspore/dataset/engine/validators.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ab2290c13c..2058bbf826 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -82,8 +82,6 @@ def zip(datasets): if len(datasets) <= 1: raise ValueError( "Can't zip empty or just one dataset!") - if not isinstance(datasets, tuple): - raise TypeError("The zip function %s type error!" % (datasets)) return ZipDataset(datasets) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 26d6241945..4c84cfe354 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -105,13 +105,13 @@ def check(method): "The %s function %s exceeds the boundary!" % ( func_name, param_name)) if isinstance(arg, int) and param_name == "num_parallel_workers" and ( - arg <= 0 or arg > cpu_count()): + arg < 1 or arg > cpu_count()): raise ValueError( "The %s function %s exceeds the boundary(%s)!" % ( func_name, param_name, cpu_count())) if isinstance(arg, int) and param_name != "seed" \ and param_name != "count" and param_name != "prefetch_size" \ - and param_name != "num_parallel_workers" and (arg <= 0 or arg > 2147483647): + and param_name != "num_parallel_workers" and (arg < 1 or arg > 2147483647): raise ValueError( "The %s function %s exceeds the boundary!" % ( func_name, param_name)) @@ -271,8 +271,8 @@ def check_interval_closed(param, param_name, valid_range): def check_num_parallel_workers(value): check_type(value, 'num_parallel_workers', int) - if value <= 0 or value > cpu_count(): - raise ValueError("num_parallel_workers exceeds the boundary between 0 and {}!".format(cpu_count())) + if value < 1 or value > cpu_count(): + raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) def check_num_samples(value):