You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_summary_collector.py 43 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Summary collector callback."""
  16. import os
  17. import re
  18. import json
  19. from json.decoder import JSONDecodeError
  20. from importlib import import_module
  21. import numpy as np
  22. from mindspore import log as logger
  23. from mindspore import context
  24. from mindspore.common.tensor import Tensor
  25. from mindspore.common.parameter import Parameter
  26. from mindspore.train.summary.summary_record import SummaryRecord, process_export_options
  27. from mindspore.train.summary.enums import PluginEnum, ModeEnum
  28. from mindspore.train.callback import Callback, ModelCheckpoint
  29. from mindspore.train import lineage_pb2
  30. from mindspore.train.callback._dataset_graph import DatasetGraph
  31. from mindspore.nn.optim.optimizer import Optimizer
  32. from mindspore.nn.loss.loss import _Loss
  33. from mindspore.train._utils import check_value_type
  34. HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
  35. HYPER_CONFIG_LEN_LIMIT = 100000
  36. class LineageMetadata:
  37. """Initialize parameters used in model lineage management."""
  38. train_dataset_path = 'train_dataset_path'
  39. valid_dataset_path = 'valid_dataset_path'
  40. train_network = 'train_network'
  41. loss_function = 'loss_function'
  42. loss = 'loss'
  43. optimizer = 'optimizer'
  44. learning_rate = 'learning_rate'
  45. epoch = 'epoch'
  46. step_num = 'step_num'
  47. parallel_mode = 'parallel_mode'
  48. device_num = 'device_num'
  49. batch_size = 'batch_size'
  50. model_path = 'model_path'
  51. model_ckpt = 'model_ckpt'
  52. model_size = 'model_size'
  53. metrics = 'metrics'
  54. train_dataset_size = 'train_dataset_size'
  55. valid_dataset_size = 'valid_dataset_size'
  56. class SummaryCollector(Callback):
  57. """
  58. SummaryCollector can help you to collect some common information.
  59. It can help you to collect loss, learning late, computational graph and so on.
  60. SummaryCollector also enables the summary operator to collect data from a summary file.
  61. Note:
  62. 1. Multiple SummaryCollector instances in callback list are not allowed.
  63. 2. Not all information is collected at the training phase or at the eval phase.
  64. 3. SummaryCollector always record the data collected by the summary operator.
  65. 4. SummaryCollector only supports Linux systems.
  66. Args:
  67. summary_dir (str): The collected data will be persisted to this directory.
  68. If the directory does not exist, it will be created automatically.
  69. collect_freq (int): Set the frequency of data collection, it should be greater then zero,
  70. and the unit is `step`. Default: 10. If a frequency is set, we will collect data
  71. when (current steps % freq) equals to 0, and the first step will be collected at any time.
  72. It is important to note that if the data sink mode is used, the unit will become the `epoch`.
  73. It is not recommended to collect data too frequently, which can affect performance.
  74. collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
  75. By default, if set to None, all data is collected as the default behavior.
  76. You can customize the collected data with a dictionary.
  77. For example, you can set {'collect_metric': False} to control not collecting metrics.
  78. The data that supports control is shown below.
  79. - collect_metric (bool): Whether to collect training metrics, currently only the loss is collected.
  80. The first output will be treated as the loss and it will be averaged.
  81. Optional: True/False. Default: True.
  82. - collect_graph (bool): Whether to collect the computational graph. Currently, only
  83. training computational graph is collected. Optional: True/False. Default: True.
  84. - collect_train_lineage (bool): Whether to collect lineage data for the training phase,
  85. this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
  86. - collect_eval_lineage (bool): Whether to collect lineage data for the evaluation phase,
  87. this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
  88. - collect_input_data (bool): Whether to collect dataset for each training.
  89. Currently only image data is supported.
  90. If there are multiple columns of data in the dataset, the first column should be image data.
  91. Optional: True/False. Default: True.
  92. - collect_dataset_graph (bool): Whether to collect dataset graph for the training phase.
  93. Optional: True/False. Default: True.
  94. - histogram_regular (Union[str, None]): Collect weight and bias for parameter distribution page
  95. and displayed in MindInsight. This field allows regular strings to control which parameters to collect.
  96. Default: None, it means only the first five parameters are collected.
  97. It is not recommended to collect too many parameters at once, as it can affect performance.
  98. Note that if you collect too many parameters and run out of memory, the training will fail.
  99. keep_default_action (bool): This field affects the collection behavior of the 'collect_specified_data' field.
  100. Optional: True/False, Default: True.
  101. True: it means that after specified data is set, non-specified data is collected as the default behavior.
  102. False: it means that after specified data is set, only the specified data is collected,
  103. and the others are not collected.
  104. custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight
  105. lineage page. In the custom data, the type of the key supports str, and the type of value supports str, int
  106. and float. Default: None, it means there is no custom data.
  107. collect_tensor_freq (Optional[int]): The same semantics as the `collect_freq`, but controls TensorSummary only.
  108. Because TensorSummary data is too large to be compared with other summary data, this parameter is used to
  109. reduce its collection. By default, The maximum number of steps for collecting TensorSummary data is 20,
  110. but it will not exceed the number of steps for collecting other summary data.
  111. Default: None, which means to follow the behavior as described above. For example, given `collect_freq=10`,
  112. when the total steps is 600, TensorSummary will be collected 20 steps, while other summary data 61 steps,
  113. but when the total steps is 20, both TensorSummary and other summary will be collected 3 steps.
  114. Also note that when in parallel mode, the total steps will be split evenly, which will
  115. affect the number of steps TensorSummary will be collected.
  116. max_file_size (Optional[int]): The maximum size in bytes of each file that can be written to the disk.
  117. Default: None, which means no limit. For example, to write not larger than 4GB,
  118. specify `max_file_size=4 * 1024**3`.
  119. export_options (Union[None, dict]): Perform custom operations on the export data.
  120. Default: None, it means that the data is not exported.
  121. Note that the size of export files is not limited by the max_file_size.
  122. You can customize the export data with a dictionary. For example, you can set {'tensor_format': 'npy'}
  123. to export tensor as npy file. The data that supports control is shown below.
  124. - tensor_format (Union[str, None]): Customize the export tensor format. Supports ["npy", None].
  125. Default: None, it means that the tensor is not exported.
  126. - npy: export tensor as npy file.
  127. Raises:
  128. ValueError: If the parameter value is not expected.
  129. TypeError: If the parameter type is not expected.
  130. RuntimeError: If an error occurs during data collection.
  131. Examples:
  132. >>> # Simple usage:
  133. >>> from mindspore.train import Model
  134. >>> summary_collector = SummaryCollector(summary_dir='./summary_dir')
  135. >>> dataset = get_dataset('/path/to/MNIST')
  136. >>> network = LeNet5()
  137. >>> model = Model(network)
  138. >>> model.train(epoch=1, dataset=dataset, callbacks=summary_collector)
  139. >>>
  140. >>> # Do not collect metric and collect the first layer parameter, others are collected by default
  141. >>> specified={'collect_metric': False, 'histogram_regular': '^conv1.*'}
  142. >>> summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified)
  143. >>> model.train(epoch=1, dataset=dataset, callbacks=summary_collector)
  144. >>>
  145. >>> # Only collect metric, custom lineage data and record data that collected by the summary operator,
  146. >>> # others are not collected
  147. >>> specified = {'collect_metric': True}
  148. >>> summary_collector = SummaryCollector('./summary_dir',
  149. >>> collect_specified_data=specified,
  150. >>> keep_default_action=False,
  151. >>> custom_lineage_data={'version': 'resnet50_v1'}
  152. >>> )
  153. >>> model.train(epoch=1, dataset=dataset, callbacks=summary_collector)
  154. """
  155. _DEFAULT_SPECIFIED_DATA = {
  156. 'collect_metric': True,
  157. 'collect_graph': True,
  158. 'collect_train_lineage': True,
  159. 'collect_eval_lineage': True,
  160. 'collect_input_data': True,
  161. 'collect_dataset_graph': True,
  162. 'histogram_regular': None
  163. }
  164. def __init__(self,
  165. summary_dir,
  166. collect_freq=10,
  167. collect_specified_data=None,
  168. keep_default_action=True,
  169. custom_lineage_data=None,
  170. collect_tensor_freq=None,
  171. max_file_size=None,
  172. export_options=None):
  173. super(SummaryCollector, self).__init__()
  174. self._summary_dir = self._process_summary_dir(summary_dir)
  175. self._record = None
  176. self._check_positive('collect_freq', collect_freq)
  177. self._collect_freq = collect_freq
  178. self._check_positive('collect_tensor_freq', collect_tensor_freq, allow_none=True)
  179. self._collect_tensor_freq = collect_tensor_freq
  180. self._tensor_collect_range = None
  181. self._check_positive('max_file_size', max_file_size, allow_none=True)
  182. self._max_file_size = max_file_size
  183. self._export_options = process_export_options(export_options)
  184. self._check_action(keep_default_action)
  185. self._collect_specified_data = self._process_specified_data(collect_specified_data, keep_default_action)
  186. msg = f"For 'collect_specified_data' the value after processing is: {self._collect_specified_data}."
  187. logger.info(msg)
  188. self._custom_lineage_data = self._process_custom_lineage_data(custom_lineage_data)
  189. self._temp_optimizer = None
  190. self._has_saved_graph = False
  191. self._has_saved_custom_data = False
  192. self._is_parse_loss_success = True
  193. self._first_step = True
  194. self._dataset_sink_mode = True
  195. def __enter__(self):
  196. self._record = SummaryRecord(log_dir=self._summary_dir,
  197. max_file_size=self._max_file_size,
  198. raise_exception=False,
  199. export_options=self._export_options)
  200. self._first_step, self._dataset_sink_mode = True, True
  201. return self
  202. def __exit__(self, *err):
  203. self._record.close()
  204. @staticmethod
  205. def _process_summary_dir(summary_dir):
  206. """Check the summary dir, and create a new directory if it not exists."""
  207. check_value_type('summary_dir', summary_dir, str)
  208. summary_dir = summary_dir.strip()
  209. if not summary_dir:
  210. raise ValueError('For `summary_dir` the value should be a valid string of path, but got empty string.')
  211. summary_dir = os.path.realpath(summary_dir)
  212. if not os.path.exists(summary_dir):
  213. os.makedirs(summary_dir, exist_ok=True)
  214. else:
  215. if not os.path.isdir(summary_dir):
  216. raise NotADirectoryError('For `summary_dir` it should be a directory path.')
  217. return summary_dir
  218. @staticmethod
  219. def _check_positive(name, value, allow_none=False):
  220. """Check if the value to be int type and positive."""
  221. if allow_none and value is None:
  222. return
  223. check_value_type(name, value, int)
  224. if value <= 0:
  225. raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.')
  226. def _process_custom_lineage_data(self, custom_lineage_data):
  227. """
  228. Check user custom lineage data.
  229. Args:
  230. custom_lineage_data (dict): The user custom defined data.
  231. Raises:
  232. TypeError: If the type of parameters is invalid.
  233. """
  234. if custom_lineage_data is None:
  235. custom_lineage_data = {}
  236. self._check_custom_lineage_type('custom_lineage_data', custom_lineage_data)
  237. auto_custom_lineage_data = self._collect_optimizer_custom_lineage_data()
  238. self._check_custom_lineage_type('auto_custom_lineage_data', auto_custom_lineage_data)
  239. # the priority of user defined info is higher than auto collected info
  240. auto_custom_lineage_data.update(custom_lineage_data)
  241. custom_lineage_data = auto_custom_lineage_data
  242. return custom_lineage_data
  243. def _check_custom_lineage_type(self, param_name, custom_lineage):
  244. """Check custom lineage type."""
  245. check_value_type(param_name, custom_lineage, [dict, type(None)])
  246. for key, value in custom_lineage.items():
  247. check_value_type(f'{param_name} -> {key}', key, str)
  248. check_value_type(f'the value of {param_name} -> {key}', value, (int, str, float))
  249. def _collect_optimizer_custom_lineage_data(self):
  250. """Collect custom lineage data if mindoptimizer has set the hyper config"""
  251. auto_custom_lineage_data = {}
  252. hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME)
  253. if hyper_config is None:
  254. logger.debug("Hyper config is not in system environment.")
  255. return auto_custom_lineage_data
  256. if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT:
  257. logger.warning("Hyper config is too long. The length limit is %s, the length of "
  258. "hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config)))
  259. return auto_custom_lineage_data
  260. try:
  261. hyper_config = json.loads(hyper_config)
  262. except (TypeError, JSONDecodeError) as exc:
  263. logger.warning("Hyper config decode error. Detail: %s." % str(exc))
  264. return auto_custom_lineage_data
  265. custom_lineage_data = hyper_config.get("custom_lineage_data")
  266. if custom_lineage_data is None:
  267. logger.info("No custom lineage data in hyper config. Please check the custom lineage data "
  268. "if custom parameters exist in the configuration file.")
  269. auto_custom_lineage_data = custom_lineage_data if custom_lineage_data is not None else {}
  270. return auto_custom_lineage_data
  271. @staticmethod
  272. def _check_action(action):
  273. """Check action type."""
  274. check_value_type('keep_default_action', action, bool)
  275. def _process_specified_data(self, specified_data, action):
  276. """Check specified data type and value."""
  277. if specified_data is None:
  278. if action:
  279. return dict(self._DEFAULT_SPECIFIED_DATA)
  280. return dict()
  281. check_value_type('collect_specified_data', specified_data, [dict, type(None)])
  282. for param_name in specified_data:
  283. check_value_type(param_name, param_name, [str])
  284. unexpected_params = set(specified_data) - set(self._DEFAULT_SPECIFIED_DATA)
  285. if unexpected_params:
  286. raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported, '
  287. f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}')
  288. if 'histogram_regular' in specified_data:
  289. regular = specified_data.get('histogram_regular')
  290. check_value_type('histogram_regular', regular, (str, type(None)))
  291. if isinstance(regular, str):
  292. try:
  293. re.match(regular, '')
  294. except re.error as exc:
  295. raise ValueError(f'For `collect_specified_data`, the value of `histogram_regular` '
  296. f'is not a valid regular expression. Detail: {str(exc)}.')
  297. bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'}
  298. for item in bool_items:
  299. if item in specified_data:
  300. check_value_type(item, specified_data.get(item), bool)
  301. if action:
  302. result = dict(self._DEFAULT_SPECIFIED_DATA)
  303. result.update(specified_data)
  304. else:
  305. result = specified_data
  306. return result
  307. def begin(self, run_context):
  308. cb_params = run_context.original_args()
  309. self._check_callbacks(cb_params)
  310. if cb_params.mode not in ModeEnum.to_list():
  311. raise ValueError('Only support `train` (model.train) and `eval` (model.eval) mode, '
  312. 'but got `{cb_params.mode}` mode.')
  313. self._record.set_mode(cb_params.mode)
  314. def step_end(self, run_context):
  315. cb_params = run_context.original_args()
  316. if cb_params.mode != ModeEnum.TRAIN.value:
  317. return
  318. if not self._has_saved_graph:
  319. self._collect_graphs(cb_params)
  320. self._collect_dataset_graph(cb_params)
  321. self._has_saved_graph = True
  322. self._record.record(cb_params.cur_step_num)
  323. if self._custom_lineage_data and not self._has_saved_custom_data:
  324. packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data)
  325. self._record.add_value('custom_lineage_data', 'custom_lineage_data', packaged_custom_data)
  326. self._has_saved_custom_data = True
  327. self._record.record(cb_params.cur_step_num)
  328. if self._first_step:
  329. # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
  330. self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num
  331. self._tensor_collect_range = self._get_tensor_collect_range(cb_params, self._dataset_sink_mode)
  332. self._collect_at_step_end(cb_params, plugin_filter=None)
  333. self._first_step = False
  334. self._record.flush()
  335. else:
  336. current = cb_params.cur_epoch_num if self._dataset_sink_mode else cb_params.cur_step_num
  337. if current % self._collect_freq == 0 and current in self._tensor_collect_range:
  338. self._collect_at_step_end(cb_params, plugin_filter=None)
  339. elif current in self._tensor_collect_range:
  340. self._collect_at_step_end(cb_params, lambda plugin: plugin == PluginEnum.TENSOR.value)
  341. elif current % self._collect_freq == 0:
  342. self._collect_at_step_end(cb_params, lambda plugin: plugin != PluginEnum.TENSOR.value)
  343. def _get_tensor_collect_range(self, cb_params, dataset_sink_mode):
  344. """Get tensor collect range."""
  345. total_step = cb_params.epoch_num
  346. if not dataset_sink_mode:
  347. total_step *= cb_params.batch_num
  348. if self._collect_tensor_freq is not None:
  349. # `total_step + 1`: `total_step` would be a value of `cb_params.cur_step_num`.
  350. return range(0, total_step + 1, self._collect_tensor_freq)
  351. summary_to_collect = len(range(0, total_step + 1, self._collect_freq))
  352. default_tensor_summary_limit = 20
  353. if summary_to_collect > default_tensor_summary_limit:
  354. tensor_freq = total_step // (default_tensor_summary_limit - 1)
  355. if tensor_freq > 1:
  356. return range(0, total_step + 1, tensor_freq)[:default_tensor_summary_limit]
  357. # `cb_params.cur_step_num` counting from `1`, when `1` is in the range, take `1` more steps.
  358. return range(0, total_step + 1)[:default_tensor_summary_limit + 1]
  359. return range(0, total_step + 1, self._collect_freq)
  360. def _collect_at_step_end(self, cb_params, plugin_filter):
  361. self._collect_input_data(cb_params)
  362. self._collect_metric(cb_params)
  363. self._collect_histogram(cb_params)
  364. self._record.record(cb_params.cur_step_num, plugin_filter=plugin_filter)
  365. def epoch_end(self, run_context):
  366. self._record.flush()
  367. def end(self, run_context):
  368. cb_params = run_context.original_args()
  369. if cb_params.mode == ModeEnum.TRAIN.value:
  370. self._collect_train_lineage(cb_params)
  371. else:
  372. self._collect_eval_lineage(cb_params)
  373. # This is a workaround to avoid record '_summary_tensor_cache'.
  374. self._record.set_mode('eval')
  375. # There's nothing special about setting step to 0 here, just to satisfy the interface call
  376. self._record.record(step=0)
  377. def _check_callbacks(self, cb_params):
  378. """Check there if there are duplicate instances of SummaryCollector."""
  379. callbacks = cb_params.list_callback
  380. is_find = False
  381. for callback in callbacks:
  382. if type(callback).__name__ == self.__class__.__name__:
  383. if not is_find:
  384. is_find = True
  385. continue
  386. raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
  387. f"but expected only one {self.__class__.__name__} instance.")
  388. @staticmethod
  389. def _package_custom_lineage_data(custom_lineage_data):
  390. """
  391. Package user-defined lineage data into binary data.
  392. Args:
  393. custom_lineage_data (dict): User custom lineage data.
  394. Returns:
  395. UserDefinedInfo, a object of lineage_pb2.UserDefinedInfo.
  396. """
  397. user_defined_info = lineage_pb2.UserDefinedInfo()
  398. for key, value in custom_lineage_data.items():
  399. if isinstance(value, int):
  400. attr_name = "map_int32"
  401. elif isinstance(value, float):
  402. attr_name = "map_double"
  403. else:
  404. attr_name = "map_str"
  405. user_info = user_defined_info.user_info.add()
  406. getattr(user_info, attr_name)[key] = value
  407. return user_defined_info
  408. def _collect_input_data(self, cb_params):
  409. """Only support to collect image data."""
  410. if not self._collect_specified_data.get('collect_input_data'):
  411. return
  412. input_data = getattr(cb_params, 'train_dataset_element', None)
  413. if not isinstance(input_data, (Tensor, list, tuple)):
  414. self._collect_specified_data['collect_input_data'] = False
  415. logger.warning("The type of input data is not Tensor/list/tuple, "
  416. "so SummaryCollector will not collect input data.")
  417. return
  418. if not isinstance(input_data, Tensor) and not input_data:
  419. self._collect_specified_data['collect_input_data'] = False
  420. logger.warning("The 'train_dataset_element' in cb_params is empty, "
  421. "so SummaryCollector will not record the input data.")
  422. if self._dataset_sink_mode and context.get_context('device_target') == 'Ascend':
  423. logger.warning('On Ascend device, SummaryCollector is not supported to record input data '
  424. 'in dataset sink mode.')
  425. return
  426. if isinstance(input_data, (list, tuple)) and input_data:
  427. input_data = input_data[0]
  428. try:
  429. self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data)
  430. except (TypeError, ValueError):
  431. logger.warning('The input data of network are not image, so will not collect by SummaryCollector.')
  432. self._collect_specified_data['collect_input_data'] = False
  433. return
  434. def _collect_dataset_graph(self, cb_params):
  435. """Only collect train dataset graph."""
  436. if not self._collect_specified_data.get('collect_dataset_graph'):
  437. return
  438. # After analysis, we think that the validated dataset graph and the training dataset graph
  439. # should be consistent under normal scenarios, so only the training dataset graph is collected.
  440. if cb_params.mode == ModeEnum.TRAIN.value:
  441. train_dataset = cb_params.train_dataset
  442. dataset_graph = DatasetGraph()
  443. graph_bytes = dataset_graph.package_dataset_graph(train_dataset)
  444. if graph_bytes is None:
  445. return
  446. self._record.add_value('dataset_graph', 'train_dataset', graph_bytes)
  447. def _collect_graphs(self, cb_params):
  448. """Collect the graph of train network and eval network."""
  449. if not self._collect_specified_data.get('collect_graph'):
  450. return
  451. network = cb_params.train_network if cb_params.mode == ModeEnum.TRAIN.value else cb_params.eval_network
  452. graph_proto = network.get_func_graph_proto()
  453. if graph_proto is None:
  454. logger.warning("Can not get graph proto, it may not be 'GRAPH_MODE' in context currently, "
  455. "so SummaryCollector will not collect graph.")
  456. return
  457. self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)
  458. def _collect_metric(self, cb_params):
  459. """Collect metric, currently only collection Loss is supported."""
  460. if not self._collect_specified_data.get('collect_metric'):
  461. return
  462. loss = self._get_loss(cb_params)
  463. if loss is None:
  464. return
  465. try:
  466. self._record.add_value(PluginEnum.SCALAR.value, 'loss/auto', loss)
  467. except ValueError:
  468. logger.warning("The output of network is not a scalar, so SummaryCollector will not collect loss.")
  469. self._collect_specified_data['collect_metric'] = False
  470. def _get_loss(self, cb_params):
  471. """
  472. Get loss from the network output.
  473. Args:
  474. cb_params (_InternalCallbackParam): Callback parameters.
  475. Returns:
  476. Union[Tensor, None], if parse loss success, will return a Tensor value(shape is [1]), else return None.
  477. """
  478. if not self._is_parse_loss_success:
  479. # If parsing has failed before, avoid repeating it
  480. return None
  481. output = cb_params.net_outputs
  482. if output is None:
  483. logger.warning("Can not find any output by this network, so SummaryCollector will not collect loss.")
  484. self._is_parse_loss_success = False
  485. return None
  486. if isinstance(output, (int, float, Tensor)):
  487. loss = output
  488. elif isinstance(output, (list, tuple)) and output:
  489. # If the output is a list, since the default network returns loss first,
  490. # we assume that the first one is loss.
  491. loss = output[0]
  492. else:
  493. logger.warning("The output type could not be identified, so no loss was recorded in SummaryCollector.")
  494. self._is_parse_loss_success = False
  495. return None
  496. if not isinstance(loss, Tensor):
  497. loss = Tensor(loss)
  498. loss = Tensor(np.mean(loss.asnumpy()))
  499. return loss
  500. def _get_optimizer(self, cb_params):
  501. """
  502. Get optimizer from the cb_params or parse from the network.
  503. Args:
  504. cb_params (_InternalCallbackParam): Callback parameters.
  505. Returns:
  506. Union[Optimizer, None], if parse optimizer success, will return a optimizer, else return None.
  507. """
  508. # 'optimizer_failed' means find optimizer failed, so we will not collect data about optimizer.
  509. optimizer_failed = 'Failed'
  510. if self._temp_optimizer == optimizer_failed:
  511. return None
  512. if self._temp_optimizer is not None:
  513. return self._temp_optimizer
  514. optimizer = cb_params.optimizer
  515. if optimizer is None:
  516. network = cb_params.train_network if cb_params.mode == 'train' else cb_params.eval_network
  517. optimizer = self._parse_optimizer_by_network(network)
  518. if optimizer is None or not isinstance(optimizer, Optimizer):
  519. logger.warning("Can not find optimizer in network, or the optimizer does not inherit MindSpore's "
  520. "optimizer, so we will not collect data about optimizer in SummaryCollector.")
  521. optimizer = None
  522. self._temp_optimizer = optimizer if optimizer is not None else optimizer_failed
  523. return optimizer
  524. @staticmethod
  525. def _parse_optimizer_by_network(network):
  526. """Parse optimizer from network, if parse success will return a optimizer, else return None."""
  527. optimizer = None
  528. for _, cell in network.cells_and_names():
  529. if isinstance(cell, Optimizer):
  530. return cell
  531. try:
  532. optimizer = getattr(cell, 'optimizer')
  533. except AttributeError:
  534. continue
  535. if not isinstance(optimizer, Optimizer):
  536. continue
  537. # Optimizer found successfully
  538. break
  539. return optimizer
  540. def _collect_histogram(self, cb_params):
  541. """Collect histogram data, contain the parameter weight and bias."""
  542. # Note: if there is not a key named `histogram_regular` in `self._collect_specified_data`,
  543. # it means we will not collect histogram data.
  544. if 'histogram_regular' not in self._collect_specified_data:
  545. return
  546. optimizer = self._get_optimizer(cb_params)
  547. if optimizer is None:
  548. return
  549. parameters = optimizer.parameters
  550. regular = self._collect_specified_data.get('histogram_regular')
  551. if regular is not None:
  552. for parameter in parameters:
  553. if re.match(regular, parameter.name):
  554. self._record.add_value(PluginEnum.HISTOGRAM.value, parameter.name+'/auto', parameter.data)
  555. return
  556. # Note: If `histogram_regular` in `self._collect_specified_data` and the value is None,
  557. # we will collect the first five parameters.
  558. default_parameter_count = 5
  559. for parameter in parameters[:default_parameter_count]:
  560. self._record.add_value(PluginEnum.HISTOGRAM.value, parameter.name+'/auto', parameter.data)
  561. @staticmethod
  562. def _get_learning_rate(optimizer):
  563. """
  564. parse the learning rate from optimizer.
  565. Args:
  566. optimizer (Optimizer): A optimizer which inherit the MindSpore Optimizer class.
  567. Returns:
  568. Union[Tensor, None], if parse learning rate success, will return a Tensor, else return None.
  569. """
  570. learning_rate = optimizer.learning_rate
  571. if not isinstance(learning_rate, Parameter):
  572. logger.warning("The learning rate detected in the optimizer is not a Parameter type, "
  573. "so it is not recorded. Its type is %r.", type(learning_rate).__name__)
  574. return None
  575. return learning_rate.data
  576. def _collect_train_lineage(self, cb_params):
  577. """Collect train lineage data, the detail refer to lineage_pb2.TrainLineage."""
  578. if not self._collect_specified_data.get('collect_train_lineage'):
  579. return
  580. train_lineage = {}
  581. loss = self._get_loss(cb_params)
  582. if loss is not None:
  583. loss_numpy = loss.asnumpy()
  584. loss = float(np.atleast_1d(loss_numpy)[0])
  585. train_lineage[LineageMetadata.loss] = loss
  586. else:
  587. train_lineage[LineageMetadata.loss] = None
  588. optimizer = self._get_optimizer(cb_params)
  589. learning_rate = self._get_learning_rate(optimizer) if optimizer is not None else None
  590. if learning_rate is not None:
  591. train_lineage[LineageMetadata.learning_rate] = list(np.atleast_1d(learning_rate.asnumpy()))[0]
  592. else:
  593. train_lineage[LineageMetadata.learning_rate] = None
  594. train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
  595. train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__
  596. loss_fn = self._get_loss_fn(cb_params)
  597. train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
  598. train_lineage[LineageMetadata.epoch] = cb_params.epoch_num
  599. train_lineage[LineageMetadata.step_num] = cb_params.cur_step_num
  600. train_lineage[LineageMetadata.parallel_mode] = cb_params.parallel_mode
  601. train_lineage[LineageMetadata.device_num] = cb_params.device_number
  602. ckpt_file_path = self._get_ckpt_file_path(cb_params)
  603. train_lineage[LineageMetadata.model_path] = json.dumps(dict(ckpt=ckpt_file_path))
  604. model_size = os.path.getsize(ckpt_file_path) if ckpt_file_path else 0
  605. train_lineage[LineageMetadata.model_size] = model_size
  606. self._parse_dataset(cb_params, train_lineage)
  607. train_lineage_message = self._package_train_lineage_message(train_lineage)
  608. self._record.add_value(PluginEnum.TRAIN_LINEAGE.value, 'train_lineage', train_lineage_message)
  609. @staticmethod
  610. def _package_train_lineage_message(train_lineage):
  611. """
  612. Package train lineage data into binary data.
  613. Args:
  614. train_lineage (dict): The train lineage dict, refer to the attribute of `_collect_train_lineage` method.
  615. Returns:
  616. TrainLineage, a object of lineage_pb2.TrainLineage.
  617. """
  618. lineage_message = lineage_pb2.TrainLineage()
  619. if train_lineage.get(LineageMetadata.train_network) is not None:
  620. lineage_message.algorithm.network = train_lineage.get(LineageMetadata.train_network)
  621. if train_lineage.get(LineageMetadata.loss) is not None:
  622. lineage_message.algorithm.loss = train_lineage.get(LineageMetadata.loss)
  623. # Construct train_dataset message.
  624. if train_lineage.get(LineageMetadata.train_dataset_path) is not None:
  625. lineage_message.train_dataset.train_dataset_path = train_lineage.get(LineageMetadata.train_dataset_path)
  626. if train_lineage.get(LineageMetadata.train_dataset_size) is not None:
  627. lineage_message.train_dataset.train_dataset_size = train_lineage.get(LineageMetadata.train_dataset_size)
  628. # Construct model message
  629. lineage_message.model.path = train_lineage.get(LineageMetadata.model_path)
  630. lineage_message.model.size = train_lineage.get(LineageMetadata.model_size)
  631. # Construct hyper_parameters message.
  632. if train_lineage.get(LineageMetadata.learning_rate) is not None:
  633. lineage_message.hyper_parameters.learning_rate = train_lineage.get(LineageMetadata.learning_rate)
  634. if train_lineage.get(LineageMetadata.optimizer) is not None:
  635. lineage_message.hyper_parameters.optimizer = train_lineage.get(LineageMetadata.optimizer)
  636. if train_lineage.get(LineageMetadata.loss_function) is not None:
  637. lineage_message.hyper_parameters.loss_function = train_lineage.get(LineageMetadata.loss_function)
  638. if train_lineage.get(LineageMetadata.parallel_mode) is not None:
  639. lineage_message.hyper_parameters.parallel_mode = train_lineage.get(LineageMetadata.parallel_mode)
  640. lineage_message.hyper_parameters.epoch = train_lineage.get(LineageMetadata.epoch)
  641. lineage_message.hyper_parameters.device_num = train_lineage.get(LineageMetadata.device_num)
  642. lineage_message.hyper_parameters.batch_size = train_lineage.get(LineageMetadata.batch_size)
  643. return lineage_message
  644. def _parse_dataset(self, cb_params, lineage_dict):
  645. """
  646. Analyze Dataset to get the dataset path and dataset size.
  647. Args:
  648. cb_params (_InternalCallbackParam): Callback parameters.
  649. lineage_dict (dict): The lineage dict, refer to the attribute
  650. of `_collect_train_lineage` method or `_collect_eval_lineage`.
  651. Returns:
  652. dict, the lineage metadata.
  653. """
  654. dataset = cb_params.train_dataset if cb_params.mode == ModeEnum.TRAIN.value else cb_params.valid_dataset
  655. try:
  656. dataset_path = self._get_dataset_path(dataset)
  657. except IndexError:
  658. dataset_path = None
  659. if dataset_path and os.path.isfile(dataset_path):
  660. dataset_dir = os.path.dirname(dataset_path)
  661. else:
  662. dataset_dir = dataset_path
  663. batch_num = dataset.get_dataset_size()
  664. batch_size = dataset.get_batch_size()
  665. dataset_size = int(batch_num * batch_size)
  666. lineage_dict[LineageMetadata.batch_size] = batch_size
  667. if cb_params.mode == ModeEnum.TRAIN.value:
  668. lineage_dict[LineageMetadata.train_dataset_path] = dataset_dir
  669. lineage_dict[LineageMetadata.train_dataset_size] = dataset_size
  670. else:
  671. lineage_dict[LineageMetadata.valid_dataset_path] = dataset_dir
  672. lineage_dict[LineageMetadata.valid_dataset_size] = dataset_size
  673. return lineage_dict
  674. def _get_dataset_path(self, output_dataset):
  675. """
  676. Get dataset path of MindDataset object.
  677. Args:
  678. output_dataset (Union[Dataset, ImageFolderDataset, MnistDataset, Cifar10Dataset, Cifar100Dataset,
  679. VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]):
  680. Refer to mindspore.dataset.Dataset.
  681. Returns:
  682. str, dataset path.
  683. Raises:
  684. IndexError: it means get dataset path failed.
  685. """
  686. dataset_package = import_module('mindspore.dataset')
  687. dataset_dir_set = (dataset_package.ImageFolderDataset, dataset_package.MnistDataset,
  688. dataset_package.Cifar10Dataset, dataset_package.Cifar100Dataset,
  689. dataset_package.VOCDataset, dataset_package.CelebADataset)
  690. dataset_file_set = (dataset_package.MindDataset, dataset_package.ManifestDataset)
  691. dataset_files_set = (dataset_package.TFRecordDataset, dataset_package.TextFileDataset)
  692. if isinstance(output_dataset, dataset_file_set):
  693. return output_dataset.dataset_file
  694. if isinstance(output_dataset, dataset_dir_set):
  695. return output_dataset.dataset_dir
  696. if isinstance(output_dataset, dataset_files_set):
  697. return output_dataset.dataset_files[0]
  698. return self._get_dataset_path(output_dataset.children[0])
  699. @staticmethod
  700. def _get_ckpt_file_path(cb_params):
  701. """
  702. Get checkpoint file path from MindSpore callback list.
  703. Args:
  704. cb_params (_InternalCallbackParam): Callback parameters.
  705. Returns:
  706. Union[str, None], if parse success will checkpoint file absolute path, else return None.
  707. """
  708. callbacks = cb_params.list_callback
  709. ckpt_file_path = None
  710. for callback in callbacks:
  711. if isinstance(callback, ModelCheckpoint):
  712. ckpt_file_path = callback.latest_ckpt_file_name
  713. if ckpt_file_path:
  714. ckpt_file_path = os.path.realpath(ckpt_file_path)
  715. return ckpt_file_path
  716. @staticmethod
  717. def _get_loss_fn(cb_params):
  718. """
  719. Get loss function by cb_params and analyzing network.
  720. Args:
  721. cb_params (_InternalCallbackParam): Callback parameters.
  722. Returns:
  723. Union[Cell, None], a Cell object, if parse failed, will return None.
  724. """
  725. loss_fn = cb_params.loss_fn
  726. if loss_fn is not None:
  727. return loss_fn
  728. if cb_params.mode == ModeEnum.TRAIN.value:
  729. network = cb_params.train_network
  730. else:
  731. network = cb_params.eval_network
  732. for _, cell in network.cells_and_names():
  733. if isinstance(cell, _Loss):
  734. loss_fn = cell
  735. break
  736. return loss_fn
  737. def _collect_eval_lineage(self, cb_params):
  738. """Collect eval lineage data, the detail refer to lineage_pb2.EvaluationLineage."""
  739. if not self._collect_specified_data.get('collect_eval_lineage'):
  740. return
  741. eval_lineage = dict()
  742. eval_lineage[LineageMetadata.metrics] = json.dumps(cb_params.metrics)
  743. self._parse_dataset(cb_params, eval_lineage)
  744. eval_lineage_message = self._package_eval_lineage_message(eval_lineage)
  745. self._record.add_value(PluginEnum.EVAL_LINEAGE.value, 'eval_lineage', eval_lineage_message)
  746. @staticmethod
  747. def _package_eval_lineage_message(eval_lineage):
  748. """
  749. Package eval lineage data into binary data.
  750. Args:
  751. eval_lineage (dict): The eval lineage dict, refer to the attribute of `_collect_eval_lineage` method.
  752. Returns:
  753. EvaluationLineage, a object of lineage_pb2.EvaluationLineage.
  754. """
  755. lineage_message = lineage_pb2.EvaluationLineage()
  756. if eval_lineage.get(LineageMetadata.metrics) is not None:
  757. lineage_message.metric = eval_lineage.get(LineageMetadata.metrics)
  758. if eval_lineage.get(LineageMetadata.valid_dataset_path) is not None:
  759. lineage_message.valid_dataset.valid_dataset_path = eval_lineage.get(LineageMetadata.valid_dataset_path)
  760. if eval_lineage.get(LineageMetadata.valid_dataset_size) is not None:
  761. lineage_message.valid_dataset.valid_dataset_size = eval_lineage.get(LineageMetadata.valid_dataset_size)
  762. return lineage_message