|
|
|
@@ -282,9 +282,15 @@ class Model: |
|
|
|
scaling_sens /= self._device_number |
|
|
|
return scaling_sens |
|
|
|
|
|
|
|
def _exec_preprocess(self, network, is_train, phase, dataset, |
|
|
|
dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None): |
|
|
|
def _exec_preprocess(self, is_train, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None): |
|
|
|
"""Initializes dataset.""" |
|
|
|
if is_train: |
|
|
|
network = self._train_network |
|
|
|
phase = 'train' |
|
|
|
else: |
|
|
|
network = self._eval_network |
|
|
|
phase = 'eval' |
|
|
|
|
|
|
|
if dataset_sink_mode and not is_train: |
|
|
|
dataset.__loop_size__ = 1 |
|
|
|
|
|
|
|
@@ -330,9 +336,7 @@ class Model: |
|
|
|
self._train_network.set_broadcast_flag() |
|
|
|
|
|
|
|
train_dataset.__no_send__ = True |
|
|
|
train_dataset_helper, train_network = self._exec_preprocess(self._train_network, |
|
|
|
is_train=True, |
|
|
|
phase='train', |
|
|
|
train_dataset_helper, train_network = self._exec_preprocess(is_train=True, |
|
|
|
dataset=train_dataset, |
|
|
|
dataset_sink_mode=True, |
|
|
|
sink_size=sink_size) |
|
|
|
@@ -346,9 +350,7 @@ class Model: |
|
|
|
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') |
|
|
|
|
|
|
|
valid_dataset.__no_send__ = True |
|
|
|
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, |
|
|
|
is_train=False, |
|
|
|
phase='eval', |
|
|
|
valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False, |
|
|
|
dataset=valid_dataset, |
|
|
|
dataset_sink_mode=True) |
|
|
|
self._eval_network = eval_network |
|
|
|
@@ -456,9 +458,7 @@ class Model: |
|
|
|
for i in range(epoch): |
|
|
|
cb_params.cur_epoch_num = i + 1 |
|
|
|
list_callback.epoch_begin(run_context) |
|
|
|
dataset_helper, train_network = self._exec_preprocess(self._train_network, |
|
|
|
is_train=True, |
|
|
|
phase='train', |
|
|
|
dataset_helper, train_network = self._exec_preprocess(is_train=True, |
|
|
|
dataset=train_dataset, |
|
|
|
dataset_sink_mode=True, |
|
|
|
sink_size=sink_size, |
|
|
|
@@ -506,9 +506,7 @@ class Model: |
|
|
|
list_callback (Callback): Executor of callback list. Default: None. |
|
|
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None. |
|
|
|
""" |
|
|
|
dataset_helper, _ = self._exec_preprocess(self._train_network, |
|
|
|
is_train=True, |
|
|
|
phase='train', |
|
|
|
dataset_helper, _ = self._exec_preprocess(is_train=True, |
|
|
|
dataset=train_dataset, |
|
|
|
dataset_sink_mode=False, |
|
|
|
epoch_num=epoch) |
|
|
|
@@ -640,9 +638,7 @@ class Model: |
|
|
|
""" |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
|
|
|
|
dataset_helper, eval_network = self._exec_preprocess(self._eval_network, |
|
|
|
is_train=False, |
|
|
|
phase='eval', |
|
|
|
dataset_helper, eval_network = self._exec_preprocess(is_train=False, |
|
|
|
dataset=valid_dataset, |
|
|
|
dataset_sink_mode=True) |
|
|
|
self._eval_network = eval_network |
|
|
|
@@ -680,9 +676,7 @@ class Model: |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
cb_params.dataset_sink_mode = False |
|
|
|
list_callback.begin(run_context) |
|
|
|
dataset_helper, _ = self._exec_preprocess(self._eval_network, |
|
|
|
is_train=False, |
|
|
|
phase='eval', |
|
|
|
dataset_helper, _ = self._exec_preprocess(is_train=False, |
|
|
|
dataset=valid_dataset, |
|
|
|
dataset_sink_mode=False) |
|
|
|
for next_element in dataset_helper: |
|
|
|
|