| @@ -100,7 +100,8 @@ class Callback(object): | |||
| def __init__(self): | |||
| super(Callback, self).__init__() | |||
| self._trainer = None # 在Trainer内部被重新赋值 | |||
| self._disabled = False | |||
| @property | |||
| def trainer(self): | |||
| """ | |||
| @@ -158,6 +159,14 @@ class Callback(object): | |||
| def batch_per_epoch(self): | |||
| """每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||
| return self._trainer.batch_per_epoch | |||
| @property | |||
| def is_master(self): | |||
| return self._trainer.is_master() | |||
| @property | |||
| def disabled(self): | |||
| return self._disabled | |||
| def on_train_begin(self): | |||
| """ | |||
| @@ -289,6 +298,8 @@ def _transfer(func): | |||
| def wrapper(manager, *arg): | |||
| returns = [] | |||
| for callback in manager.callbacks: | |||
| if callback.disabled: | |||
| continue | |||
| returns.append(getattr(callback, func.__name__)(*arg)) | |||
| return returns | |||
| @@ -320,7 +331,7 @@ class CallbackManager(Callback): | |||
| for env_name, env_val in env.items(): | |||
| for callback in self.callbacks: | |||
| setattr(callback, '_' + env_name, env_val) # Callback.trainer | |||
| @_transfer | |||
| def on_train_begin(self): | |||
| pass | |||
| @@ -378,6 +389,24 @@ class CallbackManager(Callback): | |||
| pass | |||
| class DistCallbackManager(CallbackManager): | |||
| def __init__(self, env, callbacks_all=None, callbacks_master=None): | |||
| assert 'trainer' in env | |||
| is_master = env['trainer'].is_master | |||
| self.patch_callback(callbacks_master, disabled=not is_master) | |||
| self.callbacks_all = CallbackManager(env, callbacks_all).callbacks | |||
| self.callbacks_master = CallbackManager(env, callbacks_master).callbacks | |||
| self.callbacks = self.callbacks_all + self.callbacks_master | |||
| def patch_callback(self, callbacks, disabled): | |||
| if not callbacks: | |||
| return | |||
| if not isinstance(callbacks, (list, tuple)): | |||
| callbacks = [callbacks] | |||
| for cb in callbacks: | |||
| cb._disabled = disabled | |||
| class GradientClipCallback(Callback): | |||
| """ | |||
| 别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | |||
| @@ -415,6 +444,9 @@ class GradientClipCallback(Callback): | |||
| def on_backward_end(self): | |||
| if self.step%self.update_every==0: | |||
| if self.parameters is None: | |||
| if getattr(self.trainer, 'fp16', default=''): | |||
| from apex import amp | |||
| self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | |||
| self.clip_fun(self.model.parameters(), self.clip_value) | |||
| else: | |||
| self.clip_fun(self.parameters, self.clip_value) | |||
| @@ -896,3 +928,21 @@ class EarlyStopError(CallbackException): | |||
| def __init__(self, msg): | |||
| super(EarlyStopError, self).__init__(msg) | |||
| class EchoCallback(Callback): | |||
| def __init__(self, name, out=sys.stdout): | |||
| super(EchoCallback, self).__init__() | |||
| self.name = name | |||
| self.out = out | |||
| def __getattribute__(self, item): | |||
| if item.startswith('on_'): | |||
| print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||
| file=self.out) | |||
| return super(EchoCallback, self).__getattribute__(item) | |||
| class TesterCallback(Callback): | |||
| def __init__(self, data, model, metrics, batch_size=16, num_workers=None): | |||
| self.tester = Tester(data, model) | |||
| @@ -11,7 +11,7 @@ import time | |||
| from datetime import datetime, timedelta | |||
| from .batch import DataSetIter, BatchIter | |||
| from .callback import CallbackManager, CallbackException | |||
| from .callback import DistCallbackManager, CallbackException | |||
| from .dataset import DataSet | |||
| from .losses import _prepare_losser | |||
| from .optimizer import Optimizer | |||
| @@ -39,18 +39,36 @@ def get_local_rank(): | |||
| class DistTrainer(): | |||
| def __init__(self, model, train_data, optimizer, loss, callbacks=None, | |||
| def __init__(self, train_data, model, optimizer=None, loss=None, | |||
| callbacks_all=None, callbacks_master=None, | |||
| batch_size_per_gpu=8, n_epochs=1, | |||
| num_workers=1, drop_last=False, | |||
| num_data_workers=1, drop_last=False, | |||
| update_every=1, print_every=10, validate_every=-1, | |||
| save_every=-1, save_path=None, | |||
| logging_level=logging.INFO, | |||
| fp16='', backend='nccl', init_method=None): | |||
| save_every=-1, save_path=None, device='auto', | |||
| fp16='', backend=None, init_method=None): | |||
| assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | |||
| if device == 'auto': | |||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
| if backend is None: | |||
| backend = 'nccl' if device == 'cuda' else 'gloo' | |||
| # init distributed | |||
| if device == 'cuda': | |||
| torch.cuda.set_device(get_local_rank()) | |||
| self.device = torch.device("cuda", get_local_rank()) | |||
| else: | |||
| self.device = torch.device(device) | |||
| dist.init_process_group(backend=backend, init_method=init_method) | |||
| self.world_size = dist.get_world_size() | |||
| self.rank = dist.get_rank() # unique id for each process | |||
| self.model = model | |||
| self.train_data = train_data | |||
| self.batch_size_per_gpu = int(batch_size_per_gpu) | |||
| self.n_epochs = int(n_epochs) | |||
| self.num_workers = int(num_workers) | |||
| self.num_data_workers = int(num_data_workers) | |||
| self.drop_last = drop_last | |||
| self.update_every = int(update_every) | |||
| self.print_every = int(print_every) | |||
| @@ -62,16 +80,13 @@ class DistTrainer(): | |||
| self.init_method = init_method | |||
| self.backend = backend | |||
| self.local_rank = get_local_rank() | |||
| self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
| self._forward_func = model.forward | |||
| self.callback_manager = DistCallbackManager( | |||
| env={"trainer": self}, callbacks_all=callbacks_all, | |||
| callbacks_master=callbacks_master) | |||
| assert torch.cuda.is_available(), "Distributed Trainer requires cuda to be enabled." | |||
| # init distributed | |||
| torch.cuda.set_device(self.local_rank) | |||
| self.device = torch.device("cuda", self.local_rank) | |||
| dist.init_process_group(backend=self.backend, init_method=self.init_method) | |||
| model.to(self.device) | |||
| optimizer = self.get_optimizer(optimizer) | |||
| optimizer = self._get_optimizer(optimizer) | |||
| # init fp16, must before DataParallel init | |||
| if len(self.fp16): | |||
| @@ -81,51 +96,48 @@ class DistTrainer(): | |||
| except ImportError: | |||
| raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||
| assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | |||
| assert device == 'cuda', "Amp requires cuda device" | |||
| model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | |||
| # init DataParallel | |||
| self.model = DDP(model, device_ids=[self.local_rank], | |||
| output_device=self.local_rank) | |||
| self.optimizer = optimizer | |||
| self.world_size = dist.get_world_size() | |||
| self.rank = dist.get_rank() # unique id for each process | |||
| self.sampler = DistributedSampler(self.train_data) | |||
| self.data_iterator = self.get_data_iter(self.train_data) | |||
| self.n_steps = self.get_n_steps() | |||
| self.data_iterator = self._get_data_iter(self.train_data) | |||
| self.n_steps = self._get_n_steps() | |||
| # Setup logging | |||
| dist.barrier() | |||
| self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | |||
| if self.save_path: | |||
| self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) | |||
| else: | |||
| self.cp_save_path = None | |||
| # use INFO in the master, WARN for others | |||
| logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||
| datefmt='%m/%d/%Y %H:%M:%S', | |||
| level=logging_level) | |||
| level=logging.INFO if self.is_master else logging.WARN) | |||
| self.logger = logging.getLogger(__name__) | |||
| self.logger.info("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
| self.logger.info("Setup Distributed Trainer") | |||
| self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
| os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | |||
| if self.is_master: | |||
| self.logger.info('Total epochs: %d'% self.n_epochs) | |||
| self.logger.info('Total steps: %d'% self.n_steps) | |||
| self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||
| self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||
| self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||
| self.logger.info("Num of callbacks: {}".format(len(self.callback_manager.callbacks))) | |||
| self.logger.info( | |||
| "Use callbacks: {}".format([repr(cb) for cb in self.callback_manager.callbacks])) | |||
| # only master process save model | |||
| if self.save_path: | |||
| self.save_path = os.path.join( | |||
| self.save_path, | |||
| datetime.now().strftime('%m_%d_%y-%H_%M_%S')+'-'+str(os.getpid())) | |||
| self.logger.info("Num of processes: {}".format(self.world_size)) | |||
| self.logger.info("Use device: {}".format(device)) | |||
| self.logger.info("Training with fp16: {}, optimization level: {}".format( | |||
| len(self.fp16) > 0, self.fp16 if self.fp16 else None)) | |||
| def get_n_steps(self): | |||
| def _get_n_steps(self): | |||
| batch_size = self.world_size * self.batch_size_per_gpu | |||
| return (len(self.train_data) // batch_size + int( | |||
| len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs | |||
| def get_data_iter(self, dataset): | |||
| def _get_data_iter(self, dataset): | |||
| if isinstance(dataset, DataSet): | |||
| return DataSetIter( | |||
| dataset=dataset, batch_size=self.batch_size_per_gpu, | |||
| num_workers=self.num_workers, sampler=self.sampler, | |||
| num_workers=self.num_data_workers, sampler=self.sampler, | |||
| drop_last=self.drop_last | |||
| ) | |||
| elif isinstance(dataset, BatchIter): | |||
| @@ -133,7 +145,7 @@ class DistTrainer(): | |||
| else: | |||
| raise TypeError("train_data type {} not support".format(type(dataset))) | |||
| def get_optimizer(self, optimizer): | |||
| def _get_optimizer(self, optimizer): | |||
| if isinstance(optimizer, torch.optim.Optimizer): | |||
| return optimizer | |||
| elif isinstance(optimizer, Optimizer): | |||
| @@ -148,37 +160,50 @@ class DistTrainer(): | |||
| return self.rank == 0 | |||
| def train(self, on_exception='auto'): | |||
| start_time = time.time() | |||
| results = {} | |||
| if self.n_epochs <= 0: | |||
| if self.is_master: | |||
| self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||
| results['seconds'] = 0. | |||
| return results | |||
| if self.is_master: | |||
| try: | |||
| self.logger.info("###### Training epochs started ######") | |||
| self.logger.info('Total epochs: %d'% self.n_epochs) | |||
| self.logger.info('Total steps: %d'% self.n_steps) | |||
| self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||
| self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||
| self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||
| self.logger.info("Num of callbacks for all workers: {}".format( | |||
| len(self.callback_manager.callbacks_all))) | |||
| self.logger.info("Num of callbacks for master workers: {}".format( | |||
| len(self.callback_manager.callbacks_master))) | |||
| self.logger.info("Callbacks for all workers: {}".format( | |||
| [repr(cb) for cb in self.callback_manager.callbacks_all])) | |||
| self.logger.info("Callbacks for master workers: {}".format( | |||
| [repr(cb) for cb in self.callback_manager.callbacks_master])) | |||
| start_time = time.time() | |||
| results = {} | |||
| if self.n_epochs <= 0: | |||
| self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||
| results['seconds'] = 0. | |||
| return results | |||
| try: | |||
| self.callback_manager.on_train_begin() | |||
| self._train() | |||
| self.callback_manager.on_train_end() | |||
| except BaseException as e: | |||
| self.callback_manager.on_exception(e) | |||
| if on_exception == 'auto': | |||
| if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||
| try: | |||
| self.callback_manager.on_train_begin() | |||
| self._train() | |||
| self.callback_manager.on_train_end() | |||
| except BaseException as e: | |||
| self.callback_manager.on_exception(e) | |||
| if on_exception == 'auto': | |||
| if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||
| raise e | |||
| else: | |||
| self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||
| elif on_exception == 'raise': | |||
| raise e | |||
| else: | |||
| self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||
| elif on_exception == 'raise': | |||
| raise e | |||
| results['seconds'] = round(time.time() - start_time, 2) | |||
| if self.is_master: | |||
| results['seconds'] = round(time.time() - start_time, 2) | |||
| self.logger.info("###### Train finished ######") | |||
| self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | |||
| return results | |||
| return results | |||
| finally: | |||
| self.close() | |||
| def _train(self): | |||
| if self.fp16: | |||
| @@ -187,7 +212,7 @@ class DistTrainer(): | |||
| self.step = 0 | |||
| self.epoch = 0 | |||
| self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | |||
| leave=False, dynamic_ncols=True, disable=not self.is_master) | |||
| leave=False, dynamic_ncols=True, disable=not self.is_master) | |||
| pbar = self.pbar | |||
| avg_loss = 0 | |||
| data_iterator = self.data_iterator | |||
| @@ -238,18 +263,17 @@ class DistTrainer(): | |||
| (self.validate_every < 0 and self.step % len(data_iterator) == 0)): | |||
| eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
| self.n_steps) | |||
| if self.is_master: | |||
| self.logger.info(eval_str) | |||
| self.logger.info(eval_str) | |||
| self.callback_manager.on_validation() | |||
| dist.barrier() | |||
| if self.save_path and \ | |||
| if self.cp_save_path and \ | |||
| self.save_every > 0 and \ | |||
| self.step % self.save_every == 0: | |||
| self.save_check_point() | |||
| # ================= mini-batch end ==================== # | |||
| if self.save_path and self.save_every < 0: | |||
| if self.save_every < 0 and self.cp_save_path: | |||
| self.save_check_point() | |||
| # lr decay; early stopping | |||
| self.callback_manager.on_epoch_end() | |||
| @@ -287,16 +311,15 @@ class DistTrainer(): | |||
| return loss.mean() | |||
| def save_check_point(self, only_params=False): | |||
| # only master save models | |||
| if self.is_master: | |||
| if not os.path.exists(self.save_path): | |||
| os.makedirs(self.save_path) | |||
| path = os.path.join(self.save_path, 'checkpoint-{}.bin'.format(self.step)) | |||
| os.makedirs(self.cp_save_path, exist_ok=True) | |||
| path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) | |||
| self.logger.info("Save checkpoint to {}".format(path)) | |||
| model_to_save = self.model.module | |||
| if only_params: | |||
| model_to_save = model_to_save.state_dict() | |||
| torch.save(model_to_save, path) | |||
| dist.barrier() | |||
| def close(self): | |||
| dist.destroy_process_group() | |||
| @@ -431,13 +431,13 @@ class Trainer(object): | |||
| super(Trainer, self).__init__() | |||
| if not isinstance(model, nn.Module): | |||
| raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
| # check metrics and dev_data | |||
| if (not metrics) and dev_data is not None: | |||
| raise ValueError("No metric for dev_data evaluation.") | |||
| if metrics and (dev_data is None): | |||
| raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
| # check update every | |||
| assert update_every >= 1, "update_every must be no less than 1." | |||
| self.update_every = int(update_every) | |||
| @@ -447,7 +447,7 @@ class Trainer(object): | |||
| raise ValueError("save_path can only be None or `str`.") | |||
| # prepare evaluate | |||
| metrics = _prepare_metrics(metrics) | |||
| # parse metric_key | |||
| # increase_better is True. It means the exp result gets better if the indicator increases. | |||
| # It is true by default. | |||
| @@ -546,7 +546,7 @@ class Trainer(object): | |||
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
| else: | |||
| raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
| self.use_tqdm = use_tqdm | |||
| self.pbar = None | |||
| self.print_every = abs(self.print_every) | |||
| @@ -558,10 +558,10 @@ class Trainer(object): | |||
| batch_size=self.batch_size, | |||
| device=None, # 由上面的部分处理device | |||
| verbose=0) | |||
| self.step = 0 | |||
| self.start_time = None # start timestamp | |||
| self.callback_manager = CallbackManager(env={"trainer": self}, | |||
| callbacks=callbacks) | |||
| @@ -597,7 +597,7 @@ class Trainer(object): | |||
| self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
| start_time = time.time() | |||
| print("training epochs started " + self.start_time, flush=True) | |||
| try: | |||
| self.callback_manager.on_train_begin() | |||
| self._train() | |||
| @@ -610,7 +610,7 @@ class Trainer(object): | |||
| raise e | |||
| elif on_exception == 'raise': | |||
| raise e | |||
| if self.dev_data is not None and self.best_dev_perf is not None: | |||
| print( | |||
| "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
| @@ -628,9 +628,9 @@ class Trainer(object): | |||
| finally: | |||
| pass | |||
| results['seconds'] = round(time.time() - start_time, 2) | |||
| return results | |||
| def _train(self): | |||
| if not self.use_tqdm: | |||
| from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||
| @@ -656,21 +656,21 @@ class Trainer(object): | |||
| # negative sampling; replace unknown; re-weight batch_y | |||
| self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
| prediction = self._data_forward(self.model, batch_x) | |||
| # edit prediction | |||
| self.callback_manager.on_loss_begin(batch_y, prediction) | |||
| loss = self._compute_loss(prediction, batch_y).mean() | |||
| avg_loss += loss.item() | |||
| loss = loss / self.update_every | |||
| # Is loss NaN or inf? requires_grad = False | |||
| self.callback_manager.on_backward_begin(loss) | |||
| self._grad_backward(loss) | |||
| self.callback_manager.on_backward_end() | |||
| self._update() | |||
| self.callback_manager.on_step_end() | |||
| if self.step % self.print_every == 0: | |||
| avg_loss = float(avg_loss) / self.print_every | |||
| if self.use_tqdm: | |||
| @@ -684,7 +684,7 @@ class Trainer(object): | |||
| pbar.set_postfix_str(print_output) | |||
| avg_loss = 0 | |||
| self.callback_manager.on_batch_end() | |||
| if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
| (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
| and self.dev_data is not None: | |||
| @@ -693,20 +693,20 @@ class Trainer(object): | |||
| self.n_steps) + \ | |||
| self.tester._format_eval_results(eval_res) | |||
| pbar.write(eval_str + '\n') | |||
| # ================= mini-batch end ==================== # | |||
| # lr decay; early stopping | |||
| self.callback_manager.on_epoch_end() | |||
| # =============== epochs end =================== # | |||
| pbar.close() | |||
| self.pbar = None | |||
| # ============ tqdm end ============== # | |||
| def _do_validation(self, epoch, step): | |||
| self.callback_manager.on_valid_begin() | |||
| res = self.tester.test() | |||
| is_better_eval = False | |||
| if self._better_eval_result(res): | |||
| if self.save_path is not None: | |||
| @@ -721,7 +721,7 @@ class Trainer(object): | |||
| # get validation results; adjust optimizer | |||
| self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||
| return res | |||
| def _mode(self, model, is_test=False): | |||
| """Train mode or Test mode. This is for PyTorch currently. | |||
| @@ -733,14 +733,14 @@ class Trainer(object): | |||
| model.eval() | |||
| else: | |||
| model.train() | |||
| def _update(self): | |||
| """Perform weight update on a model. | |||
| """ | |||
| if self.step % self.update_every == 0: | |||
| self.optimizer.step() | |||
| def _data_forward(self, network, x): | |||
| x = _build_args(self._forward_func, **x) | |||
| y = network(**x) | |||
| @@ -748,7 +748,7 @@ class Trainer(object): | |||
| raise TypeError( | |||
| f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | |||
| return y | |||
| def _grad_backward(self, loss): | |||
| """Compute gradient with link rules. | |||
| @@ -759,7 +759,7 @@ class Trainer(object): | |||
| if (self.step-1) % self.update_every == 0: | |||
| self.model.zero_grad() | |||
| loss.backward() | |||
| def _compute_loss(self, predict, truth): | |||
| """Compute loss given prediction and ground truth. | |||
| @@ -768,7 +768,7 @@ class Trainer(object): | |||
| :return: a scalar | |||
| """ | |||
| return self.losser(predict, truth) | |||
| def _save_model(self, model, model_name, only_param=False): | |||
| """ 存储不含有显卡信息的state_dict或model | |||
| :param model: | |||
| @@ -791,7 +791,7 @@ class Trainer(object): | |||
| model.cpu() | |||
| torch.save(model, model_path) | |||
| model.to(self._model_device) | |||
| def _load_model(self, model, model_name, only_param=False): | |||
| # 返回bool值指示是否成功reload模型 | |||
| if self.save_path is not None: | |||
| @@ -809,7 +809,7 @@ class Trainer(object): | |||
| else: | |||
| return False | |||
| return True | |||
| def _better_eval_result(self, metrics): | |||
| """Check if the current epoch yields better validation results. | |||
| @@ -835,6 +835,9 @@ class Trainer(object): | |||
| is_better = False | |||
| return is_better | |||
| @property | |||
| def is_master(self): | |||
| return True | |||
| DEFAULT_CHECK_BATCH_SIZE = 2 | |||
| DEFAULT_CHECK_NUM_BATCH = 2 | |||
| @@ -4,7 +4,7 @@ import numpy as np | |||
| import torch.cuda | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import CrossEntropyLoss | |||
| from fastNLP import CrossEntropyLoss, BCELoss | |||
| from fastNLP import SGD | |||
| from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | |||
| from fastNLP.models.base_model import NaiveClassifier | |||
| @@ -12,6 +12,7 @@ import shutil | |||
| import os | |||
| import subprocess | |||
| from argparse import ArgumentParser | |||
| from fastNLP.core.callback import EchoCallback | |||
| def prepare_fake_dataset(): | |||
| mean = np.array([-3, -3]) | |||
| @@ -36,6 +37,26 @@ def prepare_fake_dataset2(*args, size=100): | |||
| def set_rng_seed(seed): | |||
| np.random.seed(seed) | |||
| def prepare_env(): | |||
| def prepare_fake_dataset(): | |||
| mean = np.array([-3, -3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| mean = np.array([3, 3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
| [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
| return data_set | |||
| data_set = prepare_fake_dataset() | |||
| data_set.set_input("x") | |||
| data_set.set_target("y") | |||
| model = NaiveClassifier(2, 1) | |||
| return data_set, model | |||
| class TestDistTrainer(unittest.TestCase): | |||
| save_path = './save_cp' | |||
| @@ -84,23 +105,35 @@ class TestDistTrainer(unittest.TestCase): | |||
| if trainer.is_master and os.path.exists(self.save_path): | |||
| shutil.rmtree(self.save_path) | |||
| def run3(self): | |||
| data_set, model = prepare_env() | |||
| trainer = DistTrainer( | |||
| data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=3, print_every=50, | |||
| callbacks_all=[EchoCallback('callbacks_all')], | |||
| callbacks_master=[EchoCallback('callbacks_master')] | |||
| ) | |||
| trainer.train() | |||
| def run_dist(self, run_id): | |||
| if torch.cuda.is_available(): | |||
| ngpu = min(4, torch.cuda.device_count()) | |||
| ngpu = min(2, torch.cuda.device_count()) | |||
| path = __file__ | |||
| cmd = ['python', '-m', 'torch.distributed.launch', | |||
| '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | |||
| print(' '.join(cmd)) | |||
| retcode = subprocess.call(cmd) | |||
| if retcode: | |||
| raise RuntimeError('subprocess got non-zero exit status %d' % retcode) | |||
| subprocess.check_call(cmd, timeout=60.0) | |||
| def test1(self): | |||
| def test_normal_run(self): | |||
| self.run_dist(1) | |||
| def test2(self): | |||
| def test_fp16(self): | |||
| self.run_dist(2) | |||
| def test_callback(self): | |||
| self.run_dist(3) | |||
| if __name__ == '__main__': | |||
| runner = TestDistTrainer() | |||
| parser = ArgumentParser() | |||