You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_summary_scheduler.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Schedule the event writer process."""
  16. import multiprocessing as mp
  17. from enum import Enum, unique
  18. from mindspore import log as logger
  19. from ..._c_expression import Tensor
  20. from ._summary_adapter import SummaryType, package_summary_event, save_summary_data
  21. # define the type of summary
  22. FORMAT_SCALAR_STR = "Scalar"
  23. FORMAT_TENSOR_STR = "Tensor"
  24. FORMAT_IMAGE_STR = "Image"
  25. FORMAT_HISTOGRAM_STR = "Histogram"
  26. FORMAT_BEGIN_SLICE = "[:"
  27. FORMAT_END_SLICE = "]"
  28. # cache the summary data dict
  29. # {id: SummaryData}
  30. # |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
  31. g_summary_data_id = 0
  32. g_summary_data_dict = {}
  33. # cache the summary data file
  34. g_summary_writer_id = 0
  35. g_summary_file = {}
  36. @unique
  37. class ScheduleMethod(Enum):
  38. """Schedule method type."""
  39. FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue
  40. TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy)
  41. CACHE_DATA = 2 # Cache data util have idle worker to process it
  42. @unique
  43. class WorkerStatus(Enum):
  44. """Worker status."""
  45. WORKER_INIT = 0 # data is exist but not process
  46. WORKER_PROCESSING = 1 # data is processing
  47. WORKER_PROCESSED = 2 # data already processed
  48. def _parse_tag_format(tag: str):
  49. """
  50. Parse the tag.
  51. Args:
  52. tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor].
  53. Returns:
  54. Tuple, (SummaryType, summary_tag).
  55. """
  56. summary_type = SummaryType.INVALID
  57. summary_tag = tag
  58. if tag is None:
  59. logger.error("The tag is None")
  60. return summary_type, summary_tag
  61. # search the slice
  62. slice_begin = FORMAT_BEGIN_SLICE
  63. slice_end = FORMAT_END_SLICE
  64. index = tag.rfind(slice_begin)
  65. if index is -1:
  66. logger.error("The tag(%s) have not the key slice.", tag)
  67. return summary_type, summary_tag
  68. # slice the tag
  69. summary_tag = tag[:index]
  70. # check the slice end
  71. if tag[-1:] != slice_end:
  72. logger.error("The tag(%s) end format is error", tag)
  73. return summary_type, summary_tag
  74. # check the type
  75. type_str = tag[index + 2: -1]
  76. logger.debug("The summary_tag is = %r", summary_tag)
  77. logger.debug("The type_str value is = %r", type_str)
  78. if type_str == FORMAT_SCALAR_STR:
  79. summary_type = SummaryType.SCALAR
  80. elif type_str == FORMAT_TENSOR_STR:
  81. summary_type = SummaryType.TENSOR
  82. elif type_str == FORMAT_IMAGE_STR:
  83. summary_type = SummaryType.IMAGE
  84. elif type_str == FORMAT_HISTOGRAM_STR:
  85. summary_type = SummaryType.HISTOGRAM
  86. else:
  87. logger.error("The tag(%s) type is invalid.", tag)
  88. summary_type = SummaryType.INVALID
  89. return summary_type, summary_tag
  90. class SummaryDataManager:
  91. """Manage the summary global data cache."""
  92. def __init__(self):
  93. global g_summary_data_dict
  94. self.size = len(g_summary_data_dict)
  95. @classmethod
  96. def summary_data_save(cls, data):
  97. """Save the global summary cache."""
  98. global g_summary_data_id
  99. data_id = g_summary_data_id
  100. save_summary_data(data_id, data)
  101. g_summary_data_id += 1
  102. return data_id
  103. @classmethod
  104. def summary_file_set(cls, event_writer):
  105. """Support the many event_writer."""
  106. global g_summary_file, g_summary_writer_id
  107. g_summary_writer_id += 1
  108. g_summary_file[g_summary_writer_id] = event_writer
  109. return g_summary_writer_id
  110. @classmethod
  111. def summary_file_get(cls, writer_id=1):
  112. ret = None
  113. global g_summary_file
  114. if writer_id in g_summary_file:
  115. ret = g_summary_file.get(writer_id)
  116. return ret
  117. class WorkerScheduler:
  118. """
  119. Create worker and schedule data to worker.
  120. Args:
  121. writer_id (int): The index of writer.
  122. """
  123. def __init__(self, writer_id):
  124. # Create the process of write event file
  125. self.write_lock = mp.Lock()
  126. # Schedule info for all worker
  127. # Format: {worker: (step, WorkerStatus)}
  128. self.schedule_table = {}
  129. # write id
  130. self.writer_id = writer_id
  131. self.has_graph = False
  132. def dispatch(self, step, data):
  133. """
  134. Select schedule strategy and dispatch data.
  135. Args:
  136. step (Number): The number of step index.
  137. data (Object): The data of recode for summary.
  138. Retruns:
  139. bool, run successfully or not.
  140. """
  141. # save the data to global cache , convert the tensor to numpy
  142. result, size, data = self._data_convert(data)
  143. if result is False:
  144. logger.error("The step(%r) summary data(%r) is invalid.", step, size)
  145. return False
  146. data_id = SummaryDataManager.summary_data_save(data)
  147. self._start_worker(step, data_id)
  148. return True
  149. def _start_worker(self, step, data_id):
  150. """
  151. Start worker.
  152. Args:
  153. step (Number): The index of recode.
  154. data_id (str): The id of work.
  155. Return:
  156. bool, run successfully or not.
  157. """
  158. # assign the worker
  159. policy = self._make_policy()
  160. if policy == ScheduleMethod.TEMP_WORKER:
  161. worker = SummaryDataProcess(step, data_id, self.write_lock, self.writer_id)
  162. # update the schedule table
  163. self.schedule_table[worker] = (step, data_id, WorkerStatus.WORKER_INIT)
  164. # start the worker
  165. worker.start()
  166. else:
  167. logger.error("Do not support the other scheduler policy now.")
  168. # update the scheduler infor
  169. self._update_scheduler()
  170. return True
  171. def _data_convert(self, data_list):
  172. """Convert the data."""
  173. if data_list is None:
  174. logger.warning("The step does not have record data.")
  175. return False, 0, None
  176. # convert the summary to numpy
  177. size = 0
  178. for v_dict in data_list:
  179. tag = v_dict["name"]
  180. data = v_dict["data"]
  181. # confirm the data is valid
  182. summary_type, summary_tag = _parse_tag_format(tag)
  183. if summary_type == SummaryType.INVALID:
  184. logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
  185. return False, 0, None
  186. if isinstance(data, Tensor):
  187. # get the summary type and parse the tag
  188. v_dict["name"] = summary_tag
  189. v_dict["type"] = summary_type
  190. v_dict["data"] = data.asnumpy()
  191. size += v_dict["data"].size
  192. else:
  193. logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
  194. return False, 0, None
  195. return True, size, data_list
  196. def _update_scheduler(self):
  197. """Check the worker status and update schedule table."""
  198. workers = list(self.schedule_table.keys())
  199. for worker in workers:
  200. if not worker.is_alive():
  201. # update the table
  202. worker.join()
  203. del self.schedule_table[worker]
  204. def close(self):
  205. """Confirm all worker is end."""
  206. workers = self.schedule_table.keys()
  207. for worker in workers:
  208. if worker.is_alive():
  209. worker.join()
  210. def _make_policy(self):
  211. """Select the schedule strategy by data."""
  212. # now only support the temp worker
  213. return ScheduleMethod.TEMP_WORKER
  214. class SummaryDataProcess(mp.Process):
  215. """
  216. Process that consume the summarydata.
  217. Args:
  218. step (int): The index of step.
  219. data_id (int): The index of summary data.
  220. write_lock (Lock): The process lock for writer same file.
  221. writer_id (int): The index of writer.
  222. """
  223. def __init__(self, step, data_id, write_lock, writer_id):
  224. super(SummaryDataProcess, self).__init__()
  225. self.daemon = True
  226. self.writer_id = writer_id
  227. self.writer = SummaryDataManager.summary_file_get(self.writer_id)
  228. if self.writer is None:
  229. logger.error("The writer_id(%r) does not have writer", writer_id)
  230. self.step = step
  231. self.data_id = data_id
  232. self.write_lock = write_lock
  233. self.name = "SummaryDataConsumer_" + str(self.step)
  234. def run(self):
  235. """The consumer is process the step data and exit."""
  236. # convert the data to event
  237. # All exceptions need to be caught and end the queue
  238. try:
  239. logger.debug("process(%r) process a data(%r)", self.name, self.step)
  240. # package the summary event
  241. summary_event = package_summary_event(self.data_id, self.step)
  242. # send the event to file
  243. self._write_summary(summary_event)
  244. except Exception as e:
  245. logger.error("Summary data mq consumer exception occurred, value = %r", e)
  246. def _write_summary(self, summary_event):
  247. """
  248. Write the summary to event file.
  249. Note:
  250. The write record format:
  251. 1 uint64 : data length.
  252. 2 uint32 : mask crc value of data length.
  253. 3 bytes : data.
  254. 4 uint32 : mask crc value of data.
  255. Args:
  256. summary_event (Event): The summary event of proto.
  257. """
  258. event_str = summary_event.SerializeToString()
  259. self.write_lock.acquire()
  260. self.writer.write_event_to_file(event_str)
  261. self.writer.flush()
  262. self.write_lock.release()