You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_summary_adapter.py 14 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate the summary event which conform to proto format."""
  16. import platform
  17. import time
  18. import numpy as np
  19. from PIL import Image
  20. from mindspore import log as logger
  21. from mindspore import context
  22. from mindspore.communication.management import get_rank
  23. from ..._checkparam import Validator
  24. from ..anf_ir_pb2 import DataType, ModelProto
  25. from ..summary_pb2 import Event
  26. # define the MindSpore image format
  27. MS_IMAGE_TENSOR_FORMAT = 'NCHW'
  28. # Set the Event mark
  29. EVENT_FILE_NAME_MARK = ".out.events.summary."
  30. # Set the init event of version and mark
  31. EVENT_FILE_INIT_VERSION_MARK = "MindSpore.Event:"
  32. EVENT_FILE_INIT_VERSION = 1
  33. F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max
  34. def get_event_file_name(prefix, suffix, time_second):
  35. """
  36. Create file name: file_prefix + EVENT_FILE_NAME_MARK + time(seconds) + "." + Hostname + file_suffix.
  37. Args:
  38. prefix (str): The prefix of file name.
  39. suffix (str): The suffix of file name.
  40. time_second (str): The time stamp of file name.
  41. Returns:
  42. String, the name of event log file.
  43. """
  44. Validator.check_str_by_regular(prefix)
  45. Validator.check_str_by_regular(suffix)
  46. file_name = ""
  47. hostname = platform.node()
  48. device_num = context.get_auto_parallel_context('device_num')
  49. device_id = context.get_context('device_id')
  50. if device_num > 1:
  51. # Notice:
  52. # In GPU distribute training scene, get_context('device_id') will not work,
  53. # so we use get_rank instead of get_context.
  54. device_id = get_rank()
  55. file_name = f'{file_name}{EVENT_FILE_NAME_MARK}{time_second}.{device_id}.{hostname}'
  56. if prefix is not None:
  57. file_name = prefix + file_name
  58. if suffix is not None:
  59. file_name = file_name + suffix
  60. return file_name
  61. def package_init_event():
  62. """Package the summary init event."""
  63. init_event = Event()
  64. init_event.wall_time = time.time()
  65. version = EVENT_FILE_INIT_VERSION_MARK + str(EVENT_FILE_INIT_VERSION)
  66. init_event.version = version
  67. return init_event
  68. def package_graph_event(data):
  69. """
  70. Package the summary graph event.
  71. Args:
  72. data (Bytes): Graph bytes string.
  73. Returns:
  74. Event, event log object.
  75. """
  76. graph_event = Event()
  77. graph_event.wall_time = time.time()
  78. modelp = ModelProto()
  79. modelp.ParseFromString(data)
  80. graph_event.graph_def.CopyFrom(modelp.graph)
  81. return graph_event
  82. def package_summary_event(data_list, step, wall_time):
  83. """
  84. Package the summary to event protobuffer.
  85. Args:
  86. data_list (list): Summary data list.
  87. step (Number): The recode step index.
  88. wall_time (float): The wall time.
  89. Returns:
  90. Summary, the summary event.
  91. """
  92. # create the event of summary
  93. summary_event = Event()
  94. summary = summary_event.summary
  95. summary_event.wall_time = wall_time
  96. summary_event.step = int(step)
  97. for value in data_list:
  98. summary_type = value["_type"]
  99. data = value["data"]
  100. tag = value["name"]
  101. logger.debug(f"Now process {summary_type} summary, tag = {tag}")
  102. summary_value = summary.value.add()
  103. summary_value.tag = tag
  104. # get the summary type and parse the tag
  105. if summary_type == 'Scalar':
  106. if not _fill_scalar_summary(tag, data, summary_value):
  107. del summary.value[-1]
  108. elif summary_type == 'Tensor':
  109. _fill_tensor_summary(tag, data, summary_value.tensor)
  110. elif summary_type == 'Image':
  111. if not _fill_image_summary(tag, data, summary_value.image, MS_IMAGE_TENSOR_FORMAT):
  112. del summary.value[-1]
  113. elif summary_type == 'Histogram':
  114. _fill_histogram_summary(tag, data, summary_value.histogram)
  115. else:
  116. # The data is invalid ,jump the data
  117. logger.error(f"Summary type({summary_type}) is error, tag = {tag}")
  118. del summary.value[-1]
  119. return summary_event
  120. def _nptype_to_prototype(np_value):
  121. """
  122. Transform the np type to proto type.
  123. Args:
  124. np_value (Type): Numpy data type.
  125. Returns:
  126. Type, proto data type.
  127. """
  128. np2pt_tbl = {
  129. np.bool_: 'DT_BOOL',
  130. np.int8: 'DT_INT8',
  131. np.int16: 'DT_INT16',
  132. np.int32: 'DT_INT32',
  133. np.int64: 'DT_INT64',
  134. np.uint8: 'DT_UINT8',
  135. np.uint16: 'DT_UINT16',
  136. np.uint32: 'DT_UINT32',
  137. np.uint64: 'DT_UINT64',
  138. np.float16: 'DT_FLOAT16',
  139. np.float: 'DT_FLOAT64',
  140. np.float32: 'DT_FLOAT32',
  141. np.float64: 'DT_FLOAT64',
  142. None: 'DT_UNDEFINED'
  143. }
  144. np_type = None
  145. if np_value is None:
  146. logger.error("The numpy value is none")
  147. else:
  148. np_type = np_value.dtype.type
  149. proto = np2pt_tbl.get(np_type, None)
  150. if proto is None:
  151. raise TypeError("No match for proto data type.")
  152. return proto
  153. def _fill_scalar_summary(tag: str, np_value, summary):
  154. """
  155. Package the scalar summary.
  156. Args:
  157. tag (str): Summary tag describe.
  158. np_value (Object): Scalary object.
  159. Returns:
  160. Summary, return scalar summary content.
  161. """
  162. logger.debug(f"Set({tag}) the scalar summary value")
  163. if np_value.size == 1:
  164. # is scalar
  165. summary.scalar_value = np_value.item()
  166. return True
  167. if np_value.size > 1:
  168. logger.warning(
  169. f"The tensor is not a single scalar, tag = {tag}, ndim = {np_value.ndim}, shape = {np_value.shape}")
  170. summary.scalar_value = next(np_value.flat).item()
  171. return True
  172. logger.error(f"There no values inside tensor, tag = {tag}, size = {np_value.size}")
  173. return False
  174. def _fill_tensor_summary(tag: str, np_value, summary_tensor):
  175. """
  176. Package the tensor summary.
  177. Args:
  178. tag (str): Summary tag describe.
  179. np_value (Type): Summary data type.
  180. summary_tensor (Tensor): The tensor of summary.
  181. Returns:
  182. Summary, return tensor summary content.
  183. """
  184. logger.debug(f"Set({tag}) the tensor summary value")
  185. # get tensor dtype
  186. tensor_dtype = _nptype_to_prototype(np_value)
  187. summary_tensor.data_type = DataType.Value(tensor_dtype)
  188. # get the value list
  189. tensor_value_list = np_value.reshape(-1).tolist()
  190. summary_tensor.float_data.extend(tensor_value_list)
  191. # get the tensor dim
  192. for v in np_value.shape:
  193. summary_tensor.dims.append(v)
  194. return summary_tensor
  195. def _calc_histogram_bins(count):
  196. """
  197. Calculates experience-based optimal bins number for histogram.
  198. There should be enough number in each bin. So we calc bin numbers according to count. For very small count(1 -
  199. 10), we assign carefully chosen number. For large count, we tried to make sure there are 9-10 numbers in each
  200. bucket on average. Too many bins will slow down performance, so we set max number of bins to 90.
  201. Args:
  202. count (int): Valid number count for the tensor.
  203. Returns:
  204. int, number of histogram bins.
  205. """
  206. max_bins, max_per_bin = 90, 10
  207. if not count:
  208. return 1
  209. if count <= 5:
  210. return 2
  211. if count <= 10:
  212. return 3
  213. if count <= 880:
  214. # note that math.ceil(881/10) + 1 equals 90
  215. return count // max_per_bin + 1
  216. return max_bins
  217. def _fill_histogram_summary(tag: str, np_value: np.ndarray, summary) -> None:
  218. """
  219. Package the histogram summary.
  220. Args:
  221. tag (str): Summary tag describe.
  222. np_value (np.ndarray): Summary data.
  223. summary (summary_pb2.Summary.Histogram): Summary histogram data.
  224. """
  225. logger.debug(f"Set({tag}) the histogram summary value")
  226. # Default bucket for tensor with no valid data.
  227. ma_value = np.ma.masked_invalid(np_value)
  228. total, valid = np_value.size, ma_value.count()
  229. invalids = []
  230. for isfn in np.isnan, np.isposinf, np.isneginf:
  231. if total - valid > sum(invalids):
  232. count = np.count_nonzero(isfn(np_value))
  233. invalids.append(count)
  234. else:
  235. invalids.append(0)
  236. summary.count = total
  237. summary.nan_count, summary.pos_inf_count, summary.neg_inf_count = invalids
  238. if not valid:
  239. logger.warning(f'There are no valid values in the ndarray(size={total}, shape={np_value.shape})')
  240. # summary.{min, max, sum} are 0s by default, no need to explicitly set
  241. else:
  242. # BUG: max of a masked array with dtype np.float16 returns inf
  243. # See numpy issue#15077
  244. if issubclass(np_value.dtype.type, np.floating):
  245. summary.min = ma_value.min(fill_value=np.PINF)
  246. summary.max = ma_value.max(fill_value=np.NINF)
  247. if summary.min < F32_MIN or summary.max > F32_MAX:
  248. logger.warning(f'Values({summary.min}, {summary.max}) are too large, '
  249. f'you may encounter some undefined behaviours hereafter.')
  250. else:
  251. summary.min = ma_value.min()
  252. summary.max = ma_value.max()
  253. summary.sum = ma_value.sum(dtype=np.float64)
  254. bins = _calc_histogram_bins(valid)
  255. first_edge, last_edge = summary.min, summary.max
  256. if not first_edge < last_edge:
  257. first_edge -= 0.5
  258. last_edge += 0.5
  259. bins = np.linspace(first_edge, last_edge, bins + 1, dtype=np_value.dtype)
  260. hists, edges = np.histogram(np_value, bins=bins)
  261. for hist, edge1, edge2 in zip(hists, edges, edges[1:]):
  262. bucket = summary.buckets.add()
  263. bucket.width = edge2 - edge1
  264. bucket.count = hist
  265. bucket.left = edge1
  266. def _fill_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
  267. """
  268. Package the image summary.
  269. Args:
  270. tag (str): Summary tag describe.
  271. np_value (Type): Summary data type.
  272. summary_image (Tensor): The tensor of summary.
  273. input_format (str): Data sort order index. Default: 'NCHW'.
  274. Returns:
  275. Summary, return image summary content.
  276. """
  277. logger.debug(f"Set({tag}) the image summary value")
  278. if np_value.ndim != 4 or np_value.shape[1] not in (1, 3):
  279. logger.error(f"The value is not Image, tag = {tag}, ndim = {np_value.ndim}, shape={np_value.shape}")
  280. return False
  281. if np_value.ndim != len(input_format):
  282. logger.error(
  283. f"The tensor with dim({np_value.ndim}) can't convert the format({input_format}) because dim not same")
  284. return False
  285. # convert the tensor format
  286. tensor = _convert_image_format(np_value, input_format)
  287. # convert the tensor dtype
  288. # Do not assume that user passes in values in [0, 255], use data type to detect
  289. scale_factor = 1
  290. if tensor.dtype == np.uint8:
  291. scale_factor = 1
  292. elif np.max(tensor) <= 1 and np.min(tensor) >= 0:
  293. scale_factor = 255
  294. tensor = tensor.astype(np.float32)
  295. tensor = (tensor * scale_factor).astype(np.uint8)
  296. # create the image summary
  297. height, width, channel, image_string = _make_image(tensor)
  298. summary_image.height = height
  299. summary_image.width = width
  300. summary_image.colorspace = channel
  301. summary_image.encoded_image = image_string
  302. return True
  303. def _make_image(tensor, rescale=1):
  304. """
  305. Convert a numpy representation of an image to Image protobuf.
  306. Args:
  307. tensor (Tensor): The image data.
  308. rescale (Number): The rescale value. Default: 1.
  309. Returns:
  310. (Number, Number, Number, Bytes), return the height, width, channel, image string .
  311. """
  312. height, width, channel = tensor.shape
  313. scaled_height = int(height * rescale)
  314. scaled_width = int(width * rescale)
  315. image = Image.fromarray(tensor)
  316. image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS)
  317. import io
  318. output = io.BytesIO()
  319. image.save(output, format='PNG')
  320. image_string = output.getvalue()
  321. output.close()
  322. return height, width, channel, image_string
  323. def _convert_image_format(np_tensor, input_format, out_format='HWC'):
  324. """
  325. Convert the image format.
  326. Args:
  327. np_tensor (Tensor): The image data.
  328. input_format (str): Input data format.
  329. out_format (str): The output data format. Default: 'HWC'.
  330. Returns:
  331. Tensor, return format image.
  332. """
  333. input_format = input_format.upper()
  334. # convert the NCHW
  335. if input_format != 'NCHW':
  336. index = [input_format.find(c) for c in 'NCHW']
  337. tensor_nchw = np_tensor.transpose(index)
  338. else:
  339. tensor_nchw = np_tensor
  340. # make grid to expand N
  341. tensor_chw = _make_canvas_for_imgs(tensor_nchw)
  342. # convert to out format
  343. out_index = ['CHW'.find(c) for c in out_format]
  344. out_tensor = tensor_chw.transpose(out_index)
  345. return out_tensor
  346. def _make_canvas_for_imgs(tensor, col_imgs=8):
  347. """
  348. Expand the N, show imgs on a canvs.
  349. Args:
  350. tensor (Tensor): The canvas value.
  351. col_imgs (Number): The image colume number. Default: 8.
  352. Returns:
  353. Tensor, return canvas of image.
  354. """
  355. # expand the N1HW to N3HW
  356. if tensor.shape[1] == 1:
  357. tensor = np.concatenate([tensor, tensor, tensor], 1)
  358. # expand the N
  359. n = tensor.shape[0]
  360. h = tensor.shape[2]
  361. w = tensor.shape[3]
  362. cols = min(n, col_imgs)
  363. rows = int(np.ceil(float(n) / cols))
  364. # create the canvas: expand the n
  365. out_canvas = np.zeros((3, h * rows, w * cols))
  366. i = 0
  367. for y in range(rows):
  368. for x in range(cols):
  369. if i >= n:
  370. break
  371. out_canvas[:, y * h:(y + 1) * h, x * w:(x + 1) * w] = tensor[i]
  372. i = i + 1
  373. return out_canvas