From a6ff3f8cc3038efc7033a6be166aea932c6a11be Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 12 Apr 2022 13:49:03 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E5=B0=86=E5=B8=A6=E6=9C=89Monitor?= =?UTF-8?q?=E7=9A=84callback=E9=83=BD=E6=8A=BD=E8=B1=A1=E4=B8=BAHasMonitor?= =?UTF-8?q?Callback=EF=BC=8C=E5=B9=B6=E7=94=B1=E8=BF=99=E4=B8=AA=E7=88=B6?= =?UTF-8?q?=E7=B1=BB=E8=BF=9B=E8=A1=8Cmonitor=E7=9A=84=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?=E5=92=8C=E6=A3=80=E9=AA=8C=E7=9A=84;=202.=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E4=BB=8ETrainer=E4=B8=AD=E8=AE=BE=E7=BD=AEmonitor=E7=BB=99?= =?UTF-8?q?=E6=89=80=E6=9C=89=E7=9A=84Callback=E4=BD=BF=E7=94=A8;3.?= =?UTF-8?q?=E6=96=B0=E5=A2=9EEarlyStopCallback.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/__init__.py | 4 +- fastNLP/core/callbacks/callback.py | 84 +++++++++++++++++- fastNLP/core/callbacks/checkpoint_callback.py | 87 +++++++------------ fastNLP/core/callbacks/early_stop_callback.py | 61 +++++++++++++ .../callbacks/load_best_model_callback.py | 29 +++---- fastNLP/core/callbacks/progress_callback.py | 46 +++------- fastNLP/core/callbacks/utils.py | 30 ++++--- fastNLP/core/controllers/trainer.py | 14 +++ fastNLP/core/utils/exceptions.py | 10 +++ tests/core/callbacks/test_utils.py | 7 +- 10 files changed, 246 insertions(+), 126 deletions(-) create mode 100644 fastNLP/core/callbacks/early_stop_callback.py create mode 100644 fastNLP/core/utils/exceptions.py diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index a47ab998..fc5d9d5b 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -10,7 +10,8 @@ __all__ = [ 'ProgressCallback', 'RichCallback', "LRSchedCallback", - 'LoadBestModelCallback' + 'LoadBestModelCallback', + "EarlyStopCallback" ] @@ -21,4 +22,5 @@ from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallb from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback +from .early_stop_callback import EarlyStopCallback diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index b2d99b51..4b553a1f 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -1,11 +1,15 @@ -from typing import Union, Callable, Dict, Optional +from typing import Union, Callable, Dict, Optional, Any +from abc import ABC __all__ = [ 'Callback', ] from .callback_events import Events, EventsList, Filter +from .utils import _get_monitor_value from fastNLP.core.callbacks.callback_events import _SingleEventState +from fastNLP.core.log import logger +from fastNLP.core.utils import apply_to_collection class Callback: @@ -150,4 +154,82 @@ class _CallbackWrapper(Callback): return self.fn.__name__ +class CanItemDataType(ABC): + """ + 检测可以进行传输的对象。 + + """ + + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is CanItemDataType: + item = getattr(subclass, 'item', None) + return callable(item) + return NotImplemented + + +class HasMonitorCallback(Callback): + def __init__(self, monitor, larger_better, must_have_monitor=False): + self.set_monitor(monitor, larger_better) + self.must_have_moinitor = must_have_monitor + + def set_monitor(self, monitor, larger_better): + self.monitor = str(monitor) if monitor is not None else None + self.larger_better = bool(larger_better) + if larger_better: + self.monitor_value = float('-inf') + else: + self.monitor_value = float('inf') + self._real_monitor = self.monitor + + def on_after_trainer_initialized(self, trainer, driver): + """ + 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 + 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 + + :param trainer: + :param driver: + :return: + """ + if self.monitor is None and trainer.monitor is not None: + self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) + if self.must_have_moinitor and self.monitor is None: + raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " + f"You can set it in the initialization or through Trainer.") + + def get_monitor_value(self, results:Dict)->float: + """ + 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 + :param results: + :return: + """ + if len(results)==0: + return 0 + # 保证所有的 tensor 都被转换为了 python 特定的类型 + results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) + use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, + real_monitor=self._real_monitor, + res=results) + if self._real_monitor != use_monitor: # 发生了替换需要打印 + logger.warning( + f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " + f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") + self._real_monitor = use_monitor + return monitor_value + + def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): + """ + 检测 monitor_value 是否是更好的 + + :param monitor_value: + :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 + :return: + """ + better = False + if (self.larger_better and monitor_value > self.monitor_value) or \ + (not self.larger_better and monitor_value < self.monitor_value): + better = True + if keep_if_better: + self.monitor_value = monitor_value + return better \ No newline at end of file diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 12b6a9e6..839a9522 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -5,12 +5,12 @@ __all__ = [ import os from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping from pathlib import Path -from abc import ABC import sys +from copy import deepcopy import fastNLP -from .callback import Callback, Filter +from .callback import Callback, HasMonitorCallback from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME @@ -18,22 +18,7 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir from fastNLP.core.utils import apply_to_collection -class CanItemDataType(ABC): - """ - 检测可以进行传输的对象。 - - """ - - @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: - if cls is CanItemDataType: - item = getattr(subclass, 'item', None) - return callable(item) - return NotImplemented - - - -class CheckpointCallback(Callback): +class CheckpointCallback(HasMonitorCallback): def __init__( self, monitor, @@ -48,13 +33,8 @@ class CheckpointCallback(Callback): model_save_fn: Optional[Callable] = None, **kwargs, ): - # 我们新加了逻辑,如果 checkpoint callback 自己没有设置 monitor 和 larger_better,那么我们会将其在 trainer 中的设置赋值给它们; - # if monitor is None and save_topk is not None: - # raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") - - if monitor is not None and not isinstance(monitor, str): - raise ValueError("Parameter `monitor` should be of 'str' type.") - + super().__init__(monitor=monitor, larger_better=larger_better, + must_have_monitor=save_topk is not None) if save_folder is None: logger.warning( "Parameter `path` is None, and we will use the current work directory to find and load your model.") @@ -92,13 +72,12 @@ class CheckpointCallback(Callback): "`BaseException` type.") else: save_on_exception = [] - self.monitor = monitor + self.save_folder = Path(save_folder) self.save_every_n_epochs = save_every_n_epochs self.save_every_n_batches = save_every_n_batches self.save_last = save_last self.save_topk = save_topk - self.larger_better = larger_better self.only_state_dict = only_state_dict self.model_save_fn = model_save_fn self.save_on_exception = save_on_exception @@ -108,12 +87,6 @@ class CheckpointCallback(Callback): self._topk_model = {} self._topn = 0 # 表示目前已经保存了几个最好的模型; - # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 - # key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 - # 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; - # 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; - self._real_monitor = self.monitor - # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) @@ -121,20 +94,15 @@ class CheckpointCallback(Callback): synchronize_mkdir(self.timestamp_path) def on_after_trainer_initialized(self, trainer, driver): - if self.monitor is None: - if trainer.monitor is not None: - self.monitor = trainer.monitor - self.larger_better = trainer.larger_better - elif self.save_topk is not None: - raise RuntimeError("You are using `topk` mode, but you have not set the `monitor` value either in this" - "callback or in trainer.") - else: - self.monitor = None + if self.save_topk is not None: + super().on_after_trainer_initialized(trainer, driver) if self.save_topk is not None and trainer.evaluator is None: - raise RuntimeError("You are using `topk` mode, but there is no `evaluator` in trainer.") + logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") - def on_validate_end(self, trainer, validate_res): - self._save_topk(trainer, validate_res) + def on_validate_end(self, trainer, results): + if len(results) == 0: + return + self._save_topk(trainer, results) def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: @@ -157,7 +125,7 @@ class CheckpointCallback(Callback): def on_sanity_check_end(self, trainer, sanity_check_res): # 主要核对一下 monitor 是否存在。 - self._get_validate_metric(sanity_check_res) + self.get_monitor_value(results=sanity_check_res) def on_save_checkpoint(self, trainer) -> Dict: """ @@ -168,8 +136,7 @@ class CheckpointCallback(Callback): states = {} states['timestamp_path'] = str(self.timestamp_path.absolute()) - states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, - function=lambda x:x.item()) + states['_topk_model'] = deepcopy(self._topk_model) states['save_topk'] = 0 if self.save_topk is None else self.save_topk states['_real_monitor'] = self._real_monitor return states @@ -190,30 +157,30 @@ class CheckpointCallback(Callback): self._topk_model.update(self._topk_model) self._real_monitor = states["real_monitor"] - def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): + def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): """ 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 :param trainer: - :param validate_res: + :param results: :return: """ if self.save_topk is not None: - _metric_value = self._get_validate_metric(validate_res) + monitor_value = self.get_monitor_value(results=results) folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ - f"-{self._real_monitor}_{_metric_value}" + f"-{self._real_monitor}_{monitor_value}" _should_save = False if self._topn < self.save_topk: - self._topk_model[folder_name] = _metric_value + self._topk_model[folder_name] = monitor_value self._topn += 1 _should_save = True else: _least_valuable_model = (min if self.larger_better else max)(self._topk_model, key=lambda x: self._topk_model[x]) - if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ - (self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): - self._topk_model[folder_name] = _metric_value + if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ + (self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): + self._topk_model[folder_name] = monitor_value _should_save = True self._topk_model.pop(_least_valuable_model) synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) @@ -249,7 +216,11 @@ class CheckpointCallback(Callback): :return: """ use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) + if self._real_monitor != use_monitor: + logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " + f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") self._real_monitor = use_monitor + return value @property @@ -277,7 +248,7 @@ class ModelCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -324,7 +295,7 @@ class TrainerCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py new file mode 100644 index 00000000..602236f7 --- /dev/null +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -0,0 +1,61 @@ +__all__ = [ + 'EarlyStopCallback' +] + +from typing import Dict + +from .callback import HasMonitorCallback +from fastNLP.core.utils.exceptions import EarlyStopException + + +class EarlyStopCallback(HasMonitorCallback): + def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): + """ + + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 + :param larger_better: monitor 的值是否是越大越好。 + :param patience: 多少次 validate 不没有提升就停止。 + """ + super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) + self.wait = 0 + self.patience = patience + + def on_validate_end(self, trainer, results): + if len(results)==0: + return + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): + self.wait = 0 + else: + self.wait += 1 + + def on_fetch_data_begin(self, trainer): + # 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 + if self.wait >= self.patience: + raise EarlyStopException(f"After {self.wait} validations, no improvement for " + f"metric `{self._real_monitor}`") + + def on_train_epoch_begin(self, trainer): + # 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 + if self.wait >= self.patience: + raise EarlyStopException(f"After {self.wait} validations, no improvement for " + f"metric `{self._real_monitor}`(best value: {self.monitor_value})") + + def on_save_checkpoint(self, trainer) -> Dict: + states = { + 'patience': self.patience, + 'wait': self.wait, + 'monitor': self.monitor, + 'monitor_value': self.monitor_value + } + return states + + def on_load_checkpoint(self, trainer, states): + self.patience = states['patience'] + self.wait = states['wait'] + self.monitor = states['monitor'] + self.monitor_value = float(states['monitor_value']) + + def callback_name(self): + return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' + diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index e7b94f8c..9a4bb65f 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -4,8 +4,7 @@ __all__ = [ import os from typing import Optional, Callable -from .callback import Callback -from .utils import _get_monitor_value +from .callback import HasMonitorCallback from io import BytesIO import shutil @@ -14,15 +13,15 @@ from fastNLP.core.log import logger from fastNLP.envs import all_rank_call -class LoadBestModelCallback(Callback): - def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True, +class LoadBestModelCallback(HasMonitorCallback): + def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, model_load_fn:Optional[Callable] = None, delete_after_train:bool = True): """ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 - :param str monitor: 监控的 metric 值。 + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 :param larger_better: 该 metric 值是否是越大越好。 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 @@ -33,6 +32,7 @@ class LoadBestModelCallback(Callback): 请在函数内完成对模型的加载。 :param delete_after_train: 在训练结束后是否删掉模型。 """ + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) if model_load_fn is not None: assert callable(model_load_fn), "`model_load_fn` must be a callable object." assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." @@ -56,15 +56,11 @@ class LoadBestModelCallback(Callback): self.real_save_folder = None self.buffer = BytesIO() - self.monitor = monitor - self.larger_better = larger_better self.save_folder = save_folder self.only_state_dict = only_state_dict self.model_save_fn = model_save_fn self.model_load_fn = model_load_fn self.delete_after_after = delete_after_train - self._real_monitor = None - self.monitor_value = float('-inf') if larger_better else float('inf') def on_after_trainer_initialized(self, trainer, driver): if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: @@ -76,13 +72,16 @@ class LoadBestModelCallback(Callback): raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " f"save best model when launch using script.") + super().on_after_trainer_initialized(trainer, driver) + + def on_sanity_check_end(self, trainer, sanity_check_res): + self.get_monitor_value(sanity_check_res) + def on_validate_end(self, trainer, results): - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (monitor_value < self.monitor_value and self.larger_better is False) or \ - (monitor_value > self.monitor_value and self.larger_better): - self.monitor_value = monitor_value + if len(results)==0: + return + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, model_save_fn=self.model_save_fn) diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 633fbb09..756d236b 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -8,7 +8,7 @@ __all__ = [ 'RichCallback' ] -from .callback import Callback +from .callback import HasMonitorCallback from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger @@ -28,15 +28,13 @@ def choose_progress_callback(progress_bar:str): return None -class ProgressCallback(Callback): +class ProgressCallback(HasMonitorCallback): def on_train_end(self, trainer): f_rich_progress.stop() def on_sanity_check_end(self, trainer, sanity_check_res): if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=sanity_check_res) + self.get_monitor_value(sanity_check_res) class RichCallback(ProgressCallback): @@ -46,28 +44,22 @@ class RichCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ - super().__init__() + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every self.progress_bar = f_rich_progress self.task2id = {} self.loss = 0 self.loss_round_ndigit = loss_round_ndigit - self.monitor = monitor - self.larger_better = larger_better - if larger_better: - self.monitor_value = float('-inf') - else: - self.monitor_value = float('inf') - self._real_monitor = monitor self.format_json = format_json def on_after_trainer_initialized(self, trainer, driver): if not self.progress_bar.disable: self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) + super(RichCallback, self).on_after_trainer_initialized(trainer, driver) def on_train_begin(self, trainer): self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, @@ -109,16 +101,12 @@ class RichCallback(ProgressCallback): text_style = '' characters = '-' if self.monitor is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (self.larger_better and monitor_value > self.monitor_value) or \ - (not self.larger_better and monitor_value < self.monitor_value): + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if abs(self.monitor_value) != float('inf'): rule_style = 'spring_green3' text_style = '[bold]' characters = '+' - self.monitor_value = monitor_value self.progress_bar.print() self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " f"Batch:{trainer.batch_idx_in_epoch}", @@ -151,18 +139,12 @@ class RawTextCallback(ProgressCallback): :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ - super().__init__() + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every self.task2id = {} self.loss = 0 self.loss_round_ndigit = loss_round_ndigit - self.monitor = monitor - self.larger_better = larger_better - if larger_better: - self.monitor_value = float('-inf') - else: - self.monitor_value = float('inf') - self._real_monitor = monitor + self.set_monitor(monitor, larger_better) self.format_json = format_json self.num_signs = 10 @@ -189,14 +171,10 @@ class RawTextCallback(ProgressCallback): base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' text = '' if self.monitor is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (self.larger_better and monitor_value > self.monitor_value) or \ - (not self.larger_better and monitor_value < self.monitor_value): + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if abs(self.monitor_value) != float('inf'): text = '+'*self.num_signs + base_text + '+'*self.num_signs - self.monitor_value = monitor_value if len(text) == 0: text = '-'*self.num_signs + base_text + '-'*self.num_signs diff --git a/fastNLP/core/callbacks/utils.py b/fastNLP/core/callbacks/utils.py index 900aebf6..2720ba3f 100644 --- a/fastNLP/core/callbacks/utils.py +++ b/fastNLP/core/callbacks/utils.py @@ -19,23 +19,31 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( if monitor in res: return monitor, res[monitor] + if real_monitor in res: + return real_monitor, res[real_monitor] + pairs = [] for idx, (key, value) in enumerate(res.items()): - match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor)) - pairs.append((key, value, match.size, idx)) + match_size = _match_length(monitor, key) + pairs.append((key, value, match_size, idx)) pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) key, value, match_size = pairs[0][:3] - if real_monitor is not None and real_monitor in res and real_monitor != key: - # 如果 real_monitor 比新找的更长就继续用之前的。 - match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor)) - if match.size > match_size: - return real_monitor, res[real_monitor] + return key, value + - logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), " - f"we use the `{key}` as the monitor.") - real_monitor = key - return real_monitor, value +def _match_length(a:str, b:str)->int: + """ + 需要把长度短的放在前面 + + :param a: + :param b: + :return: + """ + short = a if len(a) < len(b) else b + long = a if len(a)>=len(b) else b + match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) + return match.size diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index d710f967..b360c6a0 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -25,6 +25,7 @@ from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, matc from fastNLP.envs import rank_zero_call from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME +from fastNLP.core.utils.exceptions import EarlyStopException class Trainer(TrainerEventTrigger): @@ -49,6 +50,8 @@ class Trainer(TrainerEventTrigger): output_mapping: Optional[Union[Callable, Dict]] = None, accumulation_steps: int = 1, fp16: bool = False, + monitor: str = None, + larger_better: bool = True, marker: Optional[str] = None, **kwargs ): @@ -102,6 +105,10 @@ class Trainer(TrainerEventTrigger): 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; + :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 + 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。 + :param larger_better: monitor 的值是否是越大越好。 :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; :param kwargs: 一些其它的可能需要的参数; torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; @@ -210,6 +217,8 @@ class Trainer(TrainerEventTrigger): self.evaluator = None self.epoch_validate = lambda *args, **kwargs: ... self.step_validate = lambda *args, **kwargs: ... + self.monitor = monitor + self.larger_better = larger_better if metrics is not None and validate_dataloaders is not None: if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") @@ -239,6 +248,7 @@ class Trainer(TrainerEventTrigger): else: # validate_every > 0 self._step_validate_filter = Filter(every=validate_every) + self.metrics = metrics self.validate_every = validate_every @@ -320,6 +330,10 @@ class Trainer(TrainerEventTrigger): self.driver.barrier() self.on_train_end() self.driver.barrier() + + except EarlyStopException as e: + logger.info(f"Catch early stop exception: {e.msg}.") + self.on_exception(e) except KeyboardInterrupt as e: self.driver.on_exception() self.on_exception(e) diff --git a/fastNLP/core/utils/exceptions.py b/fastNLP/core/utils/exceptions.py new file mode 100644 index 00000000..afedbcba --- /dev/null +++ b/fastNLP/core/utils/exceptions.py @@ -0,0 +1,10 @@ + +class EarlyStopException(BaseException): + r""" + 用于EarlyStop时从Trainer训练循环中跳出。 + + """ + + def __init__(self, msg): + super(EarlyStopException, self).__init__(msg) + self.msg = msg diff --git a/tests/core/callbacks/test_utils.py b/tests/core/callbacks/test_utils.py index 10aba0e0..fdec93e0 100644 --- a/tests/core/callbacks/test_utils.py +++ b/tests/core/callbacks/test_utils.py @@ -12,32 +12,27 @@ def test_get_monitor_value(): with Capturing() as output: monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) assert monitor == 'f1' and value==0.2 - assert 'We can not find' not in output[0] # 测试可以匹配,且选择更靠前的 res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) assert monitor=='acc#f1' and value==0.2 - assert 'We can not find' in output[0] # 测试monitor匹配不上,使用real_monitor res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: - monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res) + monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res) assert monitor=='acc#rec' and value==0.3 - assert 'We can not find' not in output[0] # 测试monitor/real_monitor匹配不上, 重新选择 res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) assert monitor=='acc#f1' and value==0.2 - assert 'We can not find' in output[0] # 测试partial的位置 res = {"acc#acc": 0.52, "loss#loss": 2} with Capturing() as output: monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) assert monitor=='loss#loss' and value==2 - assert 'We can not find' in output[0]