#!/usr/bin/python # -*- coding: utf-8 -*- # __author__="Danqing Wang" # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os import sys import time import numpy as np import torch from fastNLP.core.const import Const from fastNLP.io.model_io import ModelSaver from fastNLP.core.callback import Callback, EarlyStopError from tools.logger import * class TrainCallback(Callback): def __init__(self, hps, patience=3, quit_all=True): super().__init__() self._hps = hps self.patience = patience self.wait = 0 if type(quit_all) != bool: raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") self.quit_all = quit_all def on_epoch_begin(self): self.epoch_start_time = time.time() # def on_loss_begin(self, batch_y, predict_y): # """ # # :param batch_y: dict # input_len: [batch, N] # :param predict_y: dict # p_sent: [batch, N, 2] # :return: # """ # input_len = batch_y[Const.INPUT_LEN] # batch_y[Const.TARGET] = batch_y[Const.TARGET] * ((1 - input_len) * -100) # # predict_y["p_sent"] = predict_y["p_sent"] * input_len.unsqueeze(-1) # # logger.debug(predict_y["p_sent"][0:5,:,:]) def on_backward_begin(self, loss): """ :param loss: [] :return: """ if not (np.isfinite(loss.data)).numpy(): logger.error("train Loss is not finite. Stopping.") logger.info(loss) for name, param in self.model.named_parameters(): if param.requires_grad: logger.info(name) logger.info(param.grad.data.sum()) raise Exception("train Loss is not finite. Stopping.") def on_backward_end(self): if self._hps.grad_clip: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm) def on_epoch_end(self): logger.info(' | end of epoch {:3d} | time: {:5.2f}s | ' .format(self.epoch, (time.time() - self.epoch_start_time))) def on_valid_begin(self): self.valid_start_time = time.time() def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): logger.info(' | end of valid {:3d} | time: {:5.2f}s | ' .format(self.epoch, (time.time() - self.valid_start_time))) # early stop if not is_better_eval: if self.wait == self.patience: train_dir = os.path.join(self._hps.save_root, "train") save_file = os.path.join(train_dir, "earlystop.pkl") saver = ModelSaver(save_file) saver.save_pytorch(self.model) logger.info('[INFO] Saving early stop model to %s', save_file) raise EarlyStopError("Early stopping raised.") else: self.wait += 1 else: self.wait = 0 # lr descent if self._hps.lr_descent: new_lr = max(5e-6, self._hps.lr / (self.epoch + 1)) for param_group in list(optimizer.param_groups): param_group['lr'] = new_lr logger.info("[INFO] The learning rate now is %f", new_lr) def on_exception(self, exception): if isinstance(exception, KeyboardInterrupt): logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") train_dir = os.path.join(self._hps.save_root, "train") save_file = os.path.join(train_dir, "earlystop.pkl") saver = ModelSaver(save_file) saver.save_pytorch(self.model) logger.info('[INFO] Saving early stop model to %s', save_file) if self.quit_all is True: sys.exit(0) # 直接退出程序 else: pass else: raise exception # 抛出陌生Error