| @@ -79,6 +79,7 @@ except: | |||||
| from ..io.model_io import ModelSaver, ModelLoader | from ..io.model_io import ModelSaver, ModelLoader | ||||
| from .dataset import DataSet | from .dataset import DataSet | ||||
| from .tester import Tester | from .tester import Tester | ||||
| import logging | |||||
| try: | try: | ||||
| import fitlog | import fitlog | ||||
| @@ -167,7 +168,11 @@ class Callback(object): | |||||
| @property | @property | ||||
| def disabled(self): | def disabled(self): | ||||
| return self._disabled | return self._disabled | ||||
| @property | |||||
| def logger(self): | |||||
| return getattr(self._trainer, 'logger', logging) | |||||
| def on_train_begin(self): | def on_train_begin(self): | ||||
| """ | """ | ||||
| 在Train过程开始之前调用。 | 在Train过程开始之前调用。 | ||||
| @@ -316,21 +321,27 @@ class CallbackManager(Callback): | |||||
| """ | """ | ||||
| super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
| # set attribute of trainer environment | # set attribute of trainer environment | ||||
| self._env = env | |||||
| self.callbacks = [] | self.callbacks = [] | ||||
| if callbacks is not None: | |||||
| if isinstance(callbacks, list): | |||||
| if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||||
| self.callbacks.extend(callbacks) | |||||
| else: | |||||
| obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||||
| raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||||
| if callbacks: | |||||
| self.callbacks += self.prepare_callbacks(callbacks) | |||||
| def prepare_callbacks(self, callbacks): | |||||
| if not callbacks: | |||||
| return [] | |||||
| if isinstance(callbacks, list): | |||||
| if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||||
| self.callbacks.extend(callbacks) | |||||
| else: | else: | ||||
| raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||||
| for env_name, env_val in env.items(): | |||||
| for callback in self.callbacks: | |||||
| obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||||
| raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||||
| else: | |||||
| raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||||
| for env_name, env_val in self._env.items(): | |||||
| for callback in callbacks: | |||||
| setattr(callback, '_' + env_name, env_val) # Callback.trainer | setattr(callback, '_' + env_name, env_val) # Callback.trainer | ||||
| return callbacks | |||||
| @_transfer | @_transfer | ||||
| def on_train_begin(self): | def on_train_begin(self): | ||||
| @@ -391,11 +402,12 @@ class CallbackManager(Callback): | |||||
| class DistCallbackManager(CallbackManager): | class DistCallbackManager(CallbackManager): | ||||
| def __init__(self, env, callbacks_all=None, callbacks_master=None): | def __init__(self, env, callbacks_all=None, callbacks_master=None): | ||||
| super(DistCallbackManager, self).__init__(env) | |||||
| assert 'trainer' in env | assert 'trainer' in env | ||||
| is_master = env['trainer'].is_master | is_master = env['trainer'].is_master | ||||
| self.patch_callback(callbacks_master, disabled=not is_master) | self.patch_callback(callbacks_master, disabled=not is_master) | ||||
| self.callbacks_all = CallbackManager(env, callbacks_all).callbacks | |||||
| self.callbacks_master = CallbackManager(env, callbacks_master).callbacks | |||||
| self.callbacks_all = self.prepare_callbacks(callbacks_all) | |||||
| self.callbacks_master = self.prepare_callbacks(callbacks_master) | |||||
| self.callbacks = self.callbacks_all + self.callbacks_master | self.callbacks = self.callbacks_all + self.callbacks_master | ||||
| def patch_callback(self, callbacks, disabled): | def patch_callback(self, callbacks, disabled): | ||||
| @@ -944,5 +956,21 @@ class EchoCallback(Callback): | |||||
| class TesterCallback(Callback): | class TesterCallback(Callback): | ||||
| def __init__(self, data, model, metrics, batch_size=16, num_workers=None): | |||||
| self.tester = Tester(data, model) | |||||
| def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\ | |||||
| #TODO add compare & save best | |||||
| super(TesterCallback, self).__init__() | |||||
| self.tester = Tester(data, model, | |||||
| metrics=metrics, batch_size=batch_size, | |||||
| num_workers=num_workers, verbose=0) | |||||
| self.score = None | |||||
| def on_validation(self): | |||||
| cur_socre = self.tester.test() | |||||
| eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | |||||
| self.epoch, self.n_epochs, self.step, self.n_steps, | |||||
| self.tester._format_eval_results(cur_socre)) | |||||
| self.logger.info(eval_str) | |||||
| def on_train_end(self): | |||||
| self.logger.info('Evaluate on training ends.') | |||||
| self.on_validation() | |||||
| @@ -11,7 +11,7 @@ import time | |||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||
| from .batch import DataSetIter, BatchIter | from .batch import DataSetIter, BatchIter | ||||
| from .callback import DistCallbackManager, CallbackException | |||||
| from .callback import DistCallbackManager, CallbackException, TesterCallback | |||||
| from .dataset import DataSet | from .dataset import DataSet | ||||
| from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| @@ -39,10 +39,13 @@ def get_local_rank(): | |||||
| class DistTrainer(): | class DistTrainer(): | ||||
| """Distributed Trainer that support distributed and mixed precision training | |||||
| """ | |||||
| def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
| callbacks_all=None, callbacks_master=None, | callbacks_all=None, callbacks_master=None, | ||||
| batch_size_per_gpu=8, n_epochs=1, | batch_size_per_gpu=8, n_epochs=1, | ||||
| num_data_workers=1, drop_last=False, | num_data_workers=1, drop_last=False, | ||||
| dev_data=None, metrics=None, | |||||
| update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
| save_every=-1, save_path=None, device='auto', | save_every=-1, save_path=None, device='auto', | ||||
| fp16='', backend=None, init_method=None): | fp16='', backend=None, init_method=None): | ||||
| @@ -107,6 +110,14 @@ class DistTrainer(): | |||||
| self.data_iterator = self._get_data_iter(self.train_data) | self.data_iterator = self._get_data_iter(self.train_data) | ||||
| self.n_steps = self._get_n_steps() | self.n_steps = self._get_n_steps() | ||||
| # for evaluation, only run eval on master proc | |||||
| if dev_data and metrics: | |||||
| cb = TesterCallback( | |||||
| dev_data, model, metrics, | |||||
| batch_size=batch_size_per_gpu, num_workers=num_data_workers) | |||||
| self.callback_manager.callbacks_master += \ | |||||
| self.callback_manager.prepare_callbacks([cb]) | |||||
| # Setup logging | # Setup logging | ||||
| dist.barrier() | dist.barrier() | ||||
| self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | ||||
| @@ -261,9 +272,6 @@ class DistTrainer(): | |||||
| if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
| (self.validate_every < 0 and self.step % len(data_iterator) == 0)): | (self.validate_every < 0 and self.step % len(data_iterator) == 0)): | ||||
| eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
| self.n_steps) | |||||
| self.logger.info(eval_str) | |||||
| self.callback_manager.on_validation() | self.callback_manager.on_validation() | ||||
| dist.barrier() | dist.barrier() | ||||
| @@ -13,6 +13,7 @@ import os | |||||
| import subprocess | import subprocess | ||||
| from argparse import ArgumentParser | from argparse import ArgumentParser | ||||
| from fastNLP.core.callback import EchoCallback | from fastNLP.core.callback import EchoCallback | ||||
| from fastNLP import AccuracyMetric | |||||
| def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
| mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
| @@ -106,15 +107,36 @@ class TestDistTrainer(unittest.TestCase): | |||||
| shutil.rmtree(self.save_path) | shutil.rmtree(self.save_path) | ||||
| def run3(self): | def run3(self): | ||||
| set_rng_seed(100) | |||||
| data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
| trainer = DistTrainer( | trainer = DistTrainer( | ||||
| data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), | |||||
| data_set, model, optimizer=None, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=3, print_every=50, | n_epochs=3, print_every=50, | ||||
| callbacks_all=[EchoCallback('callbacks_all')], | callbacks_all=[EchoCallback('callbacks_all')], | ||||
| callbacks_master=[EchoCallback('callbacks_master')] | callbacks_master=[EchoCallback('callbacks_master')] | ||||
| ) | ) | ||||
| trainer.train() | trainer.train() | ||||
| def run4(self): | |||||
| set_rng_seed(100) | |||||
| data_set, model = prepare_env() | |||||
| train_set, dev_set = data_set.split(0.3) | |||||
| model = NaiveClassifier(2, 1) | |||||
| trainer = DistTrainer( | |||||
| train_set, model, optimizer=SGD(lr=0.1), | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
| ) | |||||
| trainer.train() | |||||
| """ | |||||
| # 应该正确运行 | |||||
| """ | |||||
| def run_dist(self, run_id): | def run_dist(self, run_id): | ||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| ngpu = min(2, torch.cuda.device_count()) | ngpu = min(2, torch.cuda.device_count()) | ||||
| @@ -133,6 +155,8 @@ class TestDistTrainer(unittest.TestCase): | |||||
| def test_callback(self): | def test_callback(self): | ||||
| self.run_dist(3) | self.run_dist(3) | ||||
| def test_dev_data(self): | |||||
| self.run_dist(4) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| runner = TestDistTrainer() | runner = TestDistTrainer() | ||||