You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profiler.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import os
  11. import re
  12. from contextlib import ContextDecorator, contextmanager
  13. from functools import wraps
  14. from typing import List
  15. from weakref import WeakSet
  16. from .. import _atexit
  17. from ..core._imperative_rt.core2 import (
  18. cupti_available,
  19. disable_cupti,
  20. enable_cupti,
  21. full_sync,
  22. pop_scope,
  23. push_scope,
  24. start_profile,
  25. stop_profile,
  26. sync,
  27. )
  28. from ..logger import get_logger
  29. _running_profiler = None
  30. _living_profilers = WeakSet()
  31. class Profiler(ContextDecorator):
  32. r"""Profile graph execution in imperative mode.
  33. Args:
  34. path: default path prefix for profiler to dump.
  35. Examples:
  36. .. code-block::
  37. import megengine as mge
  38. import megengine.module as M
  39. from megengine.utils.profiler import Profiler
  40. # With Learnable Parameters
  41. profiler = Profiler()
  42. for iter in range(0, 10):
  43. # Only profile record of last iter would be saved
  44. with profiler:
  45. # your code here
  46. # Then open the profile file in chrome timeline window
  47. """
  48. CHROME_TIMELINE = "chrome_timeline.json"
  49. valid_options = {
  50. "sample_rate": 0,
  51. "profile_device": 1,
  52. "num_tensor_watch": 10,
  53. "enable_cupti": 0,
  54. }
  55. valid_formats = {"chrome_timeline.json", "memory_flow.svg"}
  56. def __init__(
  57. self,
  58. path: str = "profile",
  59. format: str = "chrome_timeline.json",
  60. formats: List[str] = None,
  61. **kwargs
  62. ) -> None:
  63. if not formats:
  64. formats = [format]
  65. assert not isinstance(formats, str), "formats excepts list, got str"
  66. for format in formats:
  67. assert format in Profiler.valid_formats, "unsupported format {}".format(
  68. format
  69. )
  70. self._path = path
  71. self._formats = formats
  72. self._options = {}
  73. for opt, optval in Profiler.valid_options.items():
  74. self._options[opt] = int(kwargs.pop(opt, optval))
  75. self._pid = "<PID>"
  76. self._dump_callback = None
  77. if self._options.get("enable_cupti", 0):
  78. if cupti_available():
  79. enable_cupti()
  80. else:
  81. get_logger().warning("CuPTI unavailable")
  82. @property
  83. def path(self):
  84. if len(self._formats) == 0:
  85. format = "<FORMAT>"
  86. elif len(self._formats) == 1:
  87. format = self._formats[0]
  88. else:
  89. format = "{" + ",".join(self._formats) + "}"
  90. return self.format_path(self._path, self._pid, format)
  91. @property
  92. def directory(self):
  93. return self._path
  94. @property
  95. def formats(self):
  96. return list(self._formats)
  97. def start(self):
  98. global _running_profiler
  99. assert _running_profiler is None
  100. _running_profiler = self
  101. self._pid = os.getpid()
  102. start_profile(self._options)
  103. return self
  104. def stop(self):
  105. global _running_profiler
  106. assert _running_profiler is self
  107. _running_profiler = None
  108. full_sync()
  109. self._dump_callback = stop_profile()
  110. self._pid = os.getpid()
  111. _living_profilers.add(self)
  112. def dump(self):
  113. if self._dump_callback is not None:
  114. if not os.path.exists(self._path):
  115. os.makedirs(self._path)
  116. if not os.path.isdir(self._path):
  117. get_logger().warning(
  118. "{} is not a directory, cannot write profiling results".format(
  119. self._path
  120. )
  121. )
  122. return
  123. for format in self._formats:
  124. path = self.format_path(self._path, self._pid, format)
  125. get_logger().info("process {} generating {}".format(self._pid, format))
  126. self._dump_callback(path, format)
  127. get_logger().info("profiling results written to {}".format(path))
  128. if os.path.getsize(path) > 64 * 1024 * 1024:
  129. get_logger().warning(
  130. "profiling results too large, maybe you are profiling multi iters,"
  131. "consider attach profiler in each iter separately"
  132. )
  133. self._dump_callback = None
  134. _living_profilers.remove(self)
  135. def format_path(self, path, pid, format):
  136. return os.path.join(path, "{}.{}".format(pid, format))
  137. def __enter__(self):
  138. self.start()
  139. def __exit__(self, val, tp, trace):
  140. self.stop()
  141. def __call__(self, func):
  142. func = super().__call__(func)
  143. func.__profiler__ = self
  144. return func
  145. def __del__(self):
  146. if self._options.get("enable_cupti", 0):
  147. if cupti_available():
  148. disable_cupti()
  149. self.dump()
  150. @contextmanager
  151. def scope(name):
  152. push_scope(name)
  153. yield
  154. pop_scope(name)
  155. def profile(*args, **kwargs):
  156. if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
  157. return Profiler()(args[0])
  158. return Profiler(*args, **kwargs)
  159. def merge_trace_events(directory: str):
  160. names = filter(
  161. lambda x: re.match(r"\d+\.chrome_timeline\.json", x), os.listdir(directory)
  162. )
  163. def load_trace_events(name):
  164. with open(os.path.join(directory, name), "r", encoding="utf-8") as f:
  165. return json.load(f)
  166. def find_metadata(content):
  167. if isinstance(content, dict):
  168. assert "traceEvents" in content
  169. content = content["traceEvents"]
  170. if len(content) == 0:
  171. return None
  172. assert content[0]["name"] == "Metadata"
  173. return content[0]["args"]
  174. contents = list(map(load_trace_events, names))
  175. metadata_list = list(map(find_metadata, contents))
  176. min_local_time = min(
  177. map(lambda x: x["localTime"], filter(lambda x: x is not None, metadata_list))
  178. )
  179. events = []
  180. for content, metadata in zip(contents, metadata_list):
  181. local_events = content["traceEvents"]
  182. if len(local_events) == 0:
  183. continue
  184. local_time = metadata["localTime"]
  185. time_shift = local_time - min_local_time
  186. for event in local_events:
  187. if "ts" in event:
  188. event["ts"] = int(event["ts"] + time_shift)
  189. events.extend(filter(lambda x: x["name"] != "Metadata", local_events))
  190. result = {
  191. "traceEvents": events,
  192. }
  193. path = os.path.join(directory, "merge.chrome_timeline.json")
  194. with open(path, "w") as f:
  195. json.dump(result, f, ensure_ascii=False, separators=(",", ":"))
  196. get_logger().info("profiling results written to {}".format(path))
  197. def is_profiling():
  198. return _running_profiler is not None
  199. def _stop_current_profiler():
  200. global _running_profiler
  201. if _running_profiler is not None:
  202. _running_profiler.stop()
  203. living_profilers = [*_living_profilers]
  204. for profiler in living_profilers:
  205. profiler.dump()
  206. _atexit(_stop_current_profiler)