#!/usr/bin/python # -*- coding: utf-8 -*- # __author__="Danqing Wang" # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os import sys import time import numpy as np import torch from fastNLP.core.const import Const from fastNLP.io.model_io import ModelSaver from fastNLP.core.callback import Callback, EarlyStopError from fastNLP.core._logger import logger class TrainCallback(Callback): def __init__(self, hps, patience=3, quit_all=True): super().__init__() self._hps = hps self.patience = patience self.wait = 0 self.train_loss = 0.0 self.prev_train_avg_loss = 1000.0 self.train_dir = os.path.join(self._hps.save_root, "train") if type(quit_all) != bool: raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") self.quit_all = quit_all def on_epoch_begin(self): self.epoch_start_time = time.time() self.model.Train = True def on_backward_begin(self, loss): """ :param loss: [] :return: """ if not (np.isfinite(loss.data)).numpy(): logger.error("train Loss is not finite. Stopping.") logger.info(loss) for name, param in self.model.named_parameters(): if param.requires_grad: logger.info(name) logger.info(param.grad.data.sum()) raise Exception("train Loss is not finite. Stopping.") self.train_loss += loss.data def on_backward_end(self): if self._hps.grad_clip: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm) torch.cuda.empty_cache() def on_epoch_end(self): epoch_avg_loss = self.train_loss / self.n_steps logger.info(' | end of epoch {:3d} | time: {:5.2f}s | train loss: {:5.6f}' .format(self.epoch, (time.time() - self.epoch_start_time), epoch_avg_loss)) if self.prev_train_avg_loss < epoch_avg_loss: save_file = os.path.join(self.train_dir, "earlystop.pkl") self.save_model(save_file) else: self.prev_train_avg_loss = epoch_avg_loss self.train_loss = 0.0 # save epoch save_file = os.path.join(self.train_dir, "epoch_%d.pkl" % self.epoch) self.save_model(save_file) def on_valid_begin(self): self.valid_start_time = time.time() self.model.Train = False def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): logger.info(' | end of valid {:3d} | time: {:5.2f}s | ' .format(self.epoch, (time.time() - self.valid_start_time))) # early stop if not is_better_eval: if self.wait == self.patience: train_dir = os.path.join(self._hps.save_root, "train") save_file = os.path.join(train_dir, "earlystop.pkl") self.save_model(save_file) raise EarlyStopError("Early stopping raised.") else: self.wait += 1 else: self.wait = 0 # lr descent if self._hps.lr_descent: new_lr = max(5e-6, self._hps.lr / (self.epoch + 1)) for param_group in list(optimizer.param_groups): param_group['lr'] = new_lr logger.info("[INFO] The learning rate now is %f", new_lr) def on_exception(self, exception): if isinstance(exception, KeyboardInterrupt): logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") save_file = os.path.join(self.train_dir, "earlystop.pkl") self.save_model(save_file) if self.quit_all is True: sys.exit(0) # 直接退出程序 else: pass else: raise exception # 抛出陌生Error def save_model(self, save_file): saver = ModelSaver(save_file) saver.save_pytorch(self.model) logger.info('[INFO] Saving model to %s', save_file)