Browse Source

Minor fix for DatasetCache param validation

tags/v1.1.0
Lixia Chen 5 years ago
parent
commit
ab7427f1a9
3 changed files with 27 additions and 11 deletions
  1. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
  2. +14
    -2
      mindspore/dataset/engine/cache_client.py
  3. +11
    -7
      tests/ut/python/dataset/test_cache_map.py

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc View File

@@ -46,8 +46,8 @@ Status CacheClient::Builder::SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 1024, "Port must be in range (1025..65535)");
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "Port must be in range (1025..65535)");
CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1",
"now cache client has to be on the same host with cache server");
return Status::OK();


+ 14
- 2
mindspore/dataset/engine/cache_client.py View File

@@ -18,7 +18,7 @@
import copy
from mindspore._c_dataengine import CacheClient

from ..core.validator_helpers import type_check, check_uint32, check_uint64
from ..core.validator_helpers import type_check, check_uint32, check_uint64, check_positive, check_value


class DatasetCache:
@@ -29,8 +29,20 @@ class DatasetCache:
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None,
prefetch_size=None):
check_uint32(session_id, "session_id")
check_uint64(size, "size")
type_check(size, (int,), "size")
if size != 0:
check_positive(size, "size")
check_uint64(size, "size")
type_check(spilling, (bool,), "spilling")
if hostname is not None:
type_check(hostname, (str,), "hostname")
if port is not None:
type_check(port, (int,), "port")
check_value(port, (1025, 65535), "port")
if num_connections is not None:
check_uint32(num_connections, "num_connections")
if prefetch_size is not None:
check_uint32(prefetch_size, "prefetch_size")

self.session_id = session_id
self.size = size


+ 11
- 7
tests/ut/python/dataset/test_cache_map.py View File

@@ -550,7 +550,7 @@ def test_cache_map_parameter_check():

with pytest.raises(ValueError) as info:
ds.DatasetCache(session_id=1, size=-1, spilling=True)
assert "Input is not within the required interval" in str(info.value)
assert "Input size must be greater than 0" in str(info.value)

with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size="1", spilling=True)
@@ -564,6 +564,10 @@ def test_cache_map_parameter_check():
ds.DatasetCache(session_id=1, size=0, spilling="illegal")
assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value)

with pytest.raises(TypeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052)
assert "Argument hostname with value 50052 is not of type (<class 'str'>,)" in str(err.value)

with pytest.raises(RuntimeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal")
assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value)
@@ -574,19 +578,19 @@ def test_cache_map_parameter_check():

with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal")
assert "incompatible constructor arguments" in str(info.value)
assert "Argument port with value illegal is not of type (<class 'int'>,)" in str(info.value)

with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052")
assert "incompatible constructor arguments" in str(info.value)
assert "Argument port with value 50052 is not of type (<class 'int'>,)" in str(info.value)

with pytest.raises(RuntimeError) as err:
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, port=0)
assert "Unexpected error. port must be positive" in str(err.value)
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)

with pytest.raises(RuntimeError) as err:
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536)
assert "Unexpected error. illegal port number" in str(err.value)
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)

with pytest.raises(TypeError) as err:
ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)


Loading…
Cancel
Save