Browse Source

Minor fix for DatasetCache param validation

tags/v1.1.0
Lixia Chen 5 years ago
parent
commit
ab7427f1a9
3 changed files with 27 additions and 11 deletions
  1. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
  2. +14
    -2
      mindspore/dataset/engine/cache_client.py
  3. +11
    -7
      tests/ut/python/dataset/test_cache_map.py

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc View File

@@ -46,8 +46,8 @@ Status CacheClient::Builder::SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive"); CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 1024, "Port must be in range (1025..65535)");
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "Port must be in range (1025..65535)");
CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1", CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1",
"now cache client has to be on the same host with cache server"); "now cache client has to be on the same host with cache server");
return Status::OK(); return Status::OK();


+ 14
- 2
mindspore/dataset/engine/cache_client.py View File

@@ -18,7 +18,7 @@
import copy import copy
from mindspore._c_dataengine import CacheClient from mindspore._c_dataengine import CacheClient


from ..core.validator_helpers import type_check, check_uint32, check_uint64
from ..core.validator_helpers import type_check, check_uint32, check_uint64, check_positive, check_value




class DatasetCache: class DatasetCache:
@@ -29,8 +29,20 @@ class DatasetCache:
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None, def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None,
prefetch_size=None): prefetch_size=None):
check_uint32(session_id, "session_id") check_uint32(session_id, "session_id")
check_uint64(size, "size")
type_check(size, (int,), "size")
if size != 0:
check_positive(size, "size")
check_uint64(size, "size")
type_check(spilling, (bool,), "spilling") type_check(spilling, (bool,), "spilling")
if hostname is not None:
type_check(hostname, (str,), "hostname")
if port is not None:
type_check(port, (int,), "port")
check_value(port, (1025, 65535), "port")
if num_connections is not None:
check_uint32(num_connections, "num_connections")
if prefetch_size is not None:
check_uint32(prefetch_size, "prefetch_size")


self.session_id = session_id self.session_id = session_id
self.size = size self.size = size


+ 11
- 7
tests/ut/python/dataset/test_cache_map.py View File

@@ -550,7 +550,7 @@ def test_cache_map_parameter_check():


with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
ds.DatasetCache(session_id=1, size=-1, spilling=True) ds.DatasetCache(session_id=1, size=-1, spilling=True)
assert "Input is not within the required interval" in str(info.value)
assert "Input size must be greater than 0" in str(info.value)


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size="1", spilling=True) ds.DatasetCache(session_id=1, size="1", spilling=True)
@@ -564,6 +564,10 @@ def test_cache_map_parameter_check():
ds.DatasetCache(session_id=1, size=0, spilling="illegal") ds.DatasetCache(session_id=1, size=0, spilling="illegal")
assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value) assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value)


with pytest.raises(TypeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052)
assert "Argument hostname with value 50052 is not of type (<class 'str'>,)" in str(err.value)

with pytest.raises(RuntimeError) as err: with pytest.raises(RuntimeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal")
assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value)
@@ -574,19 +578,19 @@ def test_cache_map_parameter_check():


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal")
assert "incompatible constructor arguments" in str(info.value)
assert "Argument port with value illegal is not of type (<class 'int'>,)" in str(info.value)


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052") ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052")
assert "incompatible constructor arguments" in str(info.value)
assert "Argument port with value 50052 is not of type (<class 'int'>,)" in str(info.value)


with pytest.raises(RuntimeError) as err:
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, port=0) ds.DatasetCache(session_id=1, size=0, spilling=True, port=0)
assert "Unexpected error. port must be positive" in str(err.value)
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)


with pytest.raises(RuntimeError) as err:
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536) ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536)
assert "Unexpected error. illegal port number" in str(err.value)
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)


with pytest.raises(TypeError) as err: with pytest.raises(TypeError) as err:
ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True) ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)


Loading…
Cancel
Save