Browse Source

!2616 Decide whether to collect data by dataset sink mode and current step in SummaryCollector

Merge pull request !2616 from ougongchang/fix_collect_freq
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
19f79cd744
1 changed files with 20 additions and 4 deletions
  1. +20
    -4
      mindspore/train/callback/_summary_collector.py

+ 20
- 4
mindspore/train/callback/_summary_collector.py View File

@@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self._has_saved_custom_data = False self._has_saved_custom_data = False
self._is_parse_loss_success = True self._is_parse_loss_success = True
self._first_step = True self._first_step = True
self._dataset_sink_mode = True


def __enter__(self): def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir) self._record = SummaryRecord(log_dir=self._summary_dir)
return self return self


@@ -279,15 +282,15 @@ class SummaryCollector(Callback):


def step_end(self, run_context): def step_end(self, run_context):
cb_params = run_context.original_args() cb_params = run_context.original_args()
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)


if cb_params.mode == ModeEnum.TRAIN.value: if cb_params.mode == ModeEnum.TRAIN.value:


# Make sure the first step data is recorded
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
if not self._is_collect_this_step(cb_params):
return return


self._first_step = False

if not self._has_saved_train_network: if not self._has_saved_train_network:
self._collect_graphs(cb_params) self._collect_graphs(cb_params)


@@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self._collect_metric(cb_params) self._collect_metric(cb_params)
self._collect_histogram(cb_params) self._collect_histogram(cb_params)


self._first_step = False
self._record.record(cb_params.cur_step_num) self._record.record(cb_params.cur_step_num)


def end(self, run_context): def end(self, run_context):
@@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.") f"but expected only one {self.__class__.__name__} instance.")


def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True

@staticmethod @staticmethod
def _package_custom_lineage_data(custom_lineage_data): def _package_custom_lineage_data(custom_lineage_data):
""" """


Loading…
Cancel
Save