|
|
@@ -166,8 +166,11 @@ class SummaryCollector(Callback): |
|
|
self._has_saved_custom_data = False |
|
|
self._has_saved_custom_data = False |
|
|
self._is_parse_loss_success = True |
|
|
self._is_parse_loss_success = True |
|
|
self._first_step = True |
|
|
self._first_step = True |
|
|
|
|
|
self._dataset_sink_mode = True |
|
|
|
|
|
|
|
|
def __enter__(self): |
|
|
def __enter__(self): |
|
|
|
|
|
self._first_step = True |
|
|
|
|
|
self._dataset_sink_mode = True |
|
|
self._record = SummaryRecord(log_dir=self._summary_dir) |
|
|
self._record = SummaryRecord(log_dir=self._summary_dir) |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
@@ -279,15 +282,15 @@ class SummaryCollector(Callback): |
|
|
|
|
|
|
|
|
def step_end(self, run_context): |
|
|
def step_end(self, run_context): |
|
|
cb_params = run_context.original_args() |
|
|
cb_params = run_context.original_args() |
|
|
|
|
|
if self._first_step: |
|
|
|
|
|
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario |
|
|
|
|
|
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num) |
|
|
|
|
|
|
|
|
if cb_params.mode == ModeEnum.TRAIN.value: |
|
|
if cb_params.mode == ModeEnum.TRAIN.value: |
|
|
|
|
|
|
|
|
# Make sure the first step data is recorded |
|
|
|
|
|
if not self._first_step and cb_params.cur_step_num % self._collect_freq: |
|
|
|
|
|
|
|
|
if not self._is_collect_this_step(cb_params): |
|
|
return |
|
|
return |
|
|
|
|
|
|
|
|
self._first_step = False |
|
|
|
|
|
|
|
|
|
|
|
if not self._has_saved_train_network: |
|
|
if not self._has_saved_train_network: |
|
|
self._collect_graphs(cb_params) |
|
|
self._collect_graphs(cb_params) |
|
|
|
|
|
|
|
|
@@ -295,6 +298,7 @@ class SummaryCollector(Callback): |
|
|
self._collect_metric(cb_params) |
|
|
self._collect_metric(cb_params) |
|
|
self._collect_histogram(cb_params) |
|
|
self._collect_histogram(cb_params) |
|
|
|
|
|
|
|
|
|
|
|
self._first_step = False |
|
|
self._record.record(cb_params.cur_step_num) |
|
|
self._record.record(cb_params.cur_step_num) |
|
|
|
|
|
|
|
|
def end(self, run_context): |
|
|
def end(self, run_context): |
|
|
@@ -320,6 +324,18 @@ class SummaryCollector(Callback): |
|
|
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," |
|
|
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," |
|
|
f"but expected only one {self.__class__.__name__} instance.") |
|
|
f"but expected only one {self.__class__.__name__} instance.") |
|
|
|
|
|
|
|
|
|
|
|
def _is_collect_this_step(self, cb_params): |
|
|
|
|
|
"""Decide whether to collect data for the current step.""" |
|
|
|
|
|
# Make sure the first step data is recorded |
|
|
|
|
|
if not self._first_step: |
|
|
|
|
|
if self._dataset_sink_mode: |
|
|
|
|
|
if cb_params.cur_epoch_num % self._collect_freq: |
|
|
|
|
|
return False |
|
|
|
|
|
else: |
|
|
|
|
|
if cb_params.cur_step_num % self._collect_freq: |
|
|
|
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def _package_custom_lineage_data(custom_lineage_data): |
|
|
def _package_custom_lineage_data(custom_lineage_data): |
|
|
""" |
|
|
""" |
|
|
|