|
|
|
@@ -79,6 +79,7 @@ except: |
|
|
|
from ..io.model_io import ModelSaver, ModelLoader |
|
|
|
from .dataset import DataSet |
|
|
|
from .tester import Tester |
|
|
|
import logging |
|
|
|
|
|
|
|
try: |
|
|
|
import fitlog |
|
|
|
@@ -167,7 +168,11 @@ class Callback(object): |
|
|
|
@property |
|
|
|
def disabled(self): |
|
|
|
return self._disabled |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def logger(self): |
|
|
|
return getattr(self._trainer, 'logger', logging) |
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
|
""" |
|
|
|
在Train过程开始之前调用。 |
|
|
|
@@ -316,21 +321,27 @@ class CallbackManager(Callback): |
|
|
|
""" |
|
|
|
super(CallbackManager, self).__init__() |
|
|
|
# set attribute of trainer environment |
|
|
|
|
|
|
|
self._env = env |
|
|
|
self.callbacks = [] |
|
|
|
if callbacks is not None: |
|
|
|
if isinstance(callbacks, list): |
|
|
|
if all([isinstance(cb, Callback) for cb in callbacks]) is True: |
|
|
|
self.callbacks.extend(callbacks) |
|
|
|
else: |
|
|
|
obj = [not isinstance(cb, Callback) for cb in callbacks][0] |
|
|
|
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") |
|
|
|
if callbacks: |
|
|
|
self.callbacks += self.prepare_callbacks(callbacks) |
|
|
|
|
|
|
|
def prepare_callbacks(self, callbacks): |
|
|
|
if not callbacks: |
|
|
|
return [] |
|
|
|
if isinstance(callbacks, list): |
|
|
|
if all([isinstance(cb, Callback) for cb in callbacks]) is True: |
|
|
|
self.callbacks.extend(callbacks) |
|
|
|
else: |
|
|
|
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") |
|
|
|
|
|
|
|
for env_name, env_val in env.items(): |
|
|
|
for callback in self.callbacks: |
|
|
|
obj = [not isinstance(cb, Callback) for cb in callbacks][0] |
|
|
|
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") |
|
|
|
else: |
|
|
|
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") |
|
|
|
|
|
|
|
for env_name, env_val in self._env.items(): |
|
|
|
for callback in callbacks: |
|
|
|
setattr(callback, '_' + env_name, env_val) # Callback.trainer |
|
|
|
return callbacks |
|
|
|
|
|
|
|
@_transfer |
|
|
|
def on_train_begin(self): |
|
|
|
@@ -391,11 +402,12 @@ class CallbackManager(Callback): |
|
|
|
|
|
|
|
class DistCallbackManager(CallbackManager): |
|
|
|
def __init__(self, env, callbacks_all=None, callbacks_master=None): |
|
|
|
super(DistCallbackManager, self).__init__(env) |
|
|
|
assert 'trainer' in env |
|
|
|
is_master = env['trainer'].is_master |
|
|
|
self.patch_callback(callbacks_master, disabled=not is_master) |
|
|
|
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks |
|
|
|
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks |
|
|
|
self.callbacks_all = self.prepare_callbacks(callbacks_all) |
|
|
|
self.callbacks_master = self.prepare_callbacks(callbacks_master) |
|
|
|
self.callbacks = self.callbacks_all + self.callbacks_master |
|
|
|
|
|
|
|
def patch_callback(self, callbacks, disabled): |
|
|
|
@@ -944,5 +956,21 @@ class EchoCallback(Callback): |
|
|
|
|
|
|
|
|
|
|
|
class TesterCallback(Callback): |
|
|
|
def __init__(self, data, model, metrics, batch_size=16, num_workers=None): |
|
|
|
self.tester = Tester(data, model) |
|
|
|
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\ |
|
|
|
#TODO add compare & save best |
|
|
|
super(TesterCallback, self).__init__() |
|
|
|
self.tester = Tester(data, model, |
|
|
|
metrics=metrics, batch_size=batch_size, |
|
|
|
num_workers=num_workers, verbose=0) |
|
|
|
self.score = None |
|
|
|
|
|
|
|
def on_validation(self): |
|
|
|
cur_socre = self.tester.test() |
|
|
|
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( |
|
|
|
self.epoch, self.n_epochs, self.step, self.n_steps, |
|
|
|
self.tester._format_eval_results(cur_socre)) |
|
|
|
self.logger.info(eval_str) |
|
|
|
|
|
|
|
def on_train_end(self): |
|
|
|
self.logger.info('Evaluate on training ends.') |
|
|
|
self.on_validation() |