* 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担tags/v0.4.10
| @@ -17,6 +17,38 @@ class Callback(object): | |||
| super(Callback, self).__init__() | |||
| self.trainer = None # 在Trainer内部被重新赋值 | |||
| # callback只读属性 | |||
| self._n_epochs = None | |||
| self._n_steps = None | |||
| self._batch_size = None | |||
| self._model = None | |||
| self._pbar = None | |||
| self._optimizer = None | |||
| @property | |||
| def n_epochs(self): | |||
| return self._n_epochs | |||
| @property | |||
| def n_steps(self): | |||
| return self._n_steps | |||
| @property | |||
| def batch_size(self): | |||
| return self._batch_size | |||
| @property | |||
| def model(self): | |||
| return self._model | |||
| @property | |||
| def pbar(self): | |||
| return self._pbar | |||
| @property | |||
| def optimizer(self): | |||
| return self._optimizer | |||
| def on_train_begin(self): | |||
| # before the main training loop | |||
| pass | |||
| @@ -101,8 +133,6 @@ def transfer(func): | |||
| def wrapper(manager, *arg): | |||
| returns = [] | |||
| for callback in manager.callbacks: | |||
| for env_name, env_value in manager.env.items(): | |||
| setattr(callback, env_name, env_value) | |||
| returns.append(getattr(callback, func.__name__)(*arg)) | |||
| return returns | |||
| @@ -115,15 +145,15 @@ class CallbackManager(Callback): | |||
| """ | |||
| def __init__(self, env, callbacks=None): | |||
| def __init__(self, env, attr, callbacks=None): | |||
| """ | |||
| :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||
| :param dict attr: read-only attributes for all callbacks | |||
| :param Callback callbacks: | |||
| """ | |||
| super(CallbackManager, self).__init__() | |||
| # set attribute of trainer environment | |||
| self.env = env | |||
| self.callbacks = [] | |||
| if callbacks is not None: | |||
| @@ -136,6 +166,23 @@ class CallbackManager(Callback): | |||
| else: | |||
| raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
| for env_name, env_val in env.items(): | |||
| for callback in self.callbacks: | |||
| setattr(callback, env_name, env_val) # Callback.trainer | |||
| self.set_property(**attr) | |||
| def set_property(self, **kwargs): | |||
| """设置所有callback的只读属性 | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| for callback in self.callbacks: | |||
| for k, v in kwargs.items(): | |||
| setattr(callback, "_" + k, v) | |||
| @transfer | |||
| def on_train_begin(self): | |||
| pass | |||
| @@ -121,7 +121,6 @@ class Trainer(object): | |||
| self.best_dev_perf = None | |||
| self.sampler = sampler if sampler is not None else RandomSampler() | |||
| self.prefetch = prefetch | |||
| self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
| if isinstance(optimizer, torch.optim.Optimizer): | |||
| self.optimizer = optimizer | |||
| @@ -144,6 +143,12 @@ class Trainer(object): | |||
| self.step = 0 | |||
| self.start_time = None # start timestamp | |||
| self.callback_manager = CallbackManager(env={"trainer": self}, | |||
| attr={"n_epochs": self.n_epochs, "n_steps": self.step, | |||
| "batch_size": self.batch_size, "model": self.model, | |||
| "optimizer": self.optimizer}, | |||
| callbacks=callbacks) | |||
| def train(self, load_best_model=True): | |||
| """ | |||
| @@ -236,6 +241,7 @@ class Trainer(object): | |||
| avg_loss = 0 | |||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
| prefetch=self.prefetch) | |||
| self.callback_manager.set_property(pbar=pbar) | |||
| for epoch in range(1, self.n_epochs+1): | |||
| pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
| # early stopping | |||
| @@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[TensorboardCallback("loss", "metric")]) | |||
| trainer.train() | |||
| def test_readonly_property(self): | |||
| from fastNLP.core.callback import Callback | |||
| class MyCallback(Callback): | |||
| def __init__(self): | |||
| super(MyCallback, self).__init__() | |||
| def on_epoch_begin(self, cur_epoch, total_epoch): | |||
| print(self.n_epochs, self.n_steps, self.batch_size) | |||
| print(self.model) | |||
| print(self.optimizer) | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=5, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[MyCallback()]) | |||
| trainer.train() | |||