You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_client.py 4.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Cache client
  16. """
  17. import copy
  18. from mindspore._c_dataengine import CacheClient
  19. from ..core.validator_helpers import type_check, check_pos_int32, check_pos_uint32, check_uint64, check_positive, \
  20. check_value
  21. class DatasetCache:
  22. """
  23. A client to interface with tensor caching service.
  24. For details, please check `Tutorial <https://www.mindspore.cn/docs/programming_guide/en/master/enable_cache.html>`_,
  25. `Programming guide <https://www.mindspore.cn/docs/programming_guide/en/master/cache.html>`_.
  26. Args:
  27. session_id (int): A user assigned session id for the current pipeline.
  28. size (int, optional): Size of the memory set aside for the row caching (default=0, which means unlimited,
  29. note that it might bring in the risk of running out of memory on the machine).
  30. spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False).
  31. hostname (str, optional): Host name (default=None, use default hostname '127.0.0.1').
  32. port (int, optional): Port to connect to server (default=None, use default port 50052).
  33. num_connections (int, optional): Number of tcp/ip connections (default=None, use default value 12).
  34. prefetch_size (int, optional): The size of the cache queue between operations
  35. (default=None, use default value 20).
  36. Examples:
  37. >>> import mindspore.dataset as ds
  38. >>>
  39. >>> # create a cache instance, in which session_id is generated from command line `cache_admin -g`
  40. >>> some_cache = ds.DatasetCache(session_id=session_id, size=0)
  41. >>>
  42. >>> dataset_dir = "path/to/imagefolder_directory"
  43. >>> ds1 = ds.ImageFolderDataset(dataset_dir, cache=some_cache)
  44. """
  45. def __init__(self, session_id, size=0, spilling=False, hostname=None, port=None, num_connections=None,
  46. prefetch_size=None):
  47. check_pos_uint32(session_id, "session_id")
  48. type_check(size, (int,), "size")
  49. if size != 0:
  50. check_positive(size, "size")
  51. check_uint64(size, "size")
  52. type_check(spilling, (bool,), "spilling")
  53. if hostname is not None:
  54. type_check(hostname, (str,), "hostname")
  55. if port is not None:
  56. type_check(port, (int,), "port")
  57. check_value(port, (1025, 65535), "port")
  58. if num_connections is not None:
  59. check_pos_int32(num_connections, "num_connections")
  60. if prefetch_size is not None:
  61. check_pos_int32(prefetch_size, "prefetch_size")
  62. self.session_id = session_id
  63. self.size = size
  64. self.spilling = spilling
  65. self.hostname = hostname
  66. self.port = port
  67. self.prefetch_size = prefetch_size
  68. self.num_connections = num_connections
  69. self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size)
  70. def get_stat(self):
  71. """Get the statistics from a cache."""
  72. return self.cache_client.GetStat()
  73. def __deepcopy__(self, memodict):
  74. if id(self) in memodict:
  75. return memodict[id(self)]
  76. cls = self.__class__
  77. new_cache = cls.__new__(cls)
  78. memodict[id(self)] = new_cache
  79. new_cache.session_id = copy.deepcopy(self.session_id, memodict)
  80. new_cache.spilling = copy.deepcopy(self.spilling, memodict)
  81. new_cache.size = copy.deepcopy(self.size, memodict)
  82. new_cache.hostname = copy.deepcopy(self.hostname, memodict)
  83. new_cache.port = copy.deepcopy(self.port, memodict)
  84. new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
  85. new_cache.num_connections = copy.deepcopy(self.num_connections, memodict)
  86. new_cache.cache_client = self.cache_client
  87. return new_cache