|
|
|
@@ -18,21 +18,20 @@ |
|
|
|
import copy |
|
|
|
from mindspore._c_dataengine import CacheClient |
|
|
|
|
|
|
|
from ..core.validator_helpers import type_check, check_uint32, check_uint64 |
|
|
|
|
|
|
|
class DatasetCache: |
|
|
|
""" |
|
|
|
A client to interface with tensor caching service |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, session_id=None, size=None, spilling=False): |
|
|
|
if session_id is None: |
|
|
|
raise RuntimeError("Session generation is not implemented yet. session id required") |
|
|
|
self.size = size if size is not None else 0 |
|
|
|
if size < 0: |
|
|
|
raise ValueError("cache size should be 0 or positive integer value but got: size={}".format(size)) |
|
|
|
if not isinstance(spilling, bool): |
|
|
|
raise ValueError( |
|
|
|
"spilling argument for cache should be a boolean value but got: spilling={}".format(spilling)) |
|
|
|
def __init__(self, session_id=None, size=0, spilling=False): |
|
|
|
check_uint32(session_id, "session_id") |
|
|
|
check_uint64(size, "size") |
|
|
|
type_check(spilling, (bool,), "spilling") |
|
|
|
|
|
|
|
self.session_id = session_id |
|
|
|
self.size = size |
|
|
|
self.spilling = spilling |
|
|
|
self.cache_client = CacheClient(session_id, size, spilling) |
|
|
|
|
|
|
|
|