Browse Source

Fix collecting bert network name faild in MindInsight lineage.

1. collect the origin network in model, and set it to cb_params
2. collect the origin network name in SummaryCollector
3. Update the SummaryCollector API Doc
tags/v0.6.0-beta
ougongchang 5 years ago
parent
commit
336fca14bc
2 changed files with 5 additions and 26 deletions
  1. +3
    -26
      mindspore/train/callback/_summary_collector.py
  2. +2
    -0
      mindspore/train/model.py

+ 3
- 26
mindspore/train/callback/_summary_collector.py View File

@@ -73,7 +73,8 @@ class SummaryCollector(Callback):
summary_dir (str): The collected data will be persisted to this directory. summary_dir (str): The collected data will be persisted to this directory.
If the directory does not exist, it will be created automatically. If the directory does not exist, it will be created automatically.
collect_freq (int): Set the frequency of data collection, it should be greater then zero, collect_freq (int): Set the frequency of data collection, it should be greater then zero,
and the unit is `step`. Default: 10. The first step will be recorded at any time.
and the unit is `step`. Default: 10. If a frequency is set, we will collect data
at (current steps % freq) == 0, and the first step will be collected at any time.
It is important to note that if the data sink mode is used, the unit will become the `epoch`. It is important to note that if the data sink mode is used, the unit will become the `epoch`.
It is not recommended to collect data too frequently, which can affect performance. It is not recommended to collect data too frequently, which can affect performance.
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None. collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
@@ -593,7 +594,7 @@ class SummaryCollector(Callback):
else: else:
train_lineage[LineageMetadata.learning_rate] = None train_lineage[LineageMetadata.learning_rate] = None
train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network)
train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__


loss_fn = self._get_loss_fn(cb_params) loss_fn = self._get_loss_fn(cb_params)
train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
@@ -750,30 +751,6 @@ class SummaryCollector(Callback):


return ckpt_file_path return ckpt_file_path


@staticmethod
def _get_backbone(network):
"""
Get the name of backbone network.

Args:
network (Cell): The train network.

Returns:
Union[str, None], If parse success, will return the name of the backbone network, else return None.
"""
backbone_name = None
backbone_key = '_backbone'

for _, cell in network.cells_and_names():
if hasattr(cell, backbone_key):
backbone_network = getattr(cell, backbone_key)
backbone_name = type(backbone_network).__name__

if backbone_name is None and network is not None:
backbone_name = type(network).__name__

return backbone_name

@staticmethod @staticmethod
def _get_loss_fn(cb_params): def _get_loss_fn(cb_params):
""" """


+ 2
- 0
mindspore/train/model.py View File

@@ -355,6 +355,7 @@ class Model:
cb_params.train_dataset = train_dataset cb_params.train_dataset = train_dataset
cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None cb_params.train_dataset_element = None
cb_params.network = self._network
ms_role = os.getenv("MS_ROLE") ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"): if ms_role in ("MS_PSERVER", "MS_SCHED"):
epoch = 1 epoch = 1
@@ -660,6 +661,7 @@ class Model:
cb_params.mode = "eval" cb_params.mode = "eval"
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.network = self._network


self._eval_network.set_train(mode=False) self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval' self._eval_network.phase = 'eval'


Loading…
Cancel
Save