You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_writer_pool.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Write events to disk in a base directory."""
  16. import os
  17. import time
  18. from collections import deque
  19. import mindspore.log as logger
  20. from ._lineage_adapter import serialize_to_lineage_event
  21. from ._summary_adapter import package_graph_event, package_summary_event
  22. from ._explain_adapter import package_explain_event
  23. from .writer import LineageWriter, SummaryWriter, ExplainWriter
  24. try:
  25. from multiprocessing import get_context
  26. ctx = get_context('forkserver')
  27. except ValueError:
  28. import multiprocessing as ctx
  29. def _pack_data(datadict, wall_time):
  30. """Pack data according to which plugin."""
  31. result, summaries, step = [], [], None
  32. for plugin, datalist in datadict.items():
  33. for data in datalist:
  34. if plugin == 'graph':
  35. result.append([plugin, package_graph_event(data.get('value')).SerializeToString()])
  36. elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'):
  37. result.append([plugin, serialize_to_lineage_event(plugin, data.get('value'))])
  38. elif plugin in ('scalar', 'tensor', 'histogram', 'image'):
  39. summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
  40. step = data.get('step')
  41. elif plugin == 'explainer':
  42. result.append([plugin, package_explain_event(data.get('value'))])
  43. if summaries:
  44. result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()])
  45. return result
  46. class WriterPool(ctx.Process):
  47. """
  48. Use a set of pooled resident processes for writing a list of file.
  49. Args:
  50. base_dir (str): The base directory to hold all the files.
  51. max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes.
  52. filedict (dict): The mapping from plugin to filename.
  53. """
  54. def __init__(self, base_dir, max_file_size, **filedict) -> None:
  55. super().__init__()
  56. self._base_dir, self._filedict = base_dir, filedict
  57. self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None
  58. self._max_file_size = max_file_size
  59. self.start()
  60. def run(self):
  61. with ctx.Pool(min(ctx.cpu_count(), 32)) as pool:
  62. deq = deque()
  63. while True:
  64. while deq and deq[0].ready():
  65. for plugin, data in deq.popleft().get():
  66. self._write(plugin, data)
  67. if not self._queue.empty():
  68. action, data = self._queue.get()
  69. if action == 'WRITE':
  70. deq.append(pool.apply_async(_pack_data, (data, time.time())))
  71. elif action == 'FLUSH':
  72. self._flush()
  73. elif action == 'END':
  74. break
  75. for result in deq:
  76. for plugin, data in result.get():
  77. self._write(plugin, data)
  78. self._close()
  79. @property
  80. def _writers(self):
  81. """Get the writers in the subprocess."""
  82. if self._writers_ is not None:
  83. return self._writers_
  84. self._writers_ = []
  85. for plugin, filename in self._filedict.items():
  86. filepath = os.path.join(self._base_dir, filename)
  87. if plugin == 'summary':
  88. self._writers_.append(SummaryWriter(filepath, self._max_file_size))
  89. elif plugin == 'lineage':
  90. self._writers_.append(LineageWriter(filepath, self._max_file_size))
  91. elif plugin == 'explainer':
  92. self._writers_.append(ExplainWriter(filepath, self._max_file_size))
  93. return self._writers_
  94. def _write(self, plugin, data):
  95. """Write the data in the subprocess."""
  96. for writer in self._writers[:]:
  97. try:
  98. writer.write(plugin, data)
  99. except RuntimeError as e:
  100. logger.warning(e.args[0])
  101. self._writers.remove(writer)
  102. writer.close()
  103. def _flush(self):
  104. """Flush the writers in the subprocess."""
  105. for writer in self._writers:
  106. writer.flush()
  107. def _close(self):
  108. """Close the writers in the subprocess."""
  109. for writer in self._writers:
  110. writer.close()
  111. def write(self, data) -> None:
  112. """
  113. Write the event to file.
  114. Args:
  115. data (Optional[str, Tuple[list, int]]): The data to write.
  116. """
  117. self._queue.put(('WRITE', data))
  118. def flush(self):
  119. """Flush the writer and sync data to disk."""
  120. self._queue.put(('FLUSH', None))
  121. def close(self) -> None:
  122. """Close the writer."""
  123. self._queue.put(('END', None))
  124. self.join()