diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index c76e27d699..4e18c7f2ca 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -182,6 +182,7 @@ class SummaryCollector(Callback): self._custom_lineage_data = custom_lineage_data self._temp_optimizer = None + self._has_saved_train_network = False self._has_saved_custom_data = False self._is_parse_loss_success = True self._first_step = True @@ -215,7 +216,7 @@ class SummaryCollector(Callback): @staticmethod def _check_positive(name, value, allow_none=False): """Check if the value to be int type and positive.""" - if allow_none: + if allow_none and value is None: return check_value_type(name, value, int) if value <= 0: @@ -294,8 +295,9 @@ class SummaryCollector(Callback): self._collect_dataset_graph(cb_params) if self._collect_tensor_freq is None: + default_tensor_summary_limit = 50 total_step = cb_params.epoch_num * cb_params.batch_num - self._collect_tensor_freq = max(self._collect_freq, total_step // 50) + self._collect_tensor_freq = max(self._collect_freq, total_step // default_tensor_summary_limit) if self._custom_lineage_data and not self._has_saved_custom_data: packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data) @@ -309,6 +311,8 @@ class SummaryCollector(Callback): cb_params = run_context.original_args() if cb_params.mode != ModeEnum.TRAIN.value: return + if not self._has_saved_train_network: + self._collect_graphs(cb_params) if self._first_step: # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num @@ -424,6 +428,7 @@ class SummaryCollector(Callback): if graph_proto is None: return + self._has_saved_train_network = True self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto) def _collect_metric(self, cb_params):