You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

queue.py 6.3 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This dataset module creates an internal queue class to more optimally pass data
  17. between multiple processes in Python. It has same API as multiprocessing.queue
  18. but it will pass large data through shared memory.
  19. """
  20. import multiprocessing.queues
  21. import multiprocessing
  22. import types
  23. import numpy as np
  24. from mindspore import log as logger
  25. from ..transforms.py_transforms_util import ExceptionHandler
  26. class _SharedQueue(multiprocessing.queues.Queue):
  27. """
  28. Class to implement a queue using shared memory for better performance.
  29. Args:
  30. size: Number of elements in the queue.
  31. copy_out: Flag to indidcate whether an extra copy should be done before returning. If data will immediately be
  32. copied before returning, then this can be set to False.
  33. max_rowsize: Maximum size of any element in the Queue in MB.
  34. """
  35. def __init__(self, size, copy_out=False, max_rowsize=6):
  36. super().__init__(size, ctx=multiprocessing.get_context())
  37. self.copy_out = copy_out
  38. # change max_rowsize in MB into bytes
  39. self.seg_size = max_rowsize * 1024 * 1024
  40. ##pipe can hold up to 65,636 bytes at a time
  41. self.min_shared_mem = 10000
  42. self.shm_list = []
  43. self.seg_pos = 0
  44. # num_seg has to be 2 more than the queue size. We can have remote worker filling a buffer, main process
  45. # reading a buffer and also have a full queue of buffers in the meta-data queue
  46. self.num_seg = size + 2
  47. self.data_immediate = 0
  48. self.data_shared = 1
  49. self.print_error = True
  50. try:
  51. for _ in range(self.num_seg):
  52. a = multiprocessing.Array("b", self.seg_size)
  53. self.shm_list.append(a)
  54. except Exception:
  55. raise RuntimeError(
  56. "_SharedQueue: Error allocating "
  57. + str(self.seg_size)
  58. + "bytes, "
  59. + str(self.num_seg)
  60. + " elements."
  61. + " This might be caused by insufficient shm, and the recommended shm size is at least 5 GB."
  62. )
  63. def put(self, data, timeout=None):
  64. if isinstance(data, ExceptionHandler):
  65. super().put(data, timeout=timeout)
  66. else:
  67. name_list = []
  68. count = 0
  69. start_bytes = 0
  70. if not isinstance(data, tuple) and not isinstance(data, np.ndarray):
  71. raise TypeError("return value of user defined python function in GeneratorDataset or"
  72. " map should be numpy array or tuple of numpy array.")
  73. for r in data:
  74. # the map:pyfunc is a yield generator which can't be serialize
  75. if isinstance(r, types.GeneratorType):
  76. raise TypeError("Can not pickle {} object, please verify pyfunc return with numpy array"
  77. .format(type(r)))
  78. if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem
  79. and start_bytes + r.nbytes < self.seg_size):
  80. # need to convert start_bytes to offset in array
  81. start_offset = start_bytes
  82. dest = np.ndarray(r.shape, r.dtype, buffer=self.shm_list[self.seg_pos].get_obj(),
  83. offset=start_offset)
  84. np.copyto(dest, r)
  85. byte = r.nbytes
  86. byte = 8 * ((byte + 7) // 8)
  87. start_bytes += byte
  88. name_list.append((self.data_shared, self.seg_pos, byte, r.dtype, r.shape))
  89. count += 1
  90. else:
  91. if isinstance(r, np.ndarray) and r.size >= self.min_shared_mem:
  92. # Only print out error the first time it happens
  93. if self.print_error:
  94. logger.warning(
  95. "Using shared memory queue, but rowsize is larger than allocated memory "
  96. + "max_rowsize "
  97. + str(self.seg_size)
  98. + " current rowsize "
  99. + str(start_bytes + r.nbytes)
  100. )
  101. self.print_error = False
  102. name_list.append((self.data_immediate, r))
  103. super().put(name_list, timeout=timeout)
  104. # note above could generate a queue full exception. It will be handled by teh caller
  105. # only increment seg_pos after successfully adding to metadata queue
  106. if start_bytes > 0:
  107. self.seg_pos = (self.seg_pos + 1) % self.num_seg
  108. def get(self, timeout=None):
  109. result = super().get(timeout=timeout)
  110. if isinstance(result, ExceptionHandler):
  111. return result
  112. r = []
  113. start_bytes = 0
  114. for x in result:
  115. if x[0] == self.data_shared:
  116. seg_pos = x[1]
  117. byte = x[2]
  118. dtype = x[3]
  119. shape = x[4]
  120. start_offset = start_bytes
  121. b = self.shm_list[seg_pos]
  122. data = np.ndarray(shape, dtype, buffer=b.get_obj(), offset=start_offset)
  123. start_bytes += byte
  124. if self.copy_out:
  125. data2 = np.copy(data)
  126. r.append(data2)
  127. else:
  128. r.append(data)
  129. elif x[0] == self.data_immediate:
  130. r.append(x[1])
  131. else:
  132. raise RuntimeError("SharedQueue, invalid entry in metadata.")
  133. return tuple(r)