| @@ -100,7 +100,8 @@ class Callback(object): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
| self._trainer = None # 在Trainer内部被重新赋值 | self._trainer = None # 在Trainer内部被重新赋值 | ||||
| self._disabled = False | |||||
| @property | @property | ||||
| def trainer(self): | def trainer(self): | ||||
| """ | """ | ||||
| @@ -158,6 +159,14 @@ class Callback(object): | |||||
| def batch_per_epoch(self): | def batch_per_epoch(self): | ||||
| """每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | """每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | ||||
| return self._trainer.batch_per_epoch | return self._trainer.batch_per_epoch | ||||
| @property | |||||
| def is_master(self): | |||||
| return self._trainer.is_master() | |||||
| @property | |||||
| def disabled(self): | |||||
| return self._disabled | |||||
| def on_train_begin(self): | def on_train_begin(self): | ||||
| """ | """ | ||||
| @@ -289,6 +298,8 @@ def _transfer(func): | |||||
| def wrapper(manager, *arg): | def wrapper(manager, *arg): | ||||
| returns = [] | returns = [] | ||||
| for callback in manager.callbacks: | for callback in manager.callbacks: | ||||
| if callback.disabled: | |||||
| continue | |||||
| returns.append(getattr(callback, func.__name__)(*arg)) | returns.append(getattr(callback, func.__name__)(*arg)) | ||||
| return returns | return returns | ||||
| @@ -320,7 +331,7 @@ class CallbackManager(Callback): | |||||
| for env_name, env_val in env.items(): | for env_name, env_val in env.items(): | ||||
| for callback in self.callbacks: | for callback in self.callbacks: | ||||
| setattr(callback, '_' + env_name, env_val) # Callback.trainer | setattr(callback, '_' + env_name, env_val) # Callback.trainer | ||||
| @_transfer | @_transfer | ||||
| def on_train_begin(self): | def on_train_begin(self): | ||||
| pass | pass | ||||
| @@ -378,6 +389,24 @@ class CallbackManager(Callback): | |||||
| pass | pass | ||||
| class DistCallbackManager(CallbackManager): | |||||
| def __init__(self, env, callbacks_all=None, callbacks_master=None): | |||||
| assert 'trainer' in env | |||||
| is_master = env['trainer'].is_master | |||||
| self.patch_callback(callbacks_master, disabled=not is_master) | |||||
| self.callbacks_all = CallbackManager(env, callbacks_all).callbacks | |||||
| self.callbacks_master = CallbackManager(env, callbacks_master).callbacks | |||||
| self.callbacks = self.callbacks_all + self.callbacks_master | |||||
| def patch_callback(self, callbacks, disabled): | |||||
| if not callbacks: | |||||
| return | |||||
| if not isinstance(callbacks, (list, tuple)): | |||||
| callbacks = [callbacks] | |||||
| for cb in callbacks: | |||||
| cb._disabled = disabled | |||||
| class GradientClipCallback(Callback): | class GradientClipCallback(Callback): | ||||
| """ | """ | ||||
| 别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | 别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | ||||
| @@ -415,6 +444,9 @@ class GradientClipCallback(Callback): | |||||
| def on_backward_end(self): | def on_backward_end(self): | ||||
| if self.step%self.update_every==0: | if self.step%self.update_every==0: | ||||
| if self.parameters is None: | if self.parameters is None: | ||||
| if getattr(self.trainer, 'fp16', default=''): | |||||
| from apex import amp | |||||
| self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | |||||
| self.clip_fun(self.model.parameters(), self.clip_value) | self.clip_fun(self.model.parameters(), self.clip_value) | ||||
| else: | else: | ||||
| self.clip_fun(self.parameters, self.clip_value) | self.clip_fun(self.parameters, self.clip_value) | ||||
| @@ -896,3 +928,21 @@ class EarlyStopError(CallbackException): | |||||
| def __init__(self, msg): | def __init__(self, msg): | ||||
| super(EarlyStopError, self).__init__(msg) | super(EarlyStopError, self).__init__(msg) | ||||
| class EchoCallback(Callback): | |||||
| def __init__(self, name, out=sys.stdout): | |||||
| super(EchoCallback, self).__init__() | |||||
| self.name = name | |||||
| self.out = out | |||||
| def __getattribute__(self, item): | |||||
| if item.startswith('on_'): | |||||
| print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||||
| file=self.out) | |||||
| return super(EchoCallback, self).__getattribute__(item) | |||||
| class TesterCallback(Callback): | |||||
| def __init__(self, data, model, metrics, batch_size=16, num_workers=None): | |||||
| self.tester = Tester(data, model) | |||||
| @@ -11,7 +11,7 @@ import time | |||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||
| from .batch import DataSetIter, BatchIter | from .batch import DataSetIter, BatchIter | ||||
| from .callback import CallbackManager, CallbackException | |||||
| from .callback import DistCallbackManager, CallbackException | |||||
| from .dataset import DataSet | from .dataset import DataSet | ||||
| from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| @@ -39,18 +39,36 @@ def get_local_rank(): | |||||
| class DistTrainer(): | class DistTrainer(): | ||||
| def __init__(self, model, train_data, optimizer, loss, callbacks=None, | |||||
| def __init__(self, train_data, model, optimizer=None, loss=None, | |||||
| callbacks_all=None, callbacks_master=None, | |||||
| batch_size_per_gpu=8, n_epochs=1, | batch_size_per_gpu=8, n_epochs=1, | ||||
| num_workers=1, drop_last=False, | |||||
| num_data_workers=1, drop_last=False, | |||||
| update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
| save_every=-1, save_path=None, | |||||
| logging_level=logging.INFO, | |||||
| fp16='', backend='nccl', init_method=None): | |||||
| save_every=-1, save_path=None, device='auto', | |||||
| fp16='', backend=None, init_method=None): | |||||
| assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | |||||
| if device == 'auto': | |||||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
| if backend is None: | |||||
| backend = 'nccl' if device == 'cuda' else 'gloo' | |||||
| # init distributed | |||||
| if device == 'cuda': | |||||
| torch.cuda.set_device(get_local_rank()) | |||||
| self.device = torch.device("cuda", get_local_rank()) | |||||
| else: | |||||
| self.device = torch.device(device) | |||||
| dist.init_process_group(backend=backend, init_method=init_method) | |||||
| self.world_size = dist.get_world_size() | |||||
| self.rank = dist.get_rank() # unique id for each process | |||||
| self.model = model | self.model = model | ||||
| self.train_data = train_data | self.train_data = train_data | ||||
| self.batch_size_per_gpu = int(batch_size_per_gpu) | self.batch_size_per_gpu = int(batch_size_per_gpu) | ||||
| self.n_epochs = int(n_epochs) | self.n_epochs = int(n_epochs) | ||||
| self.num_workers = int(num_workers) | |||||
| self.num_data_workers = int(num_data_workers) | |||||
| self.drop_last = drop_last | self.drop_last = drop_last | ||||
| self.update_every = int(update_every) | self.update_every = int(update_every) | ||||
| self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
| @@ -62,16 +80,13 @@ class DistTrainer(): | |||||
| self.init_method = init_method | self.init_method = init_method | ||||
| self.backend = backend | self.backend = backend | ||||
| self.local_rank = get_local_rank() | self.local_rank = get_local_rank() | ||||
| self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||||
| self._forward_func = model.forward | self._forward_func = model.forward | ||||
| self.callback_manager = DistCallbackManager( | |||||
| env={"trainer": self}, callbacks_all=callbacks_all, | |||||
| callbacks_master=callbacks_master) | |||||
| assert torch.cuda.is_available(), "Distributed Trainer requires cuda to be enabled." | |||||
| # init distributed | |||||
| torch.cuda.set_device(self.local_rank) | |||||
| self.device = torch.device("cuda", self.local_rank) | |||||
| dist.init_process_group(backend=self.backend, init_method=self.init_method) | |||||
| model.to(self.device) | model.to(self.device) | ||||
| optimizer = self.get_optimizer(optimizer) | |||||
| optimizer = self._get_optimizer(optimizer) | |||||
| # init fp16, must before DataParallel init | # init fp16, must before DataParallel init | ||||
| if len(self.fp16): | if len(self.fp16): | ||||
| @@ -81,51 +96,48 @@ class DistTrainer(): | |||||
| except ImportError: | except ImportError: | ||||
| raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | ||||
| assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | ||||
| assert device == 'cuda', "Amp requires cuda device" | |||||
| model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | ||||
| # init DataParallel | # init DataParallel | ||||
| self.model = DDP(model, device_ids=[self.local_rank], | self.model = DDP(model, device_ids=[self.local_rank], | ||||
| output_device=self.local_rank) | output_device=self.local_rank) | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| self.world_size = dist.get_world_size() | |||||
| self.rank = dist.get_rank() # unique id for each process | |||||
| self.sampler = DistributedSampler(self.train_data) | self.sampler = DistributedSampler(self.train_data) | ||||
| self.data_iterator = self.get_data_iter(self.train_data) | |||||
| self.n_steps = self.get_n_steps() | |||||
| self.data_iterator = self._get_data_iter(self.train_data) | |||||
| self.n_steps = self._get_n_steps() | |||||
| # Setup logging | # Setup logging | ||||
| dist.barrier() | |||||
| self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | |||||
| if self.save_path: | |||||
| self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) | |||||
| else: | |||||
| self.cp_save_path = None | |||||
| # use INFO in the master, WARN for others | |||||
| logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||||
| datefmt='%m/%d/%Y %H:%M:%S', | datefmt='%m/%d/%Y %H:%M:%S', | ||||
| level=logging_level) | |||||
| level=logging.INFO if self.is_master else logging.WARN) | |||||
| self.logger = logging.getLogger(__name__) | self.logger = logging.getLogger(__name__) | ||||
| self.logger.info("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||||
| self.logger.info("Setup Distributed Trainer") | |||||
| self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||||
| os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | ||||
| if self.is_master: | |||||
| self.logger.info('Total epochs: %d'% self.n_epochs) | |||||
| self.logger.info('Total steps: %d'% self.n_steps) | |||||
| self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||||
| self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||||
| self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||||
| self.logger.info("Num of callbacks: {}".format(len(self.callback_manager.callbacks))) | |||||
| self.logger.info( | |||||
| "Use callbacks: {}".format([repr(cb) for cb in self.callback_manager.callbacks])) | |||||
| # only master process save model | |||||
| if self.save_path: | |||||
| self.save_path = os.path.join( | |||||
| self.save_path, | |||||
| datetime.now().strftime('%m_%d_%y-%H_%M_%S')+'-'+str(os.getpid())) | |||||
| self.logger.info("Num of processes: {}".format(self.world_size)) | |||||
| self.logger.info("Use device: {}".format(device)) | |||||
| self.logger.info("Training with fp16: {}, optimization level: {}".format( | |||||
| len(self.fp16) > 0, self.fp16 if self.fp16 else None)) | |||||
| def get_n_steps(self): | |||||
| def _get_n_steps(self): | |||||
| batch_size = self.world_size * self.batch_size_per_gpu | batch_size = self.world_size * self.batch_size_per_gpu | ||||
| return (len(self.train_data) // batch_size + int( | return (len(self.train_data) // batch_size + int( | ||||
| len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs | len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs | ||||
| def get_data_iter(self, dataset): | |||||
| def _get_data_iter(self, dataset): | |||||
| if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
| return DataSetIter( | return DataSetIter( | ||||
| dataset=dataset, batch_size=self.batch_size_per_gpu, | dataset=dataset, batch_size=self.batch_size_per_gpu, | ||||
| num_workers=self.num_workers, sampler=self.sampler, | |||||
| num_workers=self.num_data_workers, sampler=self.sampler, | |||||
| drop_last=self.drop_last | drop_last=self.drop_last | ||||
| ) | ) | ||||
| elif isinstance(dataset, BatchIter): | elif isinstance(dataset, BatchIter): | ||||
| @@ -133,7 +145,7 @@ class DistTrainer(): | |||||
| else: | else: | ||||
| raise TypeError("train_data type {} not support".format(type(dataset))) | raise TypeError("train_data type {} not support".format(type(dataset))) | ||||
| def get_optimizer(self, optimizer): | |||||
| def _get_optimizer(self, optimizer): | |||||
| if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
| return optimizer | return optimizer | ||||
| elif isinstance(optimizer, Optimizer): | elif isinstance(optimizer, Optimizer): | ||||
| @@ -148,37 +160,50 @@ class DistTrainer(): | |||||
| return self.rank == 0 | return self.rank == 0 | ||||
| def train(self, on_exception='auto'): | def train(self, on_exception='auto'): | ||||
| start_time = time.time() | |||||
| results = {} | |||||
| if self.n_epochs <= 0: | |||||
| if self.is_master: | |||||
| self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||||
| results['seconds'] = 0. | |||||
| return results | |||||
| if self.is_master: | |||||
| try: | |||||
| self.logger.info("###### Training epochs started ######") | self.logger.info("###### Training epochs started ######") | ||||
| self.logger.info('Total epochs: %d'% self.n_epochs) | |||||
| self.logger.info('Total steps: %d'% self.n_steps) | |||||
| self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||||
| self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||||
| self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||||
| self.logger.info("Num of callbacks for all workers: {}".format( | |||||
| len(self.callback_manager.callbacks_all))) | |||||
| self.logger.info("Num of callbacks for master workers: {}".format( | |||||
| len(self.callback_manager.callbacks_master))) | |||||
| self.logger.info("Callbacks for all workers: {}".format( | |||||
| [repr(cb) for cb in self.callback_manager.callbacks_all])) | |||||
| self.logger.info("Callbacks for master workers: {}".format( | |||||
| [repr(cb) for cb in self.callback_manager.callbacks_master])) | |||||
| start_time = time.time() | |||||
| results = {} | |||||
| if self.n_epochs <= 0: | |||||
| self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||||
| results['seconds'] = 0. | |||||
| return results | |||||
| try: | |||||
| self.callback_manager.on_train_begin() | |||||
| self._train() | |||||
| self.callback_manager.on_train_end() | |||||
| except BaseException as e: | |||||
| self.callback_manager.on_exception(e) | |||||
| if on_exception == 'auto': | |||||
| if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||||
| try: | |||||
| self.callback_manager.on_train_begin() | |||||
| self._train() | |||||
| self.callback_manager.on_train_end() | |||||
| except BaseException as e: | |||||
| self.callback_manager.on_exception(e) | |||||
| if on_exception == 'auto': | |||||
| if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||||
| raise e | |||||
| else: | |||||
| self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||||
| elif on_exception == 'raise': | |||||
| raise e | raise e | ||||
| else: | |||||
| self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||||
| elif on_exception == 'raise': | |||||
| raise e | |||||
| results['seconds'] = round(time.time() - start_time, 2) | |||||
| if self.is_master: | |||||
| results['seconds'] = round(time.time() - start_time, 2) | |||||
| self.logger.info("###### Train finished ######") | self.logger.info("###### Train finished ######") | ||||
| self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | ||||
| return results | |||||
| return results | |||||
| finally: | |||||
| self.close() | |||||
| def _train(self): | def _train(self): | ||||
| if self.fp16: | if self.fp16: | ||||
| @@ -187,7 +212,7 @@ class DistTrainer(): | |||||
| self.step = 0 | self.step = 0 | ||||
| self.epoch = 0 | self.epoch = 0 | ||||
| self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | ||||
| leave=False, dynamic_ncols=True, disable=not self.is_master) | |||||
| leave=False, dynamic_ncols=True, disable=not self.is_master) | |||||
| pbar = self.pbar | pbar = self.pbar | ||||
| avg_loss = 0 | avg_loss = 0 | ||||
| data_iterator = self.data_iterator | data_iterator = self.data_iterator | ||||
| @@ -238,18 +263,17 @@ class DistTrainer(): | |||||
| (self.validate_every < 0 and self.step % len(data_iterator) == 0)): | (self.validate_every < 0 and self.step % len(data_iterator) == 0)): | ||||
| eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
| self.n_steps) | self.n_steps) | ||||
| if self.is_master: | |||||
| self.logger.info(eval_str) | |||||
| self.logger.info(eval_str) | |||||
| self.callback_manager.on_validation() | self.callback_manager.on_validation() | ||||
| dist.barrier() | dist.barrier() | ||||
| if self.save_path and \ | |||||
| if self.cp_save_path and \ | |||||
| self.save_every > 0 and \ | self.save_every > 0 and \ | ||||
| self.step % self.save_every == 0: | self.step % self.save_every == 0: | ||||
| self.save_check_point() | self.save_check_point() | ||||
| # ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
| if self.save_path and self.save_every < 0: | |||||
| if self.save_every < 0 and self.cp_save_path: | |||||
| self.save_check_point() | self.save_check_point() | ||||
| # lr decay; early stopping | # lr decay; early stopping | ||||
| self.callback_manager.on_epoch_end() | self.callback_manager.on_epoch_end() | ||||
| @@ -287,16 +311,15 @@ class DistTrainer(): | |||||
| return loss.mean() | return loss.mean() | ||||
| def save_check_point(self, only_params=False): | def save_check_point(self, only_params=False): | ||||
| # only master save models | |||||
| if self.is_master: | if self.is_master: | ||||
| if not os.path.exists(self.save_path): | |||||
| os.makedirs(self.save_path) | |||||
| path = os.path.join(self.save_path, 'checkpoint-{}.bin'.format(self.step)) | |||||
| os.makedirs(self.cp_save_path, exist_ok=True) | |||||
| path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) | |||||
| self.logger.info("Save checkpoint to {}".format(path)) | self.logger.info("Save checkpoint to {}".format(path)) | ||||
| model_to_save = self.model.module | model_to_save = self.model.module | ||||
| if only_params: | if only_params: | ||||
| model_to_save = model_to_save.state_dict() | model_to_save = model_to_save.state_dict() | ||||
| torch.save(model_to_save, path) | torch.save(model_to_save, path) | ||||
| dist.barrier() | |||||
| def close(self): | def close(self): | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -431,13 +431,13 @@ class Trainer(object): | |||||
| super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
| if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
| raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | ||||
| # check metrics and dev_data | # check metrics and dev_data | ||||
| if (not metrics) and dev_data is not None: | if (not metrics) and dev_data is not None: | ||||
| raise ValueError("No metric for dev_data evaluation.") | raise ValueError("No metric for dev_data evaluation.") | ||||
| if metrics and (dev_data is None): | if metrics and (dev_data is None): | ||||
| raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | ||||
| # check update every | # check update every | ||||
| assert update_every >= 1, "update_every must be no less than 1." | assert update_every >= 1, "update_every must be no less than 1." | ||||
| self.update_every = int(update_every) | self.update_every = int(update_every) | ||||
| @@ -447,7 +447,7 @@ class Trainer(object): | |||||
| raise ValueError("save_path can only be None or `str`.") | raise ValueError("save_path can only be None or `str`.") | ||||
| # prepare evaluate | # prepare evaluate | ||||
| metrics = _prepare_metrics(metrics) | metrics = _prepare_metrics(metrics) | ||||
| # parse metric_key | # parse metric_key | ||||
| # increase_better is True. It means the exp result gets better if the indicator increases. | # increase_better is True. It means the exp result gets better if the indicator increases. | ||||
| # It is true by default. | # It is true by default. | ||||
| @@ -546,7 +546,7 @@ class Trainer(object): | |||||
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | ||||
| else: | else: | ||||
| raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | ||||
| self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
| self.pbar = None | self.pbar = None | ||||
| self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
| @@ -558,10 +558,10 @@ class Trainer(object): | |||||
| batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
| device=None, # 由上面的部分处理device | device=None, # 由上面的部分处理device | ||||
| verbose=0) | verbose=0) | ||||
| self.step = 0 | self.step = 0 | ||||
| self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
| self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
| callbacks=callbacks) | callbacks=callbacks) | ||||
| @@ -597,7 +597,7 @@ class Trainer(object): | |||||
| self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
| start_time = time.time() | start_time = time.time() | ||||
| print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
| try: | try: | ||||
| self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
| self._train() | self._train() | ||||
| @@ -610,7 +610,7 @@ class Trainer(object): | |||||
| raise e | raise e | ||||
| elif on_exception == 'raise': | elif on_exception == 'raise': | ||||
| raise e | raise e | ||||
| if self.dev_data is not None and self.best_dev_perf is not None: | if self.dev_data is not None and self.best_dev_perf is not None: | ||||
| print( | print( | ||||
| "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
| @@ -628,9 +628,9 @@ class Trainer(object): | |||||
| finally: | finally: | ||||
| pass | pass | ||||
| results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
| return results | return results | ||||
| def _train(self): | def _train(self): | ||||
| if not self.use_tqdm: | if not self.use_tqdm: | ||||
| from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | ||||
| @@ -656,21 +656,21 @@ class Trainer(object): | |||||
| # negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
| self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | ||||
| prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
| # edit prediction | # edit prediction | ||||
| self.callback_manager.on_loss_begin(batch_y, prediction) | self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
| loss = self._compute_loss(prediction, batch_y).mean() | loss = self._compute_loss(prediction, batch_y).mean() | ||||
| avg_loss += loss.item() | avg_loss += loss.item() | ||||
| loss = loss / self.update_every | loss = loss / self.update_every | ||||
| # Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
| self.callback_manager.on_backward_begin(loss) | self.callback_manager.on_backward_begin(loss) | ||||
| self._grad_backward(loss) | self._grad_backward(loss) | ||||
| self.callback_manager.on_backward_end() | self.callback_manager.on_backward_end() | ||||
| self._update() | self._update() | ||||
| self.callback_manager.on_step_end() | self.callback_manager.on_step_end() | ||||
| if self.step % self.print_every == 0: | if self.step % self.print_every == 0: | ||||
| avg_loss = float(avg_loss) / self.print_every | avg_loss = float(avg_loss) / self.print_every | ||||
| if self.use_tqdm: | if self.use_tqdm: | ||||
| @@ -684,7 +684,7 @@ class Trainer(object): | |||||
| pbar.set_postfix_str(print_output) | pbar.set_postfix_str(print_output) | ||||
| avg_loss = 0 | avg_loss = 0 | ||||
| self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
| if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
| (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | ||||
| and self.dev_data is not None: | and self.dev_data is not None: | ||||
| @@ -693,20 +693,20 @@ class Trainer(object): | |||||
| self.n_steps) + \ | self.n_steps) + \ | ||||
| self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
| pbar.write(eval_str + '\n') | pbar.write(eval_str + '\n') | ||||
| # ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
| # lr decay; early stopping | # lr decay; early stopping | ||||
| self.callback_manager.on_epoch_end() | self.callback_manager.on_epoch_end() | ||||
| # =============== epochs end =================== # | # =============== epochs end =================== # | ||||
| pbar.close() | pbar.close() | ||||
| self.pbar = None | self.pbar = None | ||||
| # ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
| def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
| self.callback_manager.on_valid_begin() | self.callback_manager.on_valid_begin() | ||||
| res = self.tester.test() | res = self.tester.test() | ||||
| is_better_eval = False | is_better_eval = False | ||||
| if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
| if self.save_path is not None: | if self.save_path is not None: | ||||
| @@ -721,7 +721,7 @@ class Trainer(object): | |||||
| # get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
| self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | ||||
| return res | return res | ||||
| def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
| """Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
| @@ -733,14 +733,14 @@ class Trainer(object): | |||||
| model.eval() | model.eval() | ||||
| else: | else: | ||||
| model.train() | model.train() | ||||
| def _update(self): | def _update(self): | ||||
| """Perform weight update on a model. | """Perform weight update on a model. | ||||
| """ | """ | ||||
| if self.step % self.update_every == 0: | if self.step % self.update_every == 0: | ||||
| self.optimizer.step() | self.optimizer.step() | ||||
| def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
| x = _build_args(self._forward_func, **x) | x = _build_args(self._forward_func, **x) | ||||
| y = network(**x) | y = network(**x) | ||||
| @@ -748,7 +748,7 @@ class Trainer(object): | |||||
| raise TypeError( | raise TypeError( | ||||
| f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | ||||
| return y | return y | ||||
| def _grad_backward(self, loss): | def _grad_backward(self, loss): | ||||
| """Compute gradient with link rules. | """Compute gradient with link rules. | ||||
| @@ -759,7 +759,7 @@ class Trainer(object): | |||||
| if (self.step-1) % self.update_every == 0: | if (self.step-1) % self.update_every == 0: | ||||
| self.model.zero_grad() | self.model.zero_grad() | ||||
| loss.backward() | loss.backward() | ||||
| def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
| """Compute loss given prediction and ground truth. | """Compute loss given prediction and ground truth. | ||||
| @@ -768,7 +768,7 @@ class Trainer(object): | |||||
| :return: a scalar | :return: a scalar | ||||
| """ | """ | ||||
| return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
| def _save_model(self, model, model_name, only_param=False): | def _save_model(self, model, model_name, only_param=False): | ||||
| """ 存储不含有显卡信息的state_dict或model | """ 存储不含有显卡信息的state_dict或model | ||||
| :param model: | :param model: | ||||
| @@ -791,7 +791,7 @@ class Trainer(object): | |||||
| model.cpu() | model.cpu() | ||||
| torch.save(model, model_path) | torch.save(model, model_path) | ||||
| model.to(self._model_device) | model.to(self._model_device) | ||||
| def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
| # 返回bool值指示是否成功reload模型 | # 返回bool值指示是否成功reload模型 | ||||
| if self.save_path is not None: | if self.save_path is not None: | ||||
| @@ -809,7 +809,7 @@ class Trainer(object): | |||||
| else: | else: | ||||
| return False | return False | ||||
| return True | return True | ||||
| def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
| """Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
| @@ -835,6 +835,9 @@ class Trainer(object): | |||||
| is_better = False | is_better = False | ||||
| return is_better | return is_better | ||||
| @property | |||||
| def is_master(self): | |||||
| return True | |||||
| DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
| DEFAULT_CHECK_NUM_BATCH = 2 | DEFAULT_CHECK_NUM_BATCH = 2 | ||||
| @@ -4,7 +4,7 @@ import numpy as np | |||||
| import torch.cuda | import torch.cuda | ||||
| from fastNLP import DataSet | from fastNLP import DataSet | ||||
| from fastNLP import Instance | from fastNLP import Instance | ||||
| from fastNLP import CrossEntropyLoss | |||||
| from fastNLP import CrossEntropyLoss, BCELoss | |||||
| from fastNLP import SGD | from fastNLP import SGD | ||||
| from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | ||||
| from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
| @@ -12,6 +12,7 @@ import shutil | |||||
| import os | import os | ||||
| import subprocess | import subprocess | ||||
| from argparse import ArgumentParser | from argparse import ArgumentParser | ||||
| from fastNLP.core.callback import EchoCallback | |||||
| def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
| mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
| @@ -36,6 +37,26 @@ def prepare_fake_dataset2(*args, size=100): | |||||
| def set_rng_seed(seed): | def set_rng_seed(seed): | ||||
| np.random.seed(seed) | np.random.seed(seed) | ||||
| def prepare_env(): | |||||
| def prepare_fake_dataset(): | |||||
| mean = np.array([-3, -3]) | |||||
| cov = np.array([[1, 0], [0, 1]]) | |||||
| class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
| mean = np.array([3, 3]) | |||||
| cov = np.array([[1, 0], [0, 1]]) | |||||
| class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
| data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
| [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
| return data_set | |||||
| data_set = prepare_fake_dataset() | |||||
| data_set.set_input("x") | |||||
| data_set.set_target("y") | |||||
| model = NaiveClassifier(2, 1) | |||||
| return data_set, model | |||||
| class TestDistTrainer(unittest.TestCase): | class TestDistTrainer(unittest.TestCase): | ||||
| save_path = './save_cp' | save_path = './save_cp' | ||||
| @@ -84,23 +105,35 @@ class TestDistTrainer(unittest.TestCase): | |||||
| if trainer.is_master and os.path.exists(self.save_path): | if trainer.is_master and os.path.exists(self.save_path): | ||||
| shutil.rmtree(self.save_path) | shutil.rmtree(self.save_path) | ||||
| def run3(self): | |||||
| data_set, model = prepare_env() | |||||
| trainer = DistTrainer( | |||||
| data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=3, print_every=50, | |||||
| callbacks_all=[EchoCallback('callbacks_all')], | |||||
| callbacks_master=[EchoCallback('callbacks_master')] | |||||
| ) | |||||
| trainer.train() | |||||
| def run_dist(self, run_id): | def run_dist(self, run_id): | ||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| ngpu = min(4, torch.cuda.device_count()) | |||||
| ngpu = min(2, torch.cuda.device_count()) | |||||
| path = __file__ | path = __file__ | ||||
| cmd = ['python', '-m', 'torch.distributed.launch', | cmd = ['python', '-m', 'torch.distributed.launch', | ||||
| '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | ||||
| print(' '.join(cmd)) | print(' '.join(cmd)) | ||||
| retcode = subprocess.call(cmd) | |||||
| if retcode: | |||||
| raise RuntimeError('subprocess got non-zero exit status %d' % retcode) | |||||
| subprocess.check_call(cmd, timeout=60.0) | |||||
| def test1(self): | |||||
| def test_normal_run(self): | |||||
| self.run_dist(1) | self.run_dist(1) | ||||
| def test2(self): | |||||
| def test_fp16(self): | |||||
| self.run_dist(2) | self.run_dist(2) | ||||
| def test_callback(self): | |||||
| self.run_dist(3) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| runner = TestDistTrainer() | runner = TestDistTrainer() | ||||
| parser = ArgumentParser() | parser = ArgumentParser() | ||||