You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_summary_adapter.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate the summary event which conform to proto format."""
  16. import time
  17. import socket
  18. from enum import Enum, unique
  19. import numpy as np
  20. from PIL import Image
  21. from mindspore import log as logger
  22. from ..summary_pb2 import Event
  23. from ..anf_ir_pb2 import ModelProto, DataType
  24. from ..._checkparam import _check_str_by_regular
  25. # define the MindSpore image format
  26. MS_IMAGE_TENSOR_FORMAT = 'NCHW'
  27. # Set the Event mark
  28. EVENT_FILE_NAME_MARK = ".out.events.summary."
  29. # Set the init event of version and mark
  30. EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:"
  31. EVENT_FILE_INIT_VERSION = 1
  32. # cache the summary data dict
  33. # {id: SummaryData}
  34. # |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
  35. g_summary_data_dict = {}
  36. def save_summary_data(data_id, data):
  37. """Save the global summary cache."""
  38. global g_summary_data_dict
  39. g_summary_data_dict[data_id] = data
  40. def del_summary_data(data_id):
  41. """Save the global summary cache."""
  42. global g_summary_data_dict
  43. if data_id in g_summary_data_dict:
  44. del g_summary_data_dict[data_id]
  45. else:
  46. logger.warning("Can't del the data because data_id(%r) "
  47. "does not have data in g_summary_data_dict", data_id)
  48. def get_summary_data(data_id):
  49. """Save the global summary cache."""
  50. ret = None
  51. global g_summary_data_dict
  52. if data_id in g_summary_data_dict:
  53. ret = g_summary_data_dict.get(data_id)
  54. else:
  55. logger.warning("The data_id(%r) does not have data in g_summary_data_dict", data_id)
  56. return ret
  57. @unique
  58. class SummaryType(Enum):
  59. """
  60. Summary type.
  61. Args:
  62. SCALAR (Number): Summary Scalar enum.
  63. TENSOR (Number): Summary TENSOR enum.
  64. IMAGE (Number): Summary image enum.
  65. GRAPH (Number): Summary graph enum.
  66. HISTOGRAM (Number): Summary histogram enum.
  67. INVALID (Number): Unknow type.
  68. """
  69. SCALAR = 1 # Scalar summary
  70. TENSOR = 2 # Tensor summary
  71. IMAGE = 3 # Image summary
  72. GRAPH = 4 # graph
  73. HISTOGRAM = 5 # Histogram Summary
  74. INVALID = 0xFF # unknow type
  75. def get_event_file_name(prefix, suffix):
  76. """
  77. Create file name: file_prefix + EVENT_FILE_NAME_MARK + time(seconds) + "." + Hostname + file_suffix.
  78. Args:
  79. prefix (str): The prefix of file name.
  80. suffix (str): The suffix of file name.
  81. Returns:
  82. String, the name of event log file.
  83. """
  84. _check_str_by_regular(prefix)
  85. _check_str_by_regular(suffix)
  86. file_name = ""
  87. time_second = str(int(time.time()))
  88. hostname = socket.gethostname()
  89. if prefix is not None:
  90. file_name = file_name + prefix
  91. file_name = file_name + EVENT_FILE_NAME_MARK + time_second + "." + hostname
  92. if suffix is not None:
  93. file_name = file_name + suffix
  94. return file_name
  95. def package_init_event():
  96. """Package the summary init event."""
  97. init_event = Event()
  98. init_event.wall_time = time.time()
  99. version = EVENT_FILE_INIT_VERSION_MARK + str(EVENT_FILE_INIT_VERSION)
  100. init_event.version = version
  101. return init_event
  102. def package_graph_event(data):
  103. """
  104. Package the summary graph event.
  105. Args:
  106. data (Bytes): Graph bytes string.
  107. Retruns:
  108. Event, event log object.
  109. """
  110. graph_event = Event()
  111. graph_event.wall_time = time.time()
  112. modelp = ModelProto()
  113. modelp.ParseFromString(data)
  114. graph_event.graph_def.CopyFrom(modelp.graph)
  115. return graph_event
  116. def package_summary_event(data_id, step):
  117. """
  118. Package the summary to event protobuffer.
  119. Args:
  120. data_id (Number): Summary data id.
  121. step (Number): The recode step index.
  122. Returns:
  123. Summary, the summary event.
  124. """
  125. data_list = get_summary_data(data_id)
  126. if data_list is None:
  127. logger.error("The step(%r) does not have record data.", step)
  128. del_summary_data(data_id)
  129. # create the event of summary
  130. summary_event = Event()
  131. summary = summary_event.summary
  132. for value in data_list:
  133. tag = value["name"]
  134. data = value["data"]
  135. summary_type = value["type"]
  136. # get the summary type and parse the tag
  137. if summary_type is SummaryType.SCALAR:
  138. logger.debug("Now process Scalar summary, tag = %r", tag)
  139. summary_value = summary.value.add()
  140. summary_value.tag = tag
  141. summary_value.scalar_value = _get_scalar_summary(tag, data)
  142. elif summary_type is SummaryType.TENSOR:
  143. logger.debug("Now process Tensor summary, tag = %r", tag)
  144. summary_value = summary.value.add()
  145. summary_value.tag = tag
  146. summary_tensor = summary_value.tensor
  147. _get_tensor_summary(tag, data, summary_tensor)
  148. elif summary_type is SummaryType.IMAGE:
  149. logger.debug("Now process Image summary, tag = %r", tag)
  150. summary_value = summary.value.add()
  151. summary_value.tag = tag
  152. summary_image = summary_value.image
  153. _get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT)
  154. elif summary_type is SummaryType.HISTOGRAM:
  155. logger.debug("Now process Histogram summary, tag = %r", tag)
  156. summary_value = summary.value.add()
  157. summary_value.tag = tag
  158. summary_histogram = summary_value.histogram
  159. _fill_histogram_summary(tag, data, summary_histogram)
  160. else:
  161. # The data is invalid ,jump the data
  162. logger.error("Summary type is error, tag = %r", tag)
  163. continue
  164. summary_event.wall_time = time.time()
  165. summary_event.step = int(step)
  166. return summary_event
  167. def _nptype_to_prototype(np_value):
  168. """
  169. Transform the np type to proto type.
  170. Args:
  171. np_value (Type): Numpy data type.
  172. Returns:
  173. Type, proto data type.
  174. """
  175. np2pt_tbl = {
  176. np.bool_: 'DT_BOOL',
  177. np.int8: 'DT_INT8',
  178. np.int16: 'DT_INT16',
  179. np.int32: 'DT_INT32',
  180. np.int64: 'DT_INT64',
  181. np.uint8: 'DT_UINT8',
  182. np.uint16: 'DT_UINT16',
  183. np.uint32: 'DT_UINT32',
  184. np.uint64: 'DT_UINT64',
  185. np.float16: 'DT_FLOAT16',
  186. np.float: 'DT_FLOAT64',
  187. np.float32: 'DT_FLOAT32',
  188. np.float64: 'DT_FLOAT64',
  189. None: 'DT_UNDEFINED'
  190. }
  191. np_type = None
  192. if np_value is None:
  193. logger.error("The numpy value is none")
  194. else:
  195. np_type = np_value.dtype.type
  196. proto = np2pt_tbl.get(np_type, None)
  197. if proto is None:
  198. raise TypeError("No match for proto data type.")
  199. return proto
  200. def _get_scalar_summary(tag: str, np_value):
  201. """
  202. Package the scalar summary.
  203. Args:
  204. tag (str): Summary tag describe.
  205. np_value (Object): Scalary object.
  206. Returns:
  207. Summary, return scalar summary content.
  208. """
  209. logger.debug("Set(%r) the scalar summary value", tag)
  210. if np_value.ndim == 0:
  211. # is scalar
  212. scalar_value = np_value.item()
  213. elif np_value.ndim == 1:
  214. # Because now GE can't providesumm the real shape info to convert the Tensor
  215. # So consider the dim = 1, shape = (1,) tensor is scalar
  216. scalar_value = np_value[0]
  217. if np_value.shape != (1,):
  218. logger.error("The tensor is not Scalar, tag = %r, Value = %r", tag, np_value)
  219. else:
  220. np_list = np_value.reshape(-1).tolist()
  221. scalar_value = np_list[0]
  222. logger.error("The value is not Scalar, tag = %r, Value = %r", tag, np_value)
  223. logger.debug("The tag(%r) value is: %r", tag, scalar_value)
  224. return scalar_value
  225. def _get_tensor_summary(tag: str, np_value, summary_tensor):
  226. """
  227. Package the tensor summary.
  228. Args:
  229. tag (str): Summary tag describe.
  230. np_value (Type): Summary data type.
  231. summary_tensor (Tensor): The tensor of summary.
  232. Retruns:
  233. Summary, return tensor summary content.
  234. """
  235. logger.debug("Set(%r) the tensor summary value", tag)
  236. # get tensor dtype
  237. tensor_dtype = _nptype_to_prototype(np_value)
  238. summary_tensor.data_type = DataType.Value(tensor_dtype)
  239. # get the value list
  240. tensor_value_list = np_value.reshape(-1).tolist()
  241. summary_tensor.float_data.extend(tensor_value_list)
  242. # get the tensor dim
  243. for v in np_value.shape:
  244. summary_tensor.dims.append(v)
  245. return summary_tensor
  246. def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None:
  247. """
  248. Package the histogram summary.
  249. Args:
  250. tag (str): Summary tag describe.
  251. np_value (np.array): Summary data.
  252. summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data.
  253. """
  254. logger.debug("Set(%r) the histogram summary value", tag)
  255. # Default bucket for tensor with no valid data.
  256. default_bucket_left = -0.5
  257. default_bucket_width = 1.0
  258. if np_value.size == 0:
  259. bucket = summary_histogram.buckets.add()
  260. bucket.left = default_bucket_left
  261. bucket.width = default_bucket_width
  262. bucket.count = 0
  263. summary_histogram.nan_count = 0
  264. summary_histogram.pos_inf_count = 0
  265. summary_histogram.neg_inf_count = 0
  266. summary_histogram.max = 0
  267. summary_histogram.min = 0
  268. summary_histogram.sum = 0
  269. summary_histogram.count = 0
  270. return
  271. summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value))
  272. summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value))
  273. summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value))
  274. summary_histogram.count = np_value.size
  275. masked_value = np.ma.masked_invalid(np_value)
  276. tensor_max = masked_value.max()
  277. tensor_min = masked_value.min()
  278. tensor_sum = masked_value.sum()
  279. # No valid value in tensor.
  280. if tensor_max is np.ma.masked:
  281. bucket = summary_histogram.buckets.add()
  282. bucket.left = default_bucket_left
  283. bucket.width = default_bucket_width
  284. bucket.count = 0
  285. summary_histogram.max = np.nan
  286. summary_histogram.min = np.nan
  287. summary_histogram.sum = 0
  288. return
  289. counts, edges = np.histogram(np_value, bins='auto', range=(tensor_min, tensor_max))
  290. for ind, count in enumerate(counts):
  291. bucket = summary_histogram.buckets.add()
  292. bucket.left = edges[ind]
  293. bucket.width = edges[ind + 1] - edges[ind]
  294. bucket.count = count
  295. summary_histogram.max = tensor_max
  296. summary_histogram.min = tensor_min
  297. summary_histogram.sum = tensor_sum
  298. def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
  299. """
  300. Package the image summary.
  301. Args:
  302. tag (str): Summary tag describe.
  303. np_value (Type): Summary data type.
  304. summary_image (Tensor): The tensor of summary.
  305. input_format (str): Data sort order index. Default: 'NCHW'.
  306. Returns:
  307. Summary, return image summary content.
  308. """
  309. logger.debug("Set(%r) the image summary value", tag)
  310. if np_value.ndim != 4:
  311. logger.error("The value is not Image, tag = %r, Value = %r", tag, np_value)
  312. # convert the tensor format
  313. tensor = _convert_image_format(np_value, input_format)
  314. # convert the tensor dtype
  315. # Do not assume that user passes in values in [0, 255], use data type to detect
  316. scale_factor = 1
  317. if tensor.dtype == np.uint8:
  318. scale_factor = 1
  319. elif np.max(tensor) <= 1 and np.min(tensor) >= 0:
  320. scale_factor = 255
  321. tensor = tensor.astype(np.float32)
  322. tensor = (tensor * scale_factor).astype(np.uint8)
  323. # create the image summary
  324. height, width, channel, image_string = _make_image(tensor)
  325. summary_image.height = height
  326. summary_image.width = width
  327. summary_image.colorspace = channel
  328. summary_image.encoded_image = image_string
  329. return summary_image
  330. def _make_image(tensor, rescale=1):
  331. """
  332. Convert a numpy representation of an image to Image protobuf.
  333. Args:
  334. tensor (Tensor): The image data.
  335. rescale (Number): The rescale value. Default: 1.
  336. Returns:
  337. (Number, Number, Number, Bytes), return the height, width, channel, image string .
  338. """
  339. height, width, channel = tensor.shape
  340. scaled_height = int(height * rescale)
  341. scaled_width = int(width * rescale)
  342. image = Image.fromarray(tensor)
  343. image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS)
  344. import io
  345. output = io.BytesIO()
  346. image.save(output, format='PNG')
  347. image_string = output.getvalue()
  348. output.close()
  349. return height, width, channel, image_string
  350. def _convert_image_format(np_tensor, input_format, out_format='HWC'):
  351. """
  352. Convert the image format.
  353. Args:
  354. np_tensor (Tensor): The image data.
  355. input_format (str): Input data format.
  356. out_format (str): The output data format. Default: 'HWC'.
  357. Returns:
  358. Tensor, return format image.
  359. """
  360. out_tensor = None
  361. if np_tensor.ndim != len(input_format):
  362. logger.error("The tensor(%r) can't convert the format(%r) because dim not same",
  363. np_tensor, input_format)
  364. return out_tensor
  365. input_format = input_format.upper()
  366. if len(input_format) == 4:
  367. # convert the NCHW
  368. if input_format != 'NCHW':
  369. index = [input_format.find(c) for c in 'NCHW']
  370. tensor_nchw = np_tensor.transpose(index)
  371. else:
  372. tensor_nchw = np_tensor
  373. # make grid to expand N
  374. tensor_chw = _make_canvas_for_imgs(tensor_nchw)
  375. # convert to out format
  376. out_index = ['CHW'.find(c) for c in out_format]
  377. out_tensor = tensor_chw.transpose(out_index)
  378. else:
  379. logger.error("Don't support the format(%r) convert", input_format)
  380. return out_tensor
  381. def _make_canvas_for_imgs(tensor, col_imgs=8):
  382. """
  383. Expand the N, show imgs on a canvs.
  384. Args:
  385. tensor (Tensor): The canvas value.
  386. col_imgs (Number): The image colume number. Default: 8.
  387. Returns:
  388. Tensor, retrun canvas of image.
  389. """
  390. # expand the N1HW to N3HW
  391. out_canvas = None
  392. if tensor.shape[1] == 1:
  393. tensor = np.concatenate([tensor, tensor, tensor], 1)
  394. # check the tensor format
  395. if tensor.ndim != 4 or tensor.shape[1] != 3:
  396. logger.error("The image tensor(%r) is not 'NCHW' format", tensor)
  397. return out_canvas
  398. # expand the N
  399. n = tensor.shape[0]
  400. h = tensor.shape[2]
  401. w = tensor.shape[3]
  402. cols = min(n, col_imgs)
  403. rows = int(np.ceil(float(n) / cols))
  404. # creat the canvas: expand the n
  405. out_canvas = np.zeros((3, h * rows, w * cols))
  406. i = 0
  407. for y in range(rows):
  408. for x in range(cols):
  409. if i >= n:
  410. break
  411. out_canvas[:, y * h:(y + 1) * h, x * w:(x + 1) * w] = tensor[i]
  412. i = i + 1
  413. return out_canvas