You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_client.py 4.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Cache client
  16. """
  17. import copy
  18. from mindspore._c_dataengine import CacheClient
  19. from ..core.validator_helpers import type_check, check_uint32, check_uint64, check_positive, check_value
  20. class DatasetCache:
  21. """
  22. A client to interface with tensor caching service.
  23. For details, please check `Chinese tutorial <https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html>`_,
  24. `Chinese programming guide <https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache>`_.
  25. Args:
  26. session_id (int): A user assigned session id for the current pipeline.
  27. size (int, optional): Size of the memory set aside for the row caching (default=0 which means unlimited,
  28. note that it might bring in the risk of running out of memory on the machine).
  29. spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False).
  30. hostname (str, optional): Host name (default="127.0.0.1").
  31. port (int, optional): Port to connect to server (default=50052).
  32. num_connections (int, optional): Number of tcp/ip connections (default=12).
  33. prefetch_size (int, optional): Prefetch size (default=20).
  34. Examples:
  35. >>> import mindspore.dataset as ds
  36. >>>
  37. >>> # create a cache instance, in which session_id is generated from command line `cache_admin -g`
  38. >>> some_cache = ds.DatasetCache(session_id=session_id, size=0)
  39. >>>
  40. >>> dataset_dir = "path/to/imagefolder_directory"
  41. >>> ds1 = ds.ImageFolderDataset(dataset_dir, cache=some_cache)
  42. """
  43. def __init__(self, session_id, size=0, spilling=False, hostname=None, port=None, num_connections=None,
  44. prefetch_size=None):
  45. check_uint32(session_id, "session_id")
  46. type_check(size, (int,), "size")
  47. if size != 0:
  48. check_positive(size, "size")
  49. check_uint64(size, "size")
  50. type_check(spilling, (bool,), "spilling")
  51. if hostname is not None:
  52. type_check(hostname, (str,), "hostname")
  53. if port is not None:
  54. type_check(port, (int,), "port")
  55. check_value(port, (1025, 65535), "port")
  56. if num_connections is not None:
  57. check_uint32(num_connections, "num_connections")
  58. if prefetch_size is not None:
  59. check_uint32(prefetch_size, "prefetch_size")
  60. self.session_id = session_id
  61. self.size = size
  62. self.spilling = spilling
  63. self.hostname = hostname
  64. self.port = port
  65. self.prefetch_size = prefetch_size
  66. self.num_connections = num_connections
  67. self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size)
  68. def GetStat(self):
  69. return self.cache_client.GetStat()
  70. def __deepcopy__(self, memodict):
  71. if id(self) in memodict:
  72. return memodict[id(self)]
  73. cls = self.__class__
  74. new_cache = cls.__new__(cls)
  75. memodict[id(self)] = new_cache
  76. new_cache.session_id = copy.deepcopy(self.session_id, memodict)
  77. new_cache.spilling = copy.deepcopy(self.spilling, memodict)
  78. new_cache.size = copy.deepcopy(self.size, memodict)
  79. new_cache.hostname = copy.deepcopy(self.hostname, memodict)
  80. new_cache.port = copy.deepcopy(self.port, memodict)
  81. new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
  82. new_cache.num_connections = copy.deepcopy(self.num_connections, memodict)
  83. new_cache.cache_client = self.cache_client
  84. return new_cache