|
|
|
@@ -67,7 +67,7 @@ class SummaryCollector(Callback): |
|
|
|
SummaryCollector can help you to collect some common information. |
|
|
|
|
|
|
|
It can help you to collect loss, learning late, computational graph and so on. |
|
|
|
SummaryCollector also enables the summary operator to collect data from a summary file. |
|
|
|
SummaryCollector also enables the summary operator to collect data to summary files. |
|
|
|
|
|
|
|
Note: |
|
|
|
1. Multiple SummaryCollector instances in callback list are not allowed. |
|
|
|
@@ -367,6 +367,7 @@ class SummaryCollector(Callback): |
|
|
|
'but got `{cb_params.mode}` mode.') |
|
|
|
|
|
|
|
self._record.set_mode(cb_params.mode) |
|
|
|
self._dataset_sink_mode = cb_params.dataset_sink_mode |
|
|
|
|
|
|
|
def step_end(self, run_context): |
|
|
|
cb_params = run_context.original_args() |
|
|
|
@@ -386,8 +387,6 @@ class SummaryCollector(Callback): |
|
|
|
self._record.record(cb_params.cur_step_num) |
|
|
|
|
|
|
|
if self._first_step: |
|
|
|
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario |
|
|
|
self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num |
|
|
|
self._tensor_collect_range = self._get_tensor_collect_range(cb_params, self._dataset_sink_mode) |
|
|
|
self._collect_at_step_end(cb_params, plugin_filter=None) |
|
|
|
self._first_step = False |
|
|
|
@@ -480,34 +479,44 @@ class SummaryCollector(Callback): |
|
|
|
|
|
|
|
def _collect_input_data(self, cb_params): |
|
|
|
"""Only support to collect image data.""" |
|
|
|
if not self._collect_specified_data.get('collect_input_data'): |
|
|
|
if not self._is_allowed_to_collect_input_data(cb_params): |
|
|
|
return |
|
|
|
|
|
|
|
input_data = getattr(cb_params, 'train_dataset_element', None) |
|
|
|
if isinstance(input_data, (list, tuple)) and input_data: |
|
|
|
input_data = input_data[0] |
|
|
|
try: |
|
|
|
self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data) |
|
|
|
except (TypeError, ValueError): |
|
|
|
logger.warning('The input data of network are not image, so will not collect by SummaryCollector.') |
|
|
|
self._collect_specified_data['collect_input_data'] = False |
|
|
|
return |
|
|
|
|
|
|
|
def _is_allowed_to_collect_input_data(self, cb_params): |
|
|
|
"""Check if the input data is allowed to be collected.""" |
|
|
|
if not self._collect_specified_data.get('collect_input_data'): |
|
|
|
return False |
|
|
|
|
|
|
|
if self._dataset_sink_mode and (context.get_context('device_target') in ('Ascend', 'GPU')): |
|
|
|
logger.warning("On Ascend or GPU device, SummaryCollector is not supported to " |
|
|
|
"record input data in dataset sink mode.") |
|
|
|
self._collect_specified_data['collect_input_data'] = False |
|
|
|
return False |
|
|
|
|
|
|
|
input_data = getattr(cb_params, 'train_dataset_element', None) |
|
|
|
if not isinstance(input_data, (Tensor, list, tuple)): |
|
|
|
self._collect_specified_data['collect_input_data'] = False |
|
|
|
logger.warning("The type of input data is not Tensor/list/tuple, " |
|
|
|
"so SummaryCollector will not collect input data.") |
|
|
|
return |
|
|
|
return False |
|
|
|
|
|
|
|
if not isinstance(input_data, Tensor) and not input_data: |
|
|
|
self._collect_specified_data['collect_input_data'] = False |
|
|
|
logger.warning("The 'train_dataset_element' in cb_params is empty, " |
|
|
|
"so SummaryCollector will not record the input data.") |
|
|
|
"so SummaryCollector will not record the input data. ") |
|
|
|
return False |
|
|
|
|
|
|
|
if self._dataset_sink_mode and context.get_context('device_target') == 'Ascend': |
|
|
|
logger.warning('On Ascend device, SummaryCollector is not supported to record input data ' |
|
|
|
'in dataset sink mode.') |
|
|
|
return |
|
|
|
|
|
|
|
if isinstance(input_data, (list, tuple)) and input_data: |
|
|
|
input_data = input_data[0] |
|
|
|
try: |
|
|
|
self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data) |
|
|
|
except (TypeError, ValueError): |
|
|
|
logger.warning('The input data of network are not image, so will not collect by SummaryCollector.') |
|
|
|
self._collect_specified_data['collect_input_data'] = False |
|
|
|
return |
|
|
|
return True |
|
|
|
|
|
|
|
def _collect_dataset_graph(self, cb_params): |
|
|
|
"""Only collect train dataset graph.""" |
|
|
|
|