From 33b5cda1da36ba3df3e5f06c8a271c274cc94c19 Mon Sep 17 00:00:00 2001 From: ougongchang Date: Sat, 27 Jun 2020 16:07:45 +0800 Subject: [PATCH] Decide whether to collect data by dataset sink mode and current step in SummaryCollector. Before, we only decide whether to collect data by current step, it will not work well in dataset sink mode, so we check to see if it's a dataset sink mode, and decide whether to collect data. --- .../train/callback/_summary_collector.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 6d5ec45d5b..cff03ca398 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -166,8 +166,11 @@ class SummaryCollector(Callback): self._has_saved_custom_data = False self._is_parse_loss_success = True self._first_step = True + self._dataset_sink_mode = True def __enter__(self): + self._first_step = True + self._dataset_sink_mode = True self._record = SummaryRecord(log_dir=self._summary_dir) return self @@ -279,15 +282,15 @@ class SummaryCollector(Callback): def step_end(self, run_context): cb_params = run_context.original_args() + if self._first_step: + # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario + self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num) if cb_params.mode == ModeEnum.TRAIN.value: - # Make sure the first step data is recorded - if not self._first_step and cb_params.cur_step_num % self._collect_freq: + if not self._is_collect_this_step(cb_params): return - self._first_step = False - if not self._has_saved_train_network: self._collect_graphs(cb_params) @@ -295,6 +298,7 @@ class SummaryCollector(Callback): self._collect_metric(cb_params) self._collect_histogram(cb_params) + self._first_step = False self._record.record(cb_params.cur_step_num) def end(self, run_context): @@ -320,6 +324,18 @@ class SummaryCollector(Callback): raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," f"but expected only one {self.__class__.__name__} instance.") + def _is_collect_this_step(self, cb_params): + """Decide whether to collect data for the current step.""" + # Make sure the first step data is recorded + if not self._first_step: + if self._dataset_sink_mode: + if cb_params.cur_epoch_num % self._collect_freq: + return False + else: + if cb_params.cur_step_num % self._collect_freq: + return False + return True + @staticmethod def _package_custom_lineage_data(custom_lineage_data): """