Browse Source

Modify collecting graph and dataset graph to step end stage

We collect graph and dataset graph in begin stage before,
If there compile graph fail in GPU, we also collect graph
and dataset graph to summary dir, it will confuse user.

So we collect graph and dataset graph in step end stage now,
If there compile graph fail, we will not collect graph and dataset
graph.
tags/v0.7.0-beta
ougongchang 5 years ago
parent
commit
1dafb2c6f5
1 changed files with 14 additions and 19 deletions
  1. +14
    -19
      mindspore/train/callback/_summary_collector.py

+ 14
- 19
mindspore/train/callback/_summary_collector.py View File

@@ -182,7 +182,7 @@ class SummaryCollector(Callback):
self._custom_lineage_data = custom_lineage_data self._custom_lineage_data = custom_lineage_data


self._temp_optimizer = None self._temp_optimizer = None
self._has_saved_train_network = False
self._has_saved_graph = False
self._has_saved_custom_data = False self._has_saved_custom_data = False
self._is_parse_loss_success = True self._is_parse_loss_success = True
self._first_step = True self._first_step = True
@@ -287,32 +287,30 @@ class SummaryCollector(Callback):
'but got `{cb_params.mode}` mode.') 'but got `{cb_params.mode}` mode.')


self._record.set_mode(cb_params.mode) self._record.set_mode(cb_params.mode)
if cb_params.mode == ModeEnum.TRAIN.value:
# Note: if model.init is not executed then the computed graph will not be obtained here
# The purpose of recording the graph here was to collect_freq if it was set to a large size,
# but also want to see the graph as soon after compilation.
self._collect_graphs(cb_params)


self._collect_dataset_graph(cb_params)
if cb_params.mode == ModeEnum.TRAIN.value:
if self._collect_tensor_freq is None: if self._collect_tensor_freq is None:
default_tensor_summary_limit = 20 default_tensor_summary_limit = 20
total_step = cb_params.epoch_num * cb_params.batch_num total_step = cb_params.epoch_num * cb_params.batch_num
self._collect_tensor_freq = max(self._collect_freq, total_step // default_tensor_summary_limit) self._collect_tensor_freq = max(self._collect_freq, total_step // default_tensor_summary_limit)


if self._custom_lineage_data and not self._has_saved_custom_data:
packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data)
self._record.add_value('custom_lineage_data', 'custom_lineage_data', packaged_custom_data)
self._has_saved_custom_data = True

# There's nothing special about setting step to 0 here, just to satisfy the interface call
self._record.record(step=0)

def step_end(self, run_context): def step_end(self, run_context):
cb_params = run_context.original_args() cb_params = run_context.original_args()
if cb_params.mode != ModeEnum.TRAIN.value: if cb_params.mode != ModeEnum.TRAIN.value:
return return
if not self._has_saved_train_network:

if not self._has_saved_graph:
self._collect_graphs(cb_params) self._collect_graphs(cb_params)
self._collect_dataset_graph(cb_params)
self._has_saved_graph = True
self._record.record(cb_params.cur_step_num)

if self._custom_lineage_data and not self._has_saved_custom_data:
packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data)
self._record.add_value('custom_lineage_data', 'custom_lineage_data', packaged_custom_data)
self._has_saved_custom_data = True
self._record.record(cb_params.cur_step_num)

if self._first_step: if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num
@@ -327,14 +325,12 @@ class SummaryCollector(Callback):
elif current % self._collect_freq == 0: elif current % self._collect_freq == 0:
self._collect_at_step_end(cb_params, lambda plugin: plugin != PluginEnum.TENSOR.value) self._collect_at_step_end(cb_params, lambda plugin: plugin != PluginEnum.TENSOR.value)



def _collect_at_step_end(self, cb_params, plugin_filter): def _collect_at_step_end(self, cb_params, plugin_filter):
self._collect_input_data(cb_params) self._collect_input_data(cb_params)
self._collect_metric(cb_params) self._collect_metric(cb_params)
self._collect_histogram(cb_params) self._collect_histogram(cb_params)
self._record.record(cb_params.cur_step_num, plugin_filter=plugin_filter) self._record.record(cb_params.cur_step_num, plugin_filter=plugin_filter)



def end(self, run_context): def end(self, run_context):
cb_params = run_context.original_args() cb_params = run_context.original_args()
if cb_params.mode == ModeEnum.TRAIN.value: if cb_params.mode == ModeEnum.TRAIN.value:
@@ -428,7 +424,6 @@ class SummaryCollector(Callback):
if graph_proto is None: if graph_proto is None:
return return


self._has_saved_train_network = True
self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto) self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)


def _collect_metric(self, cb_params): def _collect_metric(self, cb_params):


Loading…
Cancel
Save