You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_summary_adapter.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate the summary event which conform to proto format."""
  16. import platform
  17. import time
  18. import numpy as np
  19. from PIL import Image
  20. from mindspore import log as logger
  21. from ..._checkparam import Validator
  22. from ..anf_ir_pb2 import DataType, ModelProto
  23. from ..summary_pb2 import Event
  24. # define the MindSpore image format
  25. MS_IMAGE_TENSOR_FORMAT = 'NCHW'
  26. # Set the Event mark
  27. EVENT_FILE_NAME_MARK = ".out.events.summary."
  28. # Set the init event of version and mark
  29. EVENT_FILE_INIT_VERSION_MARK = "MindSpore.Event:"
  30. EVENT_FILE_INIT_VERSION = 1
  31. F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max
  32. def get_event_file_name(prefix, suffix):
  33. """
  34. Create file name: file_prefix + EVENT_FILE_NAME_MARK + time(seconds) + "." + Hostname + file_suffix.
  35. Args:
  36. prefix (str): The prefix of file name.
  37. suffix (str): The suffix of file name.
  38. Returns:
  39. String, the name of event log file.
  40. """
  41. Validator.check_str_by_regular(prefix)
  42. Validator.check_str_by_regular(suffix)
  43. file_name = ""
  44. time_second = str(int(time.time()))
  45. hostname = platform.node()
  46. if prefix is not None:
  47. file_name = file_name + prefix
  48. file_name = file_name + EVENT_FILE_NAME_MARK + time_second + "." + hostname
  49. if suffix is not None:
  50. file_name = file_name + suffix
  51. return file_name
  52. def package_init_event():
  53. """Package the summary init event."""
  54. init_event = Event()
  55. init_event.wall_time = time.time()
  56. version = EVENT_FILE_INIT_VERSION_MARK + str(EVENT_FILE_INIT_VERSION)
  57. init_event.version = version
  58. return init_event
  59. def package_graph_event(data):
  60. """
  61. Package the summary graph event.
  62. Args:
  63. data (Bytes): Graph bytes string.
  64. Retruns:
  65. Event, event log object.
  66. """
  67. graph_event = Event()
  68. graph_event.wall_time = time.time()
  69. modelp = ModelProto()
  70. modelp.ParseFromString(data)
  71. graph_event.graph_def.CopyFrom(modelp.graph)
  72. return graph_event
  73. def package_summary_event(data_list, step, wall_time):
  74. """
  75. Package the summary to event protobuffer.
  76. Args:
  77. data_id (Number): Summary data id.
  78. step (Number): The recode step index.
  79. Returns:
  80. Summary, the summary event.
  81. """
  82. # create the event of summary
  83. summary_event = Event()
  84. summary = summary_event.summary
  85. summary_event.wall_time = wall_time
  86. summary_event.step = int(step)
  87. for value in data_list:
  88. summary_type = value["_type"]
  89. data = value["data"]
  90. tag = value["name"]
  91. logger.debug(f"Now process {summary_type} summary, tag = {tag}")
  92. summary_value = summary.value.add()
  93. summary_value.tag = tag
  94. # get the summary type and parse the tag
  95. if summary_type == 'Scalar':
  96. if not _fill_scalar_summary(tag, data, summary_value):
  97. del summary.value[-1]
  98. elif summary_type == 'Tensor':
  99. _fill_tensor_summary(tag, data, summary_value.tensor)
  100. elif summary_type == 'Image':
  101. if not _fill_image_summary(tag, data, summary_value.image, MS_IMAGE_TENSOR_FORMAT):
  102. del summary.value[-1]
  103. elif summary_type == 'Histogram':
  104. _fill_histogram_summary(tag, data, summary_value.histogram)
  105. else:
  106. # The data is invalid ,jump the data
  107. logger.error(f"Summary type({summary_type}) is error, tag = {tag}")
  108. del summary.value[-1]
  109. return summary_event
  110. def _nptype_to_prototype(np_value):
  111. """
  112. Transform the np type to proto type.
  113. Args:
  114. np_value (Type): Numpy data type.
  115. Returns:
  116. Type, proto data type.
  117. """
  118. np2pt_tbl = {
  119. np.bool_: 'DT_BOOL',
  120. np.int8: 'DT_INT8',
  121. np.int16: 'DT_INT16',
  122. np.int32: 'DT_INT32',
  123. np.int64: 'DT_INT64',
  124. np.uint8: 'DT_UINT8',
  125. np.uint16: 'DT_UINT16',
  126. np.uint32: 'DT_UINT32',
  127. np.uint64: 'DT_UINT64',
  128. np.float16: 'DT_FLOAT16',
  129. np.float: 'DT_FLOAT64',
  130. np.float32: 'DT_FLOAT32',
  131. np.float64: 'DT_FLOAT64',
  132. None: 'DT_UNDEFINED'
  133. }
  134. np_type = None
  135. if np_value is None:
  136. logger.error("The numpy value is none")
  137. else:
  138. np_type = np_value.dtype.type
  139. proto = np2pt_tbl.get(np_type, None)
  140. if proto is None:
  141. raise TypeError("No match for proto data type.")
  142. return proto
  143. def _fill_scalar_summary(tag: str, np_value, summary):
  144. """
  145. Package the scalar summary.
  146. Args:
  147. tag (str): Summary tag describe.
  148. np_value (Object): Scalary object.
  149. Returns:
  150. Summary, return scalar summary content.
  151. """
  152. logger.debug(f"Set({tag}) the scalar summary value")
  153. if np_value.size == 1:
  154. # is scalar
  155. summary.scalar_value = np_value.item()
  156. return True
  157. if np_value.size > 1:
  158. logger.warning(
  159. f"The tensor is not a single scalar, tag = {tag}, ndim = {np_value.ndim}, shape = {np_value.shape}")
  160. summary.scalar_value = next(np_value.flat).item()
  161. return True
  162. logger.error(f"There no values inside tensor, tag = {tag}, size = {np_value.size}")
  163. return False
  164. def _fill_tensor_summary(tag: str, np_value, summary_tensor):
  165. """
  166. Package the tensor summary.
  167. Args:
  168. tag (str): Summary tag describe.
  169. np_value (Type): Summary data type.
  170. summary_tensor (Tensor): The tensor of summary.
  171. Retruns:
  172. Summary, return tensor summary content.
  173. """
  174. logger.debug(f"Set({tag}) the tensor summary value")
  175. # get tensor dtype
  176. tensor_dtype = _nptype_to_prototype(np_value)
  177. summary_tensor.data_type = DataType.Value(tensor_dtype)
  178. # get the value list
  179. tensor_value_list = np_value.reshape(-1).tolist()
  180. summary_tensor.float_data.extend(tensor_value_list)
  181. # get the tensor dim
  182. for v in np_value.shape:
  183. summary_tensor.dims.append(v)
  184. return summary_tensor
  185. def _calc_histogram_bins(count):
  186. """
  187. Calculates experience-based optimal bins number for histogram.
  188. There should be enough number in each bin. So we calc bin numbers according to count. For very small count(1 -
  189. 10), we assign carefully chosen number. For large count, we tried to make sure there are 9-10 numbers in each
  190. bucket on average. Too many bins will slow down performance, so we set max number of bins to 90.
  191. Args:
  192. count (int): Valid number count for the tensor.
  193. Returns:
  194. int, number of histogram bins.
  195. """
  196. max_bins, max_per_bin = 90, 10
  197. if not count:
  198. return 1
  199. if count <= 5:
  200. return 2
  201. if count <= 10:
  202. return 3
  203. if count <= 880:
  204. # note that math.ceil(881/10) + 1 equals 90
  205. return count // max_per_bin + 1
  206. return max_bins
  207. def _fill_histogram_summary(tag: str, np_value: np.ndarray, summary) -> None:
  208. """
  209. Package the histogram summary.
  210. Args:
  211. tag (str): Summary tag describe.
  212. np_value (np.ndarray): Summary data.
  213. summary (summary_pb2.Summary.Histogram): Summary histogram data.
  214. """
  215. logger.debug(f"Set({tag}) the histogram summary value")
  216. # Default bucket for tensor with no valid data.
  217. ma_value = np.ma.masked_invalid(np_value)
  218. total, valid = np_value.size, ma_value.count()
  219. invalids = []
  220. for isfn in np.isnan, np.isposinf, np.isneginf:
  221. if total - valid > sum(invalids):
  222. count = np.count_nonzero(isfn(np_value))
  223. invalids.append(count)
  224. else:
  225. invalids.append(0)
  226. summary.count = total
  227. summary.nan_count, summary.pos_inf_count, summary.neg_inf_count = invalids
  228. if not valid:
  229. logger.warning(f'There are no valid values in the ndarray(size={total}, shape={np_value.shape})')
  230. # summary.{min, max, sum} are 0s by default, no need to explicitly set
  231. else:
  232. # BUG: max of a masked array with dtype np.float16 returns inf
  233. # See numpy issue#15077
  234. if issubclass(np_value.dtype.type, np.floating):
  235. summary.min = ma_value.min(fill_value=np.PINF)
  236. summary.max = ma_value.max(fill_value=np.NINF)
  237. if summary.min < F32_MIN or summary.max > F32_MAX:
  238. logger.warning(f'Values({summary.min}, {summary.max}) are too large, '
  239. f'you may encounter some undefined behaviours hereafter.')
  240. else:
  241. summary.min = ma_value.min()
  242. summary.max = ma_value.max()
  243. summary.sum = ma_value.sum(dtype=np.float64)
  244. bins = _calc_histogram_bins(valid)
  245. first_edge, last_edge = summary.min, summary.max
  246. if not first_edge < last_edge:
  247. first_edge -= 0.5
  248. last_edge += 0.5
  249. bins = np.linspace(first_edge, last_edge, bins + 1, dtype=np_value.dtype)
  250. hists, edges = np.histogram(np_value, bins=bins)
  251. for hist, edge1, edge2 in zip(hists, edges, edges[1:]):
  252. bucket = summary.buckets.add()
  253. bucket.width = edge2 - edge1
  254. bucket.count = hist
  255. bucket.left = edge1
  256. def _fill_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
  257. """
  258. Package the image summary.
  259. Args:
  260. tag (str): Summary tag describe.
  261. np_value (Type): Summary data type.
  262. summary_image (Tensor): The tensor of summary.
  263. input_format (str): Data sort order index. Default: 'NCHW'.
  264. Returns:
  265. Summary, return image summary content.
  266. """
  267. logger.debug(f"Set({tag}) the image summary value")
  268. if np_value.ndim != 4 or np_value.shape[1] not in (1, 3):
  269. logger.error(f"The value is not Image, tag = {tag}, ndim = {np_value.ndim}, shape={np_value.shape}")
  270. return False
  271. if np_value.ndim != len(input_format):
  272. logger.error(
  273. f"The tensor with dim({np_value.ndim}) can't convert the format({input_format}) because dim not same")
  274. return False
  275. # convert the tensor format
  276. tensor = _convert_image_format(np_value, input_format)
  277. # convert the tensor dtype
  278. # Do not assume that user passes in values in [0, 255], use data type to detect
  279. scale_factor = 1
  280. if tensor.dtype == np.uint8:
  281. scale_factor = 1
  282. elif np.max(tensor) <= 1 and np.min(tensor) >= 0:
  283. scale_factor = 255
  284. tensor = tensor.astype(np.float32)
  285. tensor = (tensor * scale_factor).astype(np.uint8)
  286. # create the image summary
  287. height, width, channel, image_string = _make_image(tensor)
  288. summary_image.height = height
  289. summary_image.width = width
  290. summary_image.colorspace = channel
  291. summary_image.encoded_image = image_string
  292. return True
  293. def _make_image(tensor, rescale=1):
  294. """
  295. Convert a numpy representation of an image to Image protobuf.
  296. Args:
  297. tensor (Tensor): The image data.
  298. rescale (Number): The rescale value. Default: 1.
  299. Returns:
  300. (Number, Number, Number, Bytes), return the height, width, channel, image string .
  301. """
  302. height, width, channel = tensor.shape
  303. scaled_height = int(height * rescale)
  304. scaled_width = int(width * rescale)
  305. image = Image.fromarray(tensor)
  306. image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS)
  307. import io
  308. output = io.BytesIO()
  309. image.save(output, format='PNG')
  310. image_string = output.getvalue()
  311. output.close()
  312. return height, width, channel, image_string
  313. def _convert_image_format(np_tensor, input_format, out_format='HWC'):
  314. """
  315. Convert the image format.
  316. Args:
  317. np_tensor (Tensor): The image data.
  318. input_format (str): Input data format.
  319. out_format (str): The output data format. Default: 'HWC'.
  320. Returns:
  321. Tensor, return format image.
  322. """
  323. input_format = input_format.upper()
  324. # convert the NCHW
  325. if input_format != 'NCHW':
  326. index = [input_format.find(c) for c in 'NCHW']
  327. tensor_nchw = np_tensor.transpose(index)
  328. else:
  329. tensor_nchw = np_tensor
  330. # make grid to expand N
  331. tensor_chw = _make_canvas_for_imgs(tensor_nchw)
  332. # convert to out format
  333. out_index = ['CHW'.find(c) for c in out_format]
  334. out_tensor = tensor_chw.transpose(out_index)
  335. return out_tensor
  336. def _make_canvas_for_imgs(tensor, col_imgs=8):
  337. """
  338. Expand the N, show imgs on a canvs.
  339. Args:
  340. tensor (Tensor): The canvas value.
  341. col_imgs (Number): The image colume number. Default: 8.
  342. Returns:
  343. Tensor, retrun canvas of image.
  344. """
  345. # expand the N1HW to N3HW
  346. if tensor.shape[1] == 1:
  347. tensor = np.concatenate([tensor, tensor, tensor], 1)
  348. # expand the N
  349. n = tensor.shape[0]
  350. h = tensor.shape[2]
  351. w = tensor.shape[3]
  352. cols = min(n, col_imgs)
  353. rows = int(np.ceil(float(n) / cols))
  354. # creat the canvas: expand the n
  355. out_canvas = np.zeros((3, h * rows, w * cols))
  356. i = 0
  357. for y in range(rows):
  358. for x in range(cols):
  359. if i >= n:
  360. break
  361. out_canvas[:, y * h:(y + 1) * h, x * w:(x + 1) * w] = tensor[i]
  362. i = i + 1
  363. return out_canvas