You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_lineage.py 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This module is used to collect lineage information of model training."""
  16. import json
  17. import os
  18. import numpy as np
  19. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, LineageErrors
  20. from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError,
  21. LineageParamRunContextError)
  22. from mindinsight.lineagemgr.common.log import logger as log
  23. from mindinsight.lineagemgr.common.utils import make_directory, try_except
  24. from mindinsight.lineagemgr.common.validator.model_parameter import EvalParameter
  25. from mindinsight.lineagemgr.common.validator.validate import (validate_eval_run_context, validate_file_path,
  26. validate_int_params,
  27. validate_raise_exception,
  28. validate_user_defined_info)
  29. from mindinsight.utils.exceptions import MindInsightException
  30. from ._summary_record import LineageSummary
  31. from .base import Metadata
  32. try:
  33. from mindspore.common.tensor import Tensor
  34. from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep
  35. from mindspore.nn import Cell, Optimizer
  36. from mindspore.nn.loss.loss import _Loss
  37. from mindspore.dataset.engine import Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, \
  38. VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset
  39. import mindspore.dataset as ds
  40. except (ImportError, ModuleNotFoundError):
  41. log.warning('MindSpore Not Found!')
  42. class TrainLineage(Callback):
  43. """
  44. Collect lineage of a training job.
  45. Args:
  46. summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which
  47. is used to record the summary value(see mindspore.train.summary.SummaryRecord),
  48. or a log dir(as a `str`) to be passed to `LineageSummary` to create
  49. a lineage summary recorder. It should be noted that instead of making
  50. use of summary_record to record lineage info directly, we obtain
  51. log dir from it then create a new summary file to write lineage info.
  52. raise_exception (bool): Whether to raise exception when error occurs in
  53. TrainLineage. If True, raise exception. If False, catch exception
  54. and continue. Default: False.
  55. user_defined_info (dict): User defined information. Only flatten dict with
  56. str key and int/float/str value is supported. Default: None.
  57. Raises:
  58. MindInsightException: If validating parameter fails.
  59. LineageLogError: If recording lineage information fails.
  60. Examples:
  61. >>> from mindinsight.lineagemgr import TrainLineage
  62. >>> from mindspore.train.callback import ModelCheckpoint, SummaryStep
  63. >>> from mindspore.train.summary import SummaryRecord
  64. >>> model = Model(train_network)
  65. >>> model_ckpt = ModelCheckpoint(directory='/dir/to/save/model/')
  66. >>> summary_writer = SummaryRecord(log_dir='./')
  67. >>> summary_callback = SummaryStep(summary_writer, flush_step=2)
  68. >>> lineagemgr = TrainLineage(summary_record=summary_writer)
  69. >>> model.train(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr])
  70. """
  71. def __init__(self,
  72. summary_record,
  73. raise_exception=False,
  74. user_defined_info=None):
  75. super(TrainLineage, self).__init__()
  76. try:
  77. validate_raise_exception(raise_exception)
  78. self.raise_exception = raise_exception
  79. if isinstance(summary_record, str):
  80. # make directory if not exist
  81. self.lineage_log_dir = make_directory(summary_record)
  82. else:
  83. summary_log_path = summary_record.full_file_name
  84. validate_file_path(summary_log_path)
  85. self.lineage_log_dir = os.path.dirname(summary_log_path)
  86. self.lineage_summary = LineageSummary(self.lineage_log_dir)
  87. self.initial_learning_rate = None
  88. self.user_defined_info = user_defined_info
  89. if user_defined_info:
  90. validate_user_defined_info(user_defined_info)
  91. except MindInsightException as err:
  92. log.error(err)
  93. if raise_exception:
  94. raise
  95. @try_except(log)
  96. def begin(self, run_context):
  97. """
  98. Initialize the training progress when the training job begins.
  99. Args:
  100. run_context (RunContext): It contains all lineage information,
  101. see mindspore.train.callback.RunContext.
  102. Raises:
  103. MindInsightException: If validating parameter fails.
  104. """
  105. log.info('Initialize training lineage collection...')
  106. if self.user_defined_info:
  107. self.lineage_summary.record_user_defined_info(self.user_defined_info)
  108. if not isinstance(run_context, RunContext):
  109. error_msg = f'Invalid TrainLineage run_context.'
  110. log.error(error_msg)
  111. raise LineageParamRunContextError(error_msg)
  112. run_context_args = run_context.original_args()
  113. if not self.initial_learning_rate:
  114. optimizer = run_context_args.get('optimizer')
  115. if optimizer and not isinstance(optimizer, Optimizer):
  116. log.error("The parameter optimizer is invalid. It should be an instance of "
  117. "mindspore.nn.optim.optimizer.Optimizer.")
  118. raise MindInsightException(error=LineageErrors.PARAM_OPTIMIZER_ERROR,
  119. message=LineageErrorMsg.PARAM_OPTIMIZER_ERROR.value)
  120. if optimizer:
  121. log.info('Obtaining initial learning rate...')
  122. self.initial_learning_rate = AnalyzeObject.analyze_optimizer(optimizer)
  123. log.debug('initial_learning_rate: %s', self.initial_learning_rate)
  124. else:
  125. network = run_context_args.get('train_network')
  126. optimizer = AnalyzeObject.get_optimizer_by_network(network)
  127. self.initial_learning_rate = AnalyzeObject.analyze_optimizer(optimizer)
  128. log.debug('initial_learning_rate: %s', self.initial_learning_rate)
  129. # get train dataset graph
  130. train_dataset = run_context_args.get('train_dataset')
  131. dataset_graph_dict = ds.serialize(train_dataset)
  132. dataset_graph_json_str = json.dumps(dataset_graph_dict, indent=2)
  133. dataset_graph_dict = json.loads(dataset_graph_json_str)
  134. log.info('Logging dataset graph...')
  135. try:
  136. self.lineage_summary.record_dataset_graph(dataset_graph=dataset_graph_dict)
  137. except Exception as error:
  138. error_msg = f'Dataset graph log error in TrainLineage begin: {error}'
  139. log.error(error_msg)
  140. raise LineageLogError(error_msg)
  141. log.info('Dataset graph logged successfully.')
  142. @try_except(log)
  143. def end(self, run_context):
  144. """
  145. Collect lineage information when the training job ends.
  146. Args:
  147. run_context (RunContext): It contains all lineage information,
  148. see mindspore.train.callback.RunContext.
  149. Raises:
  150. LineageLogError: If recording lineage information fails.
  151. """
  152. log.info('Start to collect training lineage...')
  153. if not isinstance(run_context, RunContext):
  154. error_msg = f'Invalid TrainLineage run_context.'
  155. log.error(error_msg)
  156. raise LineageParamRunContextError(error_msg)
  157. run_context_args = run_context.original_args()
  158. train_lineage = dict()
  159. train_lineage = AnalyzeObject.get_network_args(
  160. run_context_args, train_lineage
  161. )
  162. train_dataset = run_context_args.get('train_dataset')
  163. callbacks = run_context_args.get('list_callback')
  164. list_callback = getattr(callbacks, '_callbacks', [])
  165. log.info('Obtaining model files...')
  166. ckpt_file_path, _ = AnalyzeObject.get_file_path(list_callback)
  167. train_lineage[Metadata.learning_rate] = self.initial_learning_rate
  168. train_lineage[Metadata.epoch] = run_context_args.get('epoch_num')
  169. train_lineage[Metadata.step_num] = run_context_args.get('cur_step_num')
  170. train_lineage[Metadata.parallel_mode] = run_context_args.get('parallel_mode')
  171. train_lineage[Metadata.device_num] = run_context_args.get('device_number')
  172. train_lineage[Metadata.batch_size] = run_context_args.get('batch_num')
  173. model_path_dict = {
  174. 'ckpt': ckpt_file_path
  175. }
  176. train_lineage[Metadata.model_path] = json.dumps(model_path_dict)
  177. log.info('Calculating model size...')
  178. train_lineage[Metadata.model_size] = AnalyzeObject.get_model_size(
  179. ckpt_file_path
  180. )
  181. log.debug('model_size: %s', train_lineage[Metadata.model_size])
  182. log.info('Analyzing dataset object...')
  183. train_lineage = AnalyzeObject.analyze_dataset(train_dataset, train_lineage, 'train')
  184. log.info('Logging lineage information...')
  185. try:
  186. self.lineage_summary.record_train_lineage(train_lineage)
  187. except IOError as error:
  188. error_msg = f'End error in TrainLineage: {error}'
  189. log.error(error_msg)
  190. raise LineageLogError(error_msg)
  191. except Exception as error:
  192. error_msg = f'End error in TrainLineage: {error}'
  193. log.error(error_msg)
  194. log.error('Fail to log the lineage of the training job.')
  195. raise LineageLogError(error_msg)
  196. log.info('The lineage of the training job has logged successfully.')
  197. class EvalLineage(Callback):
  198. """
  199. Collect lineage of an evaluation job.
  200. Args:
  201. summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which
  202. is used to record the summary value(see mindspore.train.summary.SummaryRecord),
  203. or a log dir(as a `str`) to be passed to `LineageSummary` to create
  204. a lineage summary recorder. It should be noted that instead of making
  205. use of summary_record to record lineage info directly, we obtain
  206. log dir from it then create a new summary file to write lineage info.
  207. raise_exception (bool): Whether to raise exception when error occurs in
  208. EvalLineage. If True, raise exception. If False, catch exception
  209. and continue. Default: False.
  210. user_defined_info (dict): User defined information. Only flatten dict with
  211. str key and int/float/str value is supported. Default: None.
  212. Raises:
  213. MindInsightException: If validating parameter fails.
  214. LineageLogError: If recording lineage information fails.
  215. Examples:
  216. >>> from mindinsight.lineagemgr import EvalLineage
  217. >>> from mindspore.train.callback import ModelCheckpoint, SummaryStep
  218. >>> from mindspore.train.summary import SummaryRecord
  219. >>> model = Model(train_network)
  220. >>> model_ckpt = ModelCheckpoint(directory='/dir/to/save/model/')
  221. >>> summary_writer = SummaryRecord(log_dir='./')
  222. >>> summary_callback = SummaryStep(summary_writer, flush_step=2)
  223. >>> lineagemgr = EvalLineage(summary_record=summary_writer)
  224. >>> model.eval(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr])
  225. """
  226. def __init__(self,
  227. summary_record,
  228. raise_exception=False,
  229. user_defined_info=None):
  230. super(EvalLineage, self).__init__()
  231. try:
  232. validate_raise_exception(raise_exception)
  233. self.raise_exception = raise_exception
  234. if isinstance(summary_record, str):
  235. # make directory if not exist
  236. self.lineage_log_dir = make_directory(summary_record)
  237. else:
  238. summary_log_path = summary_record.full_file_name
  239. validate_file_path(summary_log_path)
  240. self.lineage_log_dir = os.path.dirname(summary_log_path)
  241. self.lineage_summary = LineageSummary(self.lineage_log_dir)
  242. self.user_defined_info = user_defined_info
  243. if self.user_defined_info:
  244. validate_user_defined_info(self.user_defined_info)
  245. except MindInsightException as err:
  246. log.error(err)
  247. if raise_exception:
  248. raise
  249. @try_except(log)
  250. def end(self, run_context):
  251. """
  252. Collect lineage information when the training job ends.
  253. Args:
  254. run_context (RunContext): It contains all lineage information,
  255. see mindspore.train.callback.RunContext.
  256. Raises:
  257. MindInsightException: If validating parameter fails.
  258. LineageLogError: If recording lineage information fails.
  259. """
  260. if self.user_defined_info:
  261. self.lineage_summary.record_user_defined_info(self.user_defined_info)
  262. if not isinstance(run_context, RunContext):
  263. error_msg = f'Invalid EvalLineage run_context.'
  264. log.error(error_msg)
  265. raise LineageParamRunContextError(error_msg)
  266. run_context_args = run_context.original_args()
  267. validate_eval_run_context(EvalParameter, run_context_args)
  268. valid_dataset = run_context_args.get('valid_dataset')
  269. eval_lineage = dict()
  270. metrics = run_context_args.get('metrics')
  271. eval_lineage[Metadata.metrics] = json.dumps(metrics)
  272. eval_lineage[Metadata.step_num] = run_context_args.get('cur_step_num')
  273. log.info('Analyzing dataset object...')
  274. eval_lineage = AnalyzeObject.analyze_dataset(valid_dataset, eval_lineage, 'valid')
  275. log.info('Logging evaluation job lineage...')
  276. try:
  277. self.lineage_summary.record_evaluation_lineage(eval_lineage)
  278. except IOError as error:
  279. error_msg = f'End error in EvalLineage: {error}'
  280. log.error(error_msg)
  281. log.error('Fail to log the lineage of the evaluation job.')
  282. raise LineageLogError(error_msg)
  283. except Exception as error:
  284. error_msg = f'End error in EvalLineage: {error}'
  285. log.error(error_msg)
  286. log.error('Fail to log the lineage of the evaluation job.')
  287. raise LineageLogError(error_msg)
  288. log.info('The lineage of the evaluation job has logged successfully.')
  289. class AnalyzeObject:
  290. """Analyze class object in MindSpore."""
  291. @staticmethod
  292. def get_optimizer_by_network(network):
  293. """
  294. Get optimizer by analyzing network.
  295. Args:
  296. network (Cell): See mindspore.nn.Cell.
  297. Returns:
  298. Optimizer, an Optimizer object.
  299. """
  300. optimizer = None
  301. net_args = vars(network) if network else {}
  302. net_cell = net_args.get('_cells') if net_args else {}
  303. for _, value in net_cell.items():
  304. if isinstance(value, Optimizer):
  305. optimizer = value
  306. break
  307. return optimizer
  308. @staticmethod
  309. def get_loss_fn_by_network(network):
  310. """
  311. Get loss function by analyzing network.
  312. Args:
  313. network (Cell): See mindspore.nn.Cell.
  314. Returns:
  315. Loss_fn, a Cell object.
  316. """
  317. loss_fn = None
  318. inner_cell_list = []
  319. net_args = vars(network) if network else {}
  320. net_cell = net_args.get('_cells') if net_args else {}
  321. for _, value in net_cell.items():
  322. if isinstance(value, Cell) and \
  323. not isinstance(value, Optimizer):
  324. inner_cell_list.append(value)
  325. while inner_cell_list:
  326. inner_net_args = vars(inner_cell_list[0])
  327. inner_net_cell = inner_net_args.get('_cells')
  328. for value in inner_net_cell.values():
  329. if isinstance(value, _Loss):
  330. loss_fn = value
  331. break
  332. if isinstance(value, Cell):
  333. inner_cell_list.append(value)
  334. if loss_fn:
  335. break
  336. inner_cell_list.pop(0)
  337. return loss_fn
  338. @staticmethod
  339. def get_backbone_network(network):
  340. """
  341. Get the name of backbone network.
  342. Args:
  343. network (Cell): The train network.
  344. Returns:
  345. str, the name of the backbone network.
  346. """
  347. backbone_name = None
  348. has_network = False
  349. network_key = 'network'
  350. backbone_key = '_backbone'
  351. net_args = vars(network) if network else {}
  352. net_cell = net_args.get('_cells') if net_args else {}
  353. for key, value in net_cell.items():
  354. if key == network_key:
  355. network = value
  356. has_network = True
  357. break
  358. if has_network:
  359. while hasattr(network, network_key):
  360. network = getattr(network, network_key)
  361. if hasattr(network, backbone_key):
  362. backbone = getattr(network, backbone_key)
  363. backbone_name = type(backbone).__name__
  364. if backbone_name is None and network is not None:
  365. backbone_name = type(network).__name__
  366. return backbone_name
  367. @staticmethod
  368. def analyze_optimizer(optimizer):
  369. """
  370. Analyze Optimizer, a Cell object of MindSpore.
  371. In this way, we can obtain the following attributes:
  372. learning_rate (float),
  373. weight_decay (float),
  374. momentum (float),
  375. weights (float).
  376. Args:
  377. optimizer (Optimizer): See mindspore.nn.optim.Optimizer.
  378. Returns:
  379. float, the learning rate that the optimizer adopted.
  380. """
  381. learning_rate = None
  382. if isinstance(optimizer, Optimizer):
  383. learning_rate = getattr(optimizer, 'learning_rate', None)
  384. if learning_rate:
  385. learning_rate = learning_rate.default_input
  386. # Get the real learning rate value
  387. if isinstance(learning_rate, Tensor):
  388. learning_rate = learning_rate.asnumpy()
  389. if learning_rate.ndim == 0:
  390. learning_rate = np.atleast_1d(learning_rate)
  391. learning_rate = list(learning_rate)
  392. elif isinstance(learning_rate, float):
  393. learning_rate = [learning_rate]
  394. return learning_rate[0] if learning_rate else None
  395. @staticmethod
  396. def analyze_dataset(dataset, lineage_dict, dataset_type):
  397. """
  398. Analyze Dataset, a Dataset object of MindSpore.
  399. In this way, we can obtain the following attributes:
  400. dataset_path (str),
  401. train_dataset_size (int),
  402. valid_dataset_size (int),
  403. batch_size (int)
  404. Args:
  405. dataset (Dataset): See mindspore.dataengine.datasets.Dataset.
  406. lineage_dict (dict): A dict contains lineage metadata.
  407. dataset_type (str): Dataset type, train or valid.
  408. Returns:
  409. dict, the lineage metadata.
  410. """
  411. batch_num = dataset.get_dataset_size()
  412. batch_size = dataset.get_batch_size()
  413. if batch_num is not None:
  414. validate_int_params(batch_num, 'dataset_batch_num')
  415. validate_int_params(batch_num, 'dataset_batch_size')
  416. log.debug('dataset_batch_num: %d', batch_num)
  417. log.debug('dataset_batch_size: %d', batch_size)
  418. dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset)
  419. if dataset_path and os.path.isfile(dataset_path):
  420. dataset_path, _ = os.path.split(dataset_path)
  421. dataset_size = int(batch_num * batch_size)
  422. if dataset_type == 'train':
  423. lineage_dict[Metadata.train_dataset_path] = dataset_path
  424. lineage_dict[Metadata.train_dataset_size] = dataset_size
  425. elif dataset_type == 'valid':
  426. lineage_dict[Metadata.valid_dataset_path] = dataset_path
  427. lineage_dict[Metadata.valid_dataset_size] = dataset_size
  428. return lineage_dict
  429. def get_dataset_path(self, output_dataset):
  430. """
  431. Get dataset path of MindDataset object.
  432. Args:
  433. output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset,
  434. VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]):
  435. See mindspore.dataengine.datasets.Dataset.
  436. Returns:
  437. str, dataset path.
  438. """
  439. dataset_dir_set = (ImageFolderDatasetV2, MnistDataset, Cifar10Dataset,
  440. Cifar100Dataset, VOCDataset, CelebADataset)
  441. dataset_file_set = (MindDataset, ManifestDataset)
  442. dataset_files_set = (TFRecordDataset, TextFileDataset)
  443. if isinstance(output_dataset, dataset_file_set):
  444. return output_dataset.dataset_file
  445. if isinstance(output_dataset, dataset_dir_set):
  446. return output_dataset.dataset_dir
  447. if isinstance(output_dataset, dataset_files_set):
  448. return output_dataset.dataset_files[0]
  449. return self.get_dataset_path(output_dataset.input[0])
  450. @staticmethod
  451. def get_dataset_path_wrapped(dataset):
  452. """
  453. A wrapper for obtaining dataset path.
  454. Args:
  455. dataset (Union[MindDataset, Dataset]): See
  456. mindspore.dataengine.datasets.Dataset.
  457. Returns:
  458. str, dataset path.
  459. """
  460. dataset_path = None
  461. if isinstance(dataset, Dataset):
  462. try:
  463. dataset_path = AnalyzeObject().get_dataset_path(dataset)
  464. except IndexError:
  465. dataset_path = None
  466. dataset_path = validate_file_path(dataset_path, allow_empty=True)
  467. return dataset_path
  468. @staticmethod
  469. def get_file_path(list_callback):
  470. """
  471. Get ckpt_file_name and summary_log_path from MindSpore callback list.
  472. Args:
  473. list_callback (list[Callback]): The MindSpore training Callback list.
  474. Returns:
  475. tuple, contains ckpt_file_name and summary_log_path.
  476. """
  477. ckpt_file_path = None
  478. summary_log_path = None
  479. for callback in list_callback:
  480. if isinstance(callback, ModelCheckpoint):
  481. ckpt_file_path = callback.latest_ckpt_file_name
  482. if isinstance(callback, SummaryStep):
  483. summary_log_path = callback.summary_file_name
  484. if ckpt_file_path:
  485. validate_file_path(ckpt_file_path)
  486. ckpt_file_path = os.path.realpath(ckpt_file_path)
  487. if summary_log_path:
  488. validate_file_path(summary_log_path)
  489. summary_log_path = os.path.realpath(summary_log_path)
  490. return ckpt_file_path, summary_log_path
  491. @staticmethod
  492. def get_file_size(file_path):
  493. """
  494. Get the file size.
  495. Args:
  496. file_path (str): The file path.
  497. Returns:
  498. int, the file size.
  499. """
  500. try:
  501. return os.path.getsize(file_path)
  502. except (OSError, IOError) as error:
  503. error_msg = f"Error when get model file size: {error}"
  504. log.error(error_msg)
  505. raise LineageGetModelFileError(error_msg)
  506. @staticmethod
  507. def get_model_size(ckpt_file_path):
  508. """
  509. Get model the total size of the model file and the checkpoint file.
  510. Args:
  511. ckpt_file_path (str): The checkpoint file path.
  512. Returns:
  513. int, the total file size.
  514. """
  515. if ckpt_file_path:
  516. ckpt_file_path = os.path.realpath(ckpt_file_path)
  517. ckpt_file_size = AnalyzeObject.get_file_size(ckpt_file_path)
  518. else:
  519. ckpt_file_size = 0
  520. return ckpt_file_size
  521. @staticmethod
  522. def get_network_args(run_context_args, train_lineage):
  523. """
  524. Get the parameters related to the network,
  525. such as optimizer, loss function.
  526. Args:
  527. run_context_args (dict): It contains all information of the training job.
  528. train_lineage (dict): A dict contains lineage metadata.
  529. Returns:
  530. dict, the lineage metadata.
  531. """
  532. network = run_context_args.get('train_network')
  533. optimizer = run_context_args.get('optimizer')
  534. if not optimizer:
  535. optimizer = AnalyzeObject.get_optimizer_by_network(network)
  536. loss_fn = run_context_args.get('loss_fn')
  537. if not loss_fn:
  538. loss_fn = AnalyzeObject.get_loss_fn_by_network(network)
  539. loss = None
  540. else:
  541. loss = run_context_args.get('net_outputs')
  542. if loss:
  543. log.info('Calculating loss...')
  544. loss_numpy = loss.asnumpy()
  545. loss = float(np.atleast_1d(loss_numpy)[0])
  546. log.debug('loss: %s', loss)
  547. train_lineage[Metadata.loss] = loss
  548. else:
  549. train_lineage[Metadata.loss] = None
  550. # Analyze classname of optimizer, loss function and training network.
  551. train_lineage[Metadata.optimizer] = type(optimizer).__name__ \
  552. if optimizer else None
  553. train_lineage[Metadata.train_network] = AnalyzeObject.get_backbone_network(network)
  554. train_lineage[Metadata.loss_function] = type(loss_fn).__name__ \
  555. if loss_fn else None
  556. return train_lineage