You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_client.py 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Cache client
  16. """
  17. import os
  18. import copy
  19. from ..core.validator_helpers import type_check, check_uint32, check_uint64
  20. class DatasetCache:
  21. """
  22. A client to interface with tensor caching service
  23. """
  24. def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None,
  25. prefetch_size=None):
  26. check_uint32(session_id, "session_id")
  27. check_uint64(size, "size")
  28. type_check(spilling, (bool,), "spilling")
  29. self.session_id = session_id
  30. self.size = size
  31. self.spilling = spilling
  32. self.hostname = hostname
  33. self.port = port
  34. self.prefetch_size = prefetch_size
  35. self.num_connections = num_connections
  36. if os.getenv('MS_ENABLE_CACHE') != 'TRUE':
  37. # temporary disable cache feature in the current release
  38. self.cache_client = None
  39. else:
  40. from mindspore._c_dataengine import CacheClient
  41. self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size)
  42. def GetStat(self):
  43. return self.cache_client.GetStat()
  44. def __deepcopy__(self, memodict):
  45. if id(self) in memodict:
  46. return memodict[id(self)]
  47. cls = self.__class__
  48. new_cache = cls.__new__(cls)
  49. memodict[id(self)] = new_cache
  50. new_cache.session_id = copy.deepcopy(self.session_id, memodict)
  51. new_cache.spilling = copy.deepcopy(self.spilling, memodict)
  52. new_cache.size = copy.deepcopy(self.size, memodict)
  53. new_cache.hostname = copy.deepcopy(self.hostname, memodict)
  54. new_cache.port = copy.deepcopy(self.port, memodict)
  55. new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
  56. new_cache.num_connections = copy.deepcopy(self.num_connections, memodict)
  57. new_cache.cache_client = self.cache_client
  58. return new_cache