| @@ -4,7 +4,8 @@ __all__ = [ | |||
| 'EventsList', | |||
| 'Filter', | |||
| 'CallbackManager', | |||
| 'CheckpointCallback', | |||
| 'ModelCheckpointCallback', | |||
| 'TrainerCheckpointCallback', | |||
| 'choose_progress_callback', | |||
| 'ProgressCallback', | |||
| 'RichCallback', | |||
| @@ -16,7 +17,7 @@ __all__ = [ | |||
| from .callback import Callback | |||
| from .callback_events import EventsList, Events, Filter | |||
| from .callback_manager import CallbackManager | |||
| from .checkpoint_callback import CheckpointCallback | |||
| from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||
| from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | |||
| from .lr_scheduler_callback import LRSchedCallback | |||
| from .load_best_model_callback import LoadBestModelCallback | |||
| @@ -8,7 +8,7 @@ __all__ = [ | |||
| from .callback_events import Events | |||
| from .callback import Callback | |||
| from .checkpoint_callback import CheckpointCallback | |||
| from .checkpoint_callback import TrainerCheckpointCallback | |||
| from .progress_callback import ProgressCallback, choose_progress_callback | |||
| from fastNLP.core.log import logger | |||
| @@ -98,7 +98,7 @@ class CallbackManager: | |||
| :return: | |||
| """ | |||
| for each_callback in self.class_callbacks: | |||
| if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint: | |||
| if isinstance(each_callback, TrainerCheckpointCallback): | |||
| self._has_trainer_checkpoint = True | |||
| self.dissect_one_callback(each_callback) | |||
| @@ -210,7 +210,7 @@ class CallbackManager: | |||
| each_callback.on_load_checkpoint(trainer, None) | |||
| @property | |||
| def has_trainer_chechpoint(self) -> bool: | |||
| def has_trainer_checkpoint(self) -> bool: | |||
| return self._has_trainer_checkpoint | |||
| @_transfer | |||
| @@ -1,12 +1,13 @@ | |||
| __all__ = [ | |||
| 'ModelCheckpointCallback', | |||
| 'TrainerCheckpointCallback' | |||
| ] | |||
| import os | |||
| from typing import Union, Optional, Callable, Dict, Sequence | |||
| from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | |||
| from pathlib import Path | |||
| from functools import partial | |||
| from time import sleep | |||
| from abc import ABC | |||
| import sys | |||
| __all__ = [ | |||
| 'CheckpointCallback' | |||
| ] | |||
| import fastNLP | |||
| from .callback import Callback, Filter | |||
| @@ -14,35 +15,37 @@ from fastNLP.core.callbacks.utils import _get_monitor_value | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||
| from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
| from fastNLP.core.utils import apply_to_collection | |||
| class CheckpointCallback(Callback): | |||
| class CanItemDataType(ABC): | |||
| """ | |||
| 1. 因为只有 'Trainer' 才有 callback,因此评测 metric 实际上就是 validate 时干的事情; | |||
| 2. 默认 'save_last' 为 True,即 model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型,模型名字为 last.pth.tar; | |||
| 3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视,如果用户在训练过程中指定了多个 monitor 的监视,例如 "acc1", | |||
| "acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例; | |||
| 4. 理论上,在实际保存的过程中,topk 模式和 固定频率保存的模式是完全独立的,我们确实应当采取一些措施至少保证两者的名字不一样; | |||
| 检测可以进行传输的对象。 | |||
| """ | |||
| @classmethod | |||
| def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
| if cls is CanItemDataType: | |||
| item = getattr(subclass, 'item', None) | |||
| return callable(item) | |||
| return NotImplemented | |||
| class CheckpointCallback(Callback): | |||
| def __init__( | |||
| self, | |||
| monitor, | |||
| is_trainer_checkpoint: Optional[bool] = False, | |||
| save_folder: Optional[Union[str, Path]] = None, | |||
| save_every_n_epochs: Optional[int] = None, | |||
| save_every_n_global_batches: Optional[int] = None, | |||
| save_every_n_batches: Optional[int] = None, | |||
| save_last: bool = True, | |||
| save_topk: Optional[int] = None, | |||
| save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | |||
| larger_better: bool = True, | |||
| only_state_dict: bool = True, | |||
| model_save_fn: Optional[Callable] = None, | |||
| **kwargs, | |||
| ): | |||
| if monitor is None and save_topk is not None: | |||
| @@ -51,9 +54,6 @@ class CheckpointCallback(Callback): | |||
| if monitor is not None and not isinstance(monitor, str): | |||
| raise ValueError("Parameter `monitor` should be of 'str' type.") | |||
| if not isinstance(is_trainer_checkpoint, bool): | |||
| raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.") | |||
| if save_folder is None: | |||
| logger.warning( | |||
| "Parameter `path` is None, and we will use the current work directory to find and load your model.") | |||
| @@ -67,15 +67,15 @@ class CheckpointCallback(Callback): | |||
| if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: | |||
| raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") | |||
| # 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的 | |||
| # 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样; | |||
| self._filter_every_n_epochs = Filter(every=save_every_n_epochs) | |||
| else: | |||
| save_every_n_epochs = sys.maxsize # 使得没有数字可以整除 | |||
| if save_every_n_global_batches is not None: | |||
| if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1: | |||
| if save_every_n_batches is not None: | |||
| if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1: | |||
| raise ValueError( | |||
| "parameter save_every_n_global_batches should be an int and greater than or equal to 1.") | |||
| self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches) | |||
| "parameter save_every_n_batches should be an int and greater than or equal to 1.") | |||
| else: | |||
| save_every_n_batches = sys.maxsize # 使得没有数字可以整除 | |||
| if save_topk is not None: | |||
| if not isinstance(save_topk, int) or save_topk < 1: | |||
| @@ -89,12 +89,12 @@ class CheckpointCallback(Callback): | |||
| if not issubclass(exception, BaseException): | |||
| raise TypeError("Each exception in parameter `save_on_exception` can only be " | |||
| "`BaseException` type.") | |||
| else: | |||
| save_on_exception = [] | |||
| self.monitor = monitor | |||
| self.is_trainer_checkpoint = is_trainer_checkpoint | |||
| self.save_folder = Path(save_folder) | |||
| self.save_every_n_epochs = save_every_n_epochs | |||
| self.save_every_n_global_batches = save_every_n_global_batches | |||
| self.save_every_n_batches = save_every_n_batches | |||
| self.save_last = save_last | |||
| self.save_topk = save_topk | |||
| self.larger_better = larger_better | |||
| @@ -107,7 +107,7 @@ class CheckpointCallback(Callback): | |||
| self._topk_model = {} | |||
| self._topn = 0 # 表示目前已经保存了几个最好的模型; | |||
| # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个 | |||
| # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 | |||
| # key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 | |||
| # 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; | |||
| # 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; | |||
| @@ -115,76 +115,83 @@ class CheckpointCallback(Callback): | |||
| # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | |||
| # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | |||
| self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
| self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
| # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | |||
| synchronize_mkdir(self.log_filepath) | |||
| synchronize_mkdir(self.timestamp_path) | |||
| def on_validate_end(self, trainer, validate_res): | |||
| self._save_topk(trainer, validate_res) | |||
| def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | |||
| self._save_every_n_epochs(trainer) | |||
| self._save_last(trainer) | |||
| if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | |||
| folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}' | |||
| self.save(trainer, folder_name=folder_name) | |||
| if self.save_last: | |||
| folder_name = f'{self.folder_prefix}-last' | |||
| self.save(trainer, folder_name=folder_name) | |||
| def on_train_batch_end(self, trainer): | |||
| self._save_every_n_global_batches(trainer) | |||
| if trainer.global_forward_batches % self.save_every_n_batches == 0: | |||
| folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' | |||
| self.save(trainer, folder_name=folder_name) | |||
| def on_exception(self, trainer, exception: BaseException): | |||
| if self.save_on_exception is not None and exception.__class__ in self.save_on_exception: | |||
| folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None) | |||
| folder = folder + f"_{exception.__class__.__name__}" | |||
| self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder) | |||
| if exception.__class__ in self.save_on_exception: | |||
| folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ | |||
| f'exception_{exception.__class__.__name__}' | |||
| self.save(trainer=trainer, folder_name=folder_name) | |||
| def on_sanity_check_end(self, trainer, sanity_check_res): | |||
| # 主要核对一下 monitor 是否存在。 | |||
| self._get_validate_metric(sanity_check_res) | |||
| def on_save_checkpoint(self, trainer) -> Dict: | |||
| """ | |||
| 我们需要保存 CheckpointCallback 内部的几个 filter 的状态; | |||
| 保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 | |||
| topk_model的状态 | |||
| _real_monitor的值 | |||
| """ | |||
| states = {} | |||
| if self.save_every_n_epochs is not None: | |||
| states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict() | |||
| if self.save_every_n_global_batches is not None: | |||
| states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict() | |||
| states["real_monitor"] = self._real_monitor | |||
| states['timestamp_path'] = str(self.timestamp_path.absolute()) | |||
| states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, | |||
| function=lambda x:x.item()) | |||
| states['save_topk'] = 0 if self.save_topk is None else self.save_topk | |||
| states['_real_monitor'] = self._real_monitor | |||
| return states | |||
| def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
| if self.save_every_n_epochs is not None: | |||
| self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"]) | |||
| if self.save_every_n_global_batches is not None: | |||
| self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"]) | |||
| timestamp_path = states['timestamp_path'] | |||
| if not os.path.exists(timestamp_path): | |||
| logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " | |||
| f" {self.timestamp_path.absolute()}.") | |||
| else: | |||
| logger.info(f"Resume to save in path: {timestamp_path}.") | |||
| self.timestamp_path = Path(timestamp_path) | |||
| _topk_model = states['_topk_model'] | |||
| save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | |||
| if save_topk is not None and self.save_topk is not None: | |||
| assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | |||
| f"as {save_topk}." | |||
| self._topk_model.update(self._topk_model) | |||
| self._real_monitor = states["real_monitor"] | |||
| def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"): | |||
| if self.save_every_n_epochs is not None: | |||
| if self.is_trainer_checkpoint: | |||
| _fn_every_n_epochs = trainer.save | |||
| else: | |||
| _fn_every_n_epochs = trainer.save_model | |||
| _fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None) | |||
| _fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs) | |||
| _fn_every_n_epochs() | |||
| def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"): | |||
| if self.save_every_n_global_batches is not None: | |||
| if self.is_trainer_checkpoint: | |||
| _fn_every_n_global_batches = trainer.save | |||
| else: | |||
| _fn_every_n_global_batches = trainer.save_model | |||
| _fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None) | |||
| _fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches) | |||
| _fn_every_n_global_batches() | |||
| def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): | |||
| """ | |||
| 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | |||
| :param trainer: | |||
| :param validate_res: | |||
| :return: | |||
| """ | |||
| if self.save_topk is not None: | |||
| _metric_value = self._get_validate_metric(validate_res) | |||
| _saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value) | |||
| folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | |||
| f"-{self._real_monitor}_{_metric_value}" | |||
| _should_save = False | |||
| if self._topn < self.save_topk: | |||
| self._topk_model[_saved_name] = _metric_value | |||
| self._topk_model[folder_name] = _metric_value | |||
| self._topn += 1 | |||
| _should_save = True | |||
| else: | |||
| @@ -192,39 +199,27 @@ class CheckpointCallback(Callback): | |||
| key=lambda x: self._topk_model[x]) | |||
| if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ | |||
| (self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): | |||
| self._topk_model[_saved_name] = _metric_value | |||
| self._topk_model[folder_name] = _metric_value | |||
| _should_save = True | |||
| self._topk_model.pop(_least_valuable_model) | |||
| synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model)) | |||
| synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | |||
| assert len(self._topk_model) == self.save_topk == self._topn | |||
| if _should_save: | |||
| self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name) | |||
| self.save(trainer, folder_name=folder_name) | |||
| def _save_last(self, trainer: "fastNLP.Trainer"): | |||
| if self.save_last: | |||
| self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last") | |||
| def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None, | |||
| substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None): | |||
| # 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者 | |||
| # epoch-batch-monitor-monitor_value; | |||
| if substitute_folder is None: | |||
| folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric)) | |||
| else: | |||
| folder = self.log_filepath.joinpath(substitute_folder) | |||
| def save(self, trainer, folder_name): | |||
| """ | |||
| 执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。 | |||
| :param trainer: | |||
| :param folder_name: | |||
| :return: | |||
| """ | |||
| folder = self.timestamp_path.joinpath(folder_name) | |||
| synchronize_mkdir(folder) | |||
| # 然后再调用 trainer 的 save_model(用于保存模型)或者 save(用于断点重训)函数; | |||
| if substitute_fn is not None: | |||
| _fn = substitute_fn | |||
| else: | |||
| if self.is_trainer_checkpoint: | |||
| _fn = trainer.save | |||
| else: | |||
| _fn = trainer.save_model | |||
| _fn = getattr(trainer, self.save_fn_name) | |||
| _fn( | |||
| folder=folder, | |||
| only_state_dict=self.only_state_dict, | |||
| @@ -243,18 +238,95 @@ class CheckpointCallback(Callback): | |||
| self._real_monitor = use_monitor | |||
| return value | |||
| def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False, | |||
| metric: Optional[Union[int, float]] = None) -> str: | |||
| @property | |||
| def folder_prefix(self): | |||
| raise NotImplementedError("The `folder_prefix` is not specified") | |||
| @property | |||
| def save_fn_name(self): | |||
| raise NotImplementedError("The `save_fn_name` is not specified.") | |||
| class ModelCheckpointCallback(CheckpointCallback): | |||
| """ | |||
| 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||
| - save_folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||
| - model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||
| - model-last/ # 最后一个 epoch 的保存 | |||
| - model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||
| - model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
| model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||
| 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||
| :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。 | |||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
| :param save_every_n_batches: 多少个 batch 保存一次。 | |||
| :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||
| :param save_topk: 保存 monitor 结果 topK 个。 | |||
| :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||
| :param larger_better: monitor 的值是否时越大越好。 | |||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
| 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
| :param kwargs: | |||
| """ | |||
| @property | |||
| def save_fn_name(self): | |||
| return 'save_model' | |||
| @property | |||
| def callback_name(self): | |||
| """ | |||
| 获取当前保存模型的真正地名字; | |||
| metric 参数仅当 mode 为 'topk' 时起作用; | |||
| 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||
| :return: | |||
| """ | |||
| cur_epoch_idx = trainer.cur_epoch_idx | |||
| global_forward_batches = trainer.global_forward_batches | |||
| _other = "" | |||
| if topk: | |||
| _other = f"_{metric}" | |||
| return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}" | |||
| return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
| @property | |||
| def folder_prefix(self): | |||
| return 'model' | |||
| class TrainerCheckpointCallback(CheckpointCallback): | |||
| """ | |||
| 保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||
| - save_folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 | |||
| - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 | |||
| - trainer-last/ # 最后一个 epoch 的保存 | |||
| - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||
| - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
| model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | |||
| 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||
| :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。 | |||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
| :param save_every_n_batches: 多少个 batch 保存一次。 | |||
| :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||
| :param save_topk: 保存 monitor 结果 topK 个。 | |||
| :param save_on_exception: 在出异常信息时,是否保存。 | |||
| :param larger_better: monitor 的值是否时越大越好。 | |||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。 | |||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
| 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
| :param kwargs: | |||
| """ | |||
| @property | |||
| def save_fn_name(self): | |||
| return 'save' | |||
| @property | |||
| def callback_name(self): | |||
| @@ -262,6 +334,8 @@ class CheckpointCallback(Callback): | |||
| 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||
| :return: | |||
| """ | |||
| return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}" | |||
| return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
| @property | |||
| def folder_prefix(self): | |||
| return 'trainer' | |||
| @@ -31,7 +31,7 @@ class LoadBestModelCallback(Callback): | |||
| 请在函数内完成对模型的保存。 | |||
| :param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||
| 请在函数内完成对模型的加载。 | |||
| :param delete_after_train: 在加载了最佳模型之后是否删掉模型。 | |||
| :param delete_after_train: 在训练结束后是否删掉模型。 | |||
| """ | |||
| if model_load_fn is not None: | |||
| assert callable(model_load_fn), "`model_load_fn` must be a callable object." | |||
| @@ -133,17 +133,18 @@ class Evaluator: | |||
| self.driver.barrier() | |||
| def run(self, num_eval_batch_per_dl: int = -1) -> Dict: | |||
| def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||
| """ | |||
| 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | |||
| 如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||
| metric_indicator_name#metric_name | |||
| 如果存在多个数据集,一个metric的情况,key的命名规则是 | |||
| metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||
| 如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||
| metric_indicator_name#metric_name#dataloader_name | |||
| :param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||
| 如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||
| metric_indicator_name#metric_name | |||
| 如果存在多个数据集,一个metric的情况,key的命名规则是 | |||
| metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||
| 如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||
| metric_indicator_name#metric_name#dataloader_name | |||
| 其中 metric_indicator_name 可能不存在。 | |||
| :param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||
| :return: | |||
| """ | |||
| assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | |||
| @@ -157,7 +158,6 @@ class Evaluator: | |||
| assert self.driver.has_test_dataloaders() | |||
| metric_results = {} | |||
| self.reset() | |||
| evaluate_context = self.driver.get_evaluate_context() | |||
| self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') | |||
| @@ -251,7 +251,7 @@ class Trainer(TrainerEventTrigger): | |||
| self.driver.set_deterministic_dataloader(self.dataloader) | |||
| self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||
| reproducible=self.callback_manager.has_trainer_chechpoint) | |||
| reproducible=self.callback_manager.has_trainer_checkpoint) | |||
| self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | |||
| self.on_after_trainer_initialized(self.driver) | |||
| @@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger): | |||
| raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | |||
| if self.evaluator is not None and num_eval_sanity_batch > 0: | |||
| logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") | |||
| self.on_sanity_check_begin() | |||
| sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | |||
| self.on_sanity_check_end(sanity_check_res) | |||
| @@ -509,7 +510,7 @@ class Trainer(TrainerEventTrigger): | |||
| :param folder: 保存模型的地址; | |||
| :param only_state_dict: 是否只保存模型的 `state_dict`; | |||
| :param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | |||
| :param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | |||
| :param kwargs: 一些 driver 的保存模型的函数的参数另有其它; | |||
| """ | |||
| @@ -534,7 +535,16 @@ class Trainer(TrainerEventTrigger): | |||
| def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False, | |||
| model_load_fn: Optional[Callable] = None, **kwargs): | |||
| """ | |||
| 加载模型 | |||
| :param folder: 读取 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, | |||
| 直接将该 folder 传递到 model_load_fn 中。 | |||
| :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 model_load_fn 不为 None 时,该参数无意义。 | |||
| :param model_load_fn: callable 的函数,接受一个 folder 作为参数,不返回任何内容。 | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| self.on_load_model() | |||
| self.driver.barrier() | |||
| if not isinstance(folder, (io.BytesIO, BinaryIO)): | |||
| @@ -555,7 +565,13 @@ class Trainer(TrainerEventTrigger): | |||
| def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | |||
| r""" | |||
| 用于断点重训的保存函数; | |||
| 用于断点重训 Trainer 的保存函数; | |||
| :param folder: | |||
| :param only_state_dict: | |||
| :param model_save_fn: | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| self.driver.barrier() | |||
| @@ -8,9 +8,8 @@ __all__ = [ | |||
| import _pickle as pickle | |||
| from copy import deepcopy | |||
| from typing import Optional, List, Callable, Union, Dict, Any | |||
| from typing import Optional, List, Callable, Union, Dict, Any, Mapping | |||
| from functools import partial | |||
| import warnings | |||
| import numpy as np | |||
| from threading import Thread | |||
| @@ -197,6 +196,20 @@ class DataSet: | |||
| else: | |||
| raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
| def __setitem__(self, key, value): | |||
| assert isinstance(key, int) and key<len(self) | |||
| assert isinstance(value, Instance) or isinstance(value, Mapping) | |||
| ins_keys = set(value.keys()) | |||
| ds_keys = set(self.get_field_names()) | |||
| if len(ins_keys - ds_keys) != 0: | |||
| raise KeyError(f"The following keys are not found in the Dataset:{list(ins_keys - ds_keys)}.") | |||
| if len(ds_keys - ins_keys) != 0: | |||
| raise KeyError(f"The following keys are not found in the Instance:{list(ds_keys - ins_keys)}.") | |||
| for field_name, field in self.field_arrays.items(): | |||
| field[key] = value[field_name] | |||
| def __getattribute__(self, item): | |||
| return object.__getattribute__(self, item) | |||
| @@ -813,6 +826,3 @@ class DataSet: | |||
| self.collate_fns.set_input(*field_names) | |||
| class IterableDataset: | |||
| pass | |||
| @@ -46,9 +46,6 @@ class FieldArray: | |||
| def __setitem__(self, idx: int, val: Any): | |||
| assert isinstance(idx, int) | |||
| if idx == -1: | |||
| idx = len(self) - 1 | |||
| assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}" | |||
| self.content[idx] = val | |||
| def get(self, indices: Union[int, List[int]]): | |||
| @@ -79,7 +76,7 @@ class FieldArray: | |||
| def split(self, sep: str = None, inplace: bool = True): | |||
| r""" | |||
| 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | |||
| 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。 | |||
| :param sep: 分割符,如果为None则直接调用str.split()。 | |||
| :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
| @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod | |||
| from datetime import datetime | |||
| from pathlib import Path | |||
| from io import BytesIO | |||
| import json | |||
| __all__ = [ | |||
| 'Driver' | |||
| @@ -68,9 +69,12 @@ class Driver(ABC): | |||
| def set_sampler_epoch(self, dataloader, cur_epoch_idx): | |||
| r""" | |||
| 对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | |||
| dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。 | |||
| :param dataloader: 需要设置 epoch 的 dataloader 。 | |||
| :param cur_epoch_idx: 当前是第几个 epoch; | |||
| """ | |||
| @abstractmethod | |||
| def train_step(self, batch): | |||
| """ | |||
| @@ -444,13 +448,14 @@ class Driver(ABC): | |||
| exc_type, exc_value, exc_traceback_obj = sys.exc_info() | |||
| _write_exc_info = { | |||
| 'exc_type': exc_type, | |||
| 'exc_value': exc_value, | |||
| 'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||
| 'global_rank': getattr(self, "global_rank", None), | |||
| 'rank': self.get_local_rank(), | |||
| 'exc_type': str(exc_type.__name__), | |||
| 'exc_value': str(exc_value), | |||
| 'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||
| 'exc_global_rank': getattr(self, "global_rank", None), | |||
| 'exc_local_rank': self.get_local_rank(), | |||
| } | |||
| sys.stderr.write(str(_write_exc_info)+"\n") | |||
| sys.stderr.write("\nException info:\n") | |||
| sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n") | |||
| sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | |||
| for pid in self._pids: | |||
| @@ -402,7 +402,7 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||
| if _TORCH_GREATER_EQUAL_1_8: | |||
| objs = [None for _ in range(dist.get_world_size(group))] | |||
| dist.all_gather_object(objs, obj) | |||
| apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||
| objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||
| return objs | |||
| group = group if group is not None else torch.distributed.group.WORLD | |||
| data = convert_to_tensors(obj, device=device) | |||
| @@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||
| # world_size 和 rank | |||
| if FASTNLP_BACKEND_LAUNCH in os.environ: | |||
| if device is not None: | |||
| logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||
| logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||
| "up your script. And we will directly get the local device via " | |||
| "`os.environ['LOCAL_RANK']`.") | |||
| return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | |||
| @@ -39,11 +39,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||
| if isinstance(device, str): | |||
| device = torch.device(device) | |||
| elif isinstance(device, int): | |||
| if device < 0 and device != -1: | |||
| raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | |||
| if device >= _could_use_device_num: | |||
| if device < 0: | |||
| if device != -1: | |||
| raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | |||
| device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] | |||
| elif device >= _could_use_device_num: | |||
| raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||
| device = torch.device(f"cuda:{device}") | |||
| else: | |||
| device = torch.device(f"cuda:{device}") | |||
| elif isinstance(device, Sequence): | |||
| device = list(set(device)) | |||
| for each in device: | |||
| @@ -62,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||
| if not isinstance(device, List): | |||
| return TorchSingleDriver(model, device, **kwargs) | |||
| else: | |||
| logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||
| logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||
| "`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" | |||
| "`driver` as `TorchDDPDriver`.") | |||
| return TorchDDPDriver(model, device, **kwargs) | |||
| @@ -143,8 +143,6 @@ class TorchDriver(Driver): | |||
| :param filepath: 保存到哪个文件夹; | |||
| :param only_state_dict: 是否只保存权重; | |||
| :param model_save_fn: | |||
| :return: | |||
| """ | |||
| model = self.unwrap_model() | |||
| @@ -33,8 +33,7 @@ class TorchBackend(Backend): | |||
| if dist.is_initialized(): | |||
| if method is None: | |||
| raise AggregateMethodError(should_have_aggregate_method=True) | |||
| tensor = self._gather_all(tensor) | |||
| # tensor = self.all_gather_object(tensor) | |||
| tensor = fastnlp_torch_all_gather(tensor) | |||
| if isinstance(tensor[0], torch.Tensor): | |||
| tensor = torch.stack(tensor) | |||
| # 第一步, aggregate结果 | |||
| @@ -69,59 +68,6 @@ class TorchBackend(Backend): | |||
| def get_scalar(self, tensor) -> float: | |||
| return tensor.item() | |||
| @staticmethod | |||
| def _gather_all(result, group: Optional[Any] = None) -> List: | |||
| """Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. | |||
| Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case | |||
| tensors are padded, gathered and then trimmed to secure equal workload for all processes. | |||
| Args: | |||
| result: the value to sync | |||
| group: the process group to gather results from. Defaults to all processes (world) | |||
| Return: | |||
| gathered_result: list with size equal to the process group where | |||
| gathered_result[i] corresponds to result tensor from process i | |||
| """ | |||
| if group is None: | |||
| group = dist.group.WORLD | |||
| # convert tensors to contiguous format | |||
| result = result.contiguous() | |||
| world_size = dist.get_world_size(group) | |||
| dist.barrier(group=group) | |||
| # if the tensor is scalar, things are easy | |||
| if result.ndim == 0: | |||
| return _simple_gather_all_tensors(result, group, world_size) | |||
| # 1. Gather sizes of all tensors | |||
| local_size = torch.tensor(result.shape, device=result.device) | |||
| local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] | |||
| dist.all_gather(local_sizes, local_size, group=group) | |||
| max_size = torch.stack(local_sizes).max(dim=0).values | |||
| all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||
| # 2. If shapes are all the same, then do a simple gather: | |||
| if all_sizes_equal: | |||
| return _simple_gather_all_tensors(result, group, world_size) | |||
| # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate | |||
| pad_dims = [] | |||
| pad_by = (max_size - local_size).detach().cpu() | |||
| for val in reversed(pad_by): | |||
| pad_dims.append(0) | |||
| pad_dims.append(val.item()) | |||
| result_padded = torch.nn.functional.pad(result, pad_dims) | |||
| gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] | |||
| dist.all_gather(gathered_result, result_padded, group) | |||
| for idx, item_size in enumerate(local_sizes): | |||
| slice_param = [slice(dim_size) for dim_size in item_size] | |||
| gathered_result[idx] = gathered_result[idx][slice_param] | |||
| return gathered_result | |||
| def tensor2numpy(self, tensor) -> np.array: | |||
| """ | |||
| 将对应的tensor转为numpy对象 | |||
| @@ -11,12 +11,12 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
| class Element: | |||
| def __init__(self, value: float, aggregate_method, backend: Backend, name=None): | |||
| def __init__(self, name, value: float, aggregate_method, backend: Backend): | |||
| self.name = name | |||
| self.init_value = value | |||
| self.aggregate_method = aggregate_method | |||
| self.name = name | |||
| if backend == 'auto': | |||
| raise RuntimeError("You have to specify the backend.") | |||
| raise RuntimeError(f"You have to specify the backend for Element:{self.name}.") | |||
| elif isinstance(backend, AutoBackend): | |||
| self.backend = backend | |||
| else: | |||
| @@ -41,14 +41,9 @@ class Element: | |||
| msg = 'If you see this message, please report a bug.' | |||
| if self.name and e.should_have_aggregate_method: | |||
| msg = f"Element:{self.name} has no specified `aggregate_method`." | |||
| elif e.should_have_aggregate_method: | |||
| msg = "Element has no specified `aggregate_method`." | |||
| elif self.name and not e.should_have_aggregate_method: | |||
| msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ | |||
| f'aggregate_method:{self.aggregate_method}.' | |||
| elif not e.should_have_aggregate_method: | |||
| msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ | |||
| f'aggregate_method:{self.aggregate_method}.' | |||
| if e.only_warn: | |||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
| logger.warning(msg) | |||
| @@ -97,7 +92,7 @@ class Element: | |||
| def _check_value_when_call(self): | |||
| if self.value is None: | |||
| prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||
| prefix = f'Element:`{self.name}`' | |||
| raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||
| "element, or use it after it being used by the `Metric.compute()` method.") | |||
| @@ -275,9 +270,10 @@ class Element: | |||
| """ | |||
| try: | |||
| if self._value is None: | |||
| prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||
| prefix = f'Element:`{self.name}`' | |||
| raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||
| "element, or use it after it being used by the `Metric.compute()` method.") | |||
| return getattr(self._value, item) | |||
| except AttributeError as e: | |||
| logger.error(f"Element:{self.name} has no `{item}` attribute.") | |||
| raise e | |||
| @@ -35,7 +35,7 @@ class Metric: | |||
| def elements(self) -> dict: | |||
| return self._elements | |||
| def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||
| def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||
| """ | |||
| 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | |||
| tensor 直接进行加减乘除计算即可。 | |||
| @@ -57,11 +57,9 @@ class Metric: | |||
| else: | |||
| backend = AutoBackend(backend) | |||
| # 当name为None,默认为变量取得变量名 | |||
| if name is None: | |||
| name = f'ele_var_{len(self._elements)}' | |||
| assert name is not None and name not in self.elements | |||
| element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) | |||
| element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend) | |||
| self.elements[name] = element | |||
| setattr(self, name, element) | |||
| return element | |||
| @@ -219,6 +219,23 @@ class SpanFPreRecMetric(Metric): | |||
| def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | |||
| only_gross: bool = True, f_type='micro', | |||
| beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: | |||
| r""" | |||
| :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
| 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
| :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||
| :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||
| :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||
| :param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||
| :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | |||
| :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | |||
| :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
| :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
| :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||
| 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | |||
| :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
| 当 backend 不支持分布式时,该参数无意义。 | |||
| """ | |||
| super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
| if f_type not in ('micro', 'macro'): | |||
| raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
| @@ -255,7 +272,7 @@ class SpanFPreRecMetric(Metric): | |||
| for word, _ in tag_vocab: | |||
| word = word.lower() | |||
| if word != 'o': | |||
| word = word.split('-')[1] | |||
| word = word[2:] | |||
| if word in self._true_positives: | |||
| continue | |||
| self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | |||
| @@ -266,8 +283,8 @@ class SpanFPreRecMetric(Metric): | |||
| evaluate_result = {} | |||
| if not self.only_gross or self.f_type == 'macro': | |||
| tags = set(self._false_negatives.keys()) | |||
| tags.update(set(self._false_positives.keys())) | |||
| tags.update(set(self._true_positives.keys())) | |||
| tags.update(self._false_positives.keys()) | |||
| tags.update(self._true_positives.keys()) | |||
| f_sum = 0 | |||
| pre_sum = 0 | |||
| rec_sum = 0 | |||
| @@ -275,6 +292,9 @@ class SpanFPreRecMetric(Metric): | |||
| tp = self._true_positives[tag].get_scalar() | |||
| fn = self._false_negatives[tag].get_scalar() | |||
| fp = self._false_positives[tag].get_scalar() | |||
| if tp == fn == fp == 0: | |||
| continue | |||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
| f_sum += f | |||
| pre_sum += pre | |||
| @@ -96,6 +96,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
| # start new | |||
| self.start() | |||
| self.console.show_cursor(show=True) | |||
| return self | |||
| def set_transient(self, transient: bool = True): | |||
| @@ -149,6 +150,9 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
| super().stop_task(task_id) | |||
| super().remove_task(task_id) | |||
| def start(self) -> None: | |||
| super().start() | |||
| self.console.show_cursor(show=True) | |||
| if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
| f_rich_progress = FRichProgress().new_progess( | |||
| @@ -161,7 +165,7 @@ if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
| TextColumn("{task.fields[post_desc]}", justify="right"), | |||
| transient=True, | |||
| disable=False, | |||
| speed_estimate_period=10 | |||
| speed_estimate_period=1 | |||
| ) | |||
| else: | |||
| f_rich_progress = DummyFRichProgress() | |||
| @@ -44,6 +44,9 @@ __all__ = [ | |||
| ] | |||
| def get_fn_arg_names(fn: Callable) -> List[str]: | |||
| r""" | |||
| 返回一个函数的所有参数的名字; | |||
| @@ -150,7 +150,7 @@ def seed_jittor_global_seed(global_seed): | |||
| pass | |||
| def dump_fastnlp_backend(default:bool = False): | |||
| def dump_fastnlp_backend(default:bool = False, backend=None): | |||
| """ | |||
| 将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下, | |||
| 若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。 | |||
| @@ -162,6 +162,7 @@ def dump_fastnlp_backend(default:bool = False): | |||
| 会保存的环境变量为 FASTNLP_BACKEND 。 | |||
| :param default: | |||
| :param backend: 保存使用的 backend 为哪个值,允许的值有 ['torch', 'paddle', 'jittor']。如果为 None ,则使用环境变量中的值。 | |||
| :return: | |||
| """ | |||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
| @@ -176,10 +177,16 @@ def dump_fastnlp_backend(default:bool = False): | |||
| os.makedirs(os.path.dirname(env_path), exist_ok=True) | |||
| envs = {} | |||
| if FASTNLP_BACKEND in os.environ: | |||
| envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | |||
| assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now." | |||
| if backend is None: | |||
| if FASTNLP_BACKEND in os.environ: | |||
| envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | |||
| else: | |||
| envs[FASTNLP_BACKEND] = backend | |||
| if len(envs): | |||
| with open(env_path, 'w', encoding='utf8') as f: | |||
| json.dump(fp=f, obj=envs) | |||
| print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.") | |||
| else: | |||
| raise RuntimeError("No backend specified.") | |||
| @@ -48,7 +48,8 @@ def set_env_on_import_paddle(): | |||
| # TODO jittor may need set this | |||
| def set_env_on_import_jittor(): | |||
| # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | |||
| pass | |||
| if 'log_silent' not in os.environ: | |||
| os.environ['log_silent'] = '1' | |||
| def set_env_on_import(): | |||
| @@ -64,7 +65,7 @@ def set_env_on_import(): | |||
| # fastNLP 内部使用的一些变量 | |||
| if FASTNLP_LAUNCH_TIME not in os.environ: | |||
| cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}" | |||
| cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}" | |||
| os.environ[FASTNLP_LAUNCH_TIME] = cur_time | |||
| # 设置对应的值 | |||
| @@ -8,7 +8,7 @@ import torch.distributed as dist | |||
| from pathlib import Path | |||
| import re | |||
| from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback | |||
| from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||
| from fastNLP.core.controllers.trainer import Trainer | |||
| from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||
| @@ -80,16 +80,23 @@ def test_model_checkpoint_callback_1( | |||
| version, | |||
| only_state_dict | |||
| ): | |||
| # def test_model_checkpoint_callback_1( | |||
| # model_and_optimizers: TrainerParameters, | |||
| # driver='torch_ddp', | |||
| # device=[0, 1], | |||
| # version=1, | |||
| # only_state_dict=True | |||
| # ): | |||
| path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
| path.mkdir(exist_ok=True, parents=True) | |||
| if version == 0: | |||
| callbacks = [ | |||
| CheckpointCallback( | |||
| ModelCheckpointCallback( | |||
| monitor="acc", | |||
| save_folder=path, | |||
| save_every_n_epochs=1, | |||
| save_every_n_global_batches=123, # 避免和 epoch 的保存重复; | |||
| save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||
| save_topk=None, | |||
| save_last=False, | |||
| save_on_exception=None, | |||
| @@ -98,11 +105,11 @@ def test_model_checkpoint_callback_1( | |||
| ] | |||
| elif version == 1: | |||
| callbacks = [ | |||
| CheckpointCallback( | |||
| ModelCheckpointCallback( | |||
| monitor="acc", | |||
| save_folder=path, | |||
| save_every_n_epochs=3, | |||
| save_every_n_global_batches=None, | |||
| save_every_n_batches=None, | |||
| save_topk=2, | |||
| save_last=True, | |||
| save_on_exception=None, | |||
| @@ -121,7 +128,6 @@ def test_model_checkpoint_callback_1( | |||
| input_mapping=model_and_optimizers.input_mapping, | |||
| output_mapping=model_and_optimizers.output_mapping, | |||
| metrics=model_and_optimizers.metrics, | |||
| n_epochs=10, | |||
| callbacks=callbacks, | |||
| output_from_new_proc="all" | |||
| @@ -134,31 +140,31 @@ def test_model_checkpoint_callback_1( | |||
| if version == 0: | |||
| if driver == "torch": | |||
| assert "epoch_10-global_batch_250-acc" in all_saved_model_paths | |||
| assert "epoch_4-global_batch_123-acc" in all_saved_model_paths | |||
| assert "model-epoch_10" in all_saved_model_paths | |||
| assert "model-epoch_4-batch_123" in all_saved_model_paths | |||
| epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"] | |||
| step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] | |||
| epoch_save_path = all_saved_model_paths["model-epoch_10"] | |||
| step_save_path = all_saved_model_paths["model-epoch_4-batch_123"] | |||
| assert len(all_saved_model_paths) == 12 | |||
| # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
| else: | |||
| assert "epoch_6-global_batch_78-acc" in all_saved_model_paths | |||
| assert "epoch_9-global_batch_123-acc" in all_saved_model_paths | |||
| assert "model-epoch_6" in all_saved_model_paths | |||
| assert "model-epoch_9-batch_123" in all_saved_model_paths | |||
| epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"] | |||
| step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] | |||
| epoch_save_path = all_saved_model_paths["model-epoch_6"] | |||
| step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] | |||
| assert len(all_saved_model_paths) == 11 | |||
| all_state_dicts = [epoch_save_path, step_save_path] | |||
| elif version == 1: | |||
| pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||
| pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | |||
| if driver == "torch": | |||
| assert "epoch_9-global_batch_225-acc" in all_saved_model_paths | |||
| assert "last" in all_saved_model_paths | |||
| assert "model-epoch_9" in all_saved_model_paths | |||
| assert "model-last" in all_saved_model_paths | |||
| aLL_topk_folders = [] | |||
| for each_folder_name in all_saved_model_paths: | |||
| each_folder_name = pattern.findall(each_folder_name) | |||
| @@ -166,15 +172,15 @@ def test_model_checkpoint_callback_1( | |||
| aLL_topk_folders.append(each_folder_name[0]) | |||
| assert len(aLL_topk_folders) == 2 | |||
| epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"] | |||
| last_save_path = all_saved_model_paths["last"] | |||
| epoch_save_path = all_saved_model_paths["model-epoch_9"] | |||
| last_save_path = all_saved_model_paths["model-last"] | |||
| topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
| assert len(all_saved_model_paths) == 6 | |||
| # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
| else: | |||
| assert "epoch_9-global_batch_117-acc" in all_saved_model_paths | |||
| assert "last" in all_saved_model_paths | |||
| assert "model-epoch_9" in all_saved_model_paths | |||
| assert "model-last" in all_saved_model_paths | |||
| aLL_topk_folders = [] | |||
| for each_folder_name in all_saved_model_paths: | |||
| @@ -183,8 +189,8 @@ def test_model_checkpoint_callback_1( | |||
| aLL_topk_folders.append(each_folder_name[0]) | |||
| assert len(aLL_topk_folders) == 2 | |||
| epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"] | |||
| last_save_path = all_saved_model_paths["last"] | |||
| epoch_save_path = all_saved_model_paths["model-epoch_9"] | |||
| last_save_path = all_saved_model_paths["model-last"] | |||
| topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
| assert len(all_saved_model_paths) == 6 | |||
| @@ -212,7 +218,7 @@ def test_model_checkpoint_callback_1( | |||
| finally: | |||
| synchronize_safe_rm(path) | |||
| # pass | |||
| pass | |||
| if dist.is_initialized(): | |||
| dist.destroy_process_group() | |||
| @@ -238,11 +244,11 @@ def test_model_checkpoint_callback_2( | |||
| raise NotImplementedError | |||
| callbacks = [ | |||
| CheckpointCallback( | |||
| ModelCheckpointCallback( | |||
| monitor="acc1", | |||
| save_folder=path, | |||
| save_every_n_epochs=None, | |||
| save_every_n_global_batches=None, | |||
| save_every_n_batches=None, | |||
| save_topk=None, | |||
| save_last=False, | |||
| save_on_exception=NotImplementedError, | |||
| @@ -279,12 +285,12 @@ def test_model_checkpoint_callback_2( | |||
| all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
| if driver == "torch": | |||
| assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths | |||
| exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"] | |||
| assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | |||
| exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | |||
| # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
| else: | |||
| assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths | |||
| exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"] | |||
| assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths | |||
| exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"] | |||
| assert len(all_saved_model_paths) == 1 | |||
| all_state_dicts = [exception_model_path] | |||
| @@ -332,12 +338,11 @@ def test_trainer_checkpoint_callback_1( | |||
| if version == 0: | |||
| callbacks = [ | |||
| CheckpointCallback( | |||
| TrainerCheckpointCallback( | |||
| monitor="acc", | |||
| is_trainer_checkpoint=True, | |||
| save_folder=path, | |||
| save_every_n_epochs=7, | |||
| save_every_n_global_batches=123, # 避免和 epoch 的保存重复; | |||
| save_every_n_batches=123, # 避免和 epoch 的保存重复; | |||
| save_topk=None, | |||
| save_last=False, | |||
| save_on_exception=None, | |||
| @@ -346,12 +351,11 @@ def test_trainer_checkpoint_callback_1( | |||
| ] | |||
| elif version == 1: | |||
| callbacks = [ | |||
| CheckpointCallback( | |||
| TrainerCheckpointCallback( | |||
| monitor="acc", | |||
| is_trainer_checkpoint=True, | |||
| save_folder=path, | |||
| save_every_n_epochs=None, | |||
| save_every_n_global_batches=None, | |||
| save_every_n_batches=None, | |||
| save_topk=2, | |||
| save_last=True, | |||
| save_on_exception=None, | |||
| @@ -383,31 +387,31 @@ def test_trainer_checkpoint_callback_1( | |||
| if version == 0: | |||
| if driver == "torch": | |||
| assert "epoch_7-global_batch_175-acc" in all_saved_model_paths | |||
| assert "epoch_4-global_batch_123-acc" in all_saved_model_paths | |||
| assert "trainer-epoch_7" in all_saved_model_paths | |||
| assert "trainer-epoch_4-batch_123" in all_saved_model_paths | |||
| epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"] | |||
| step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] | |||
| epoch_save_path = all_saved_model_paths["trainer-epoch_7"] | |||
| step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"] | |||
| assert len(all_saved_model_paths) == 3 | |||
| # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
| else: | |||
| assert "epoch_7-global_batch_91-acc" in all_saved_model_paths | |||
| assert "epoch_9-global_batch_123-acc" in all_saved_model_paths | |||
| assert "trainer-epoch_7" in all_saved_model_paths | |||
| assert "trainer-epoch_9-batch_123" in all_saved_model_paths | |||
| epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"] | |||
| step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] | |||
| epoch_save_path = all_saved_model_paths["trainer-epoch_7"] | |||
| step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"] | |||
| assert len(all_saved_model_paths) == 2 | |||
| all_state_dicts = [epoch_save_path, step_save_path] | |||
| elif version == 1: | |||
| pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||
| pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | |||
| # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
| if driver == "torch": | |||
| assert "last" in all_saved_model_paths | |||
| assert "trainer-last" in all_saved_model_paths | |||
| aLL_topk_folders = [] | |||
| for each_folder_name in all_saved_model_paths: | |||
| each_folder_name = pattern.findall(each_folder_name) | |||
| @@ -415,13 +419,13 @@ def test_trainer_checkpoint_callback_1( | |||
| aLL_topk_folders.append(each_folder_name[0]) | |||
| assert len(aLL_topk_folders) == 2 | |||
| last_save_path = all_saved_model_paths["last"] | |||
| last_save_path = all_saved_model_paths["trainer-last"] | |||
| topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
| assert len(all_saved_model_paths) == 3 | |||
| # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
| else: | |||
| assert "last" in all_saved_model_paths | |||
| assert "trainer-last" in all_saved_model_paths | |||
| aLL_topk_folders = [] | |||
| for each_folder_name in all_saved_model_paths: | |||
| @@ -430,7 +434,7 @@ def test_trainer_checkpoint_callback_1( | |||
| aLL_topk_folders.append(each_folder_name[0]) | |||
| assert len(aLL_topk_folders) == 2 | |||
| last_save_path = all_saved_model_paths["last"] | |||
| last_save_path = all_saved_model_paths["trainer-last"] | |||
| topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
| assert len(all_saved_model_paths) == 3 | |||
| @@ -474,10 +478,11 @@ def test_trainer_checkpoint_callback_2( | |||
| device, | |||
| version | |||
| ): | |||
| pytest.skip("Skip transformers test for now.") | |||
| path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
| path.mkdir(exist_ok=True, parents=True) | |||
| import transformers | |||
| import transformers # 版本4.16.2 | |||
| import torch | |||
| from torchmetrics import Accuracy | |||
| from transformers import AutoModelForSequenceClassification | |||
| @@ -587,12 +592,11 @@ def test_trainer_checkpoint_callback_2( | |||
| if version == 0: | |||
| callbacks = [ | |||
| CheckpointCallback( | |||
| TrainerCheckpointCallback( | |||
| monitor="acc", | |||
| is_trainer_checkpoint=True, | |||
| save_folder=path, | |||
| save_every_n_epochs=None, | |||
| save_every_n_global_batches=50, | |||
| save_every_n_batches=50, | |||
| save_topk=None, | |||
| save_last=False, | |||
| save_on_exception=None, | |||
| @@ -601,12 +605,11 @@ def test_trainer_checkpoint_callback_2( | |||
| ] | |||
| elif version == 1: | |||
| callbacks = [ | |||
| CheckpointCallback( | |||
| TrainerCheckpointCallback( | |||
| monitor="acc", | |||
| is_trainer_checkpoint=True, | |||
| save_folder=path, | |||
| save_every_n_epochs=None, | |||
| save_every_n_global_batches=None, | |||
| save_every_n_batches=None, | |||
| save_topk=1, | |||
| save_last=True, | |||
| save_on_exception=None, | |||
| @@ -638,27 +641,27 @@ def test_trainer_checkpoint_callback_2( | |||
| if version == 0: | |||
| if driver == "torch": | |||
| assert "epoch_1-global_batch_200-acc" in all_saved_model_paths | |||
| assert "trainer-epoch_1-batch_200" in all_saved_model_paths | |||
| epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"] | |||
| epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] | |||
| assert len(all_saved_model_paths) == 4 | |||
| # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
| else: | |||
| assert "epoch_1-global_batch_100-acc" in all_saved_model_paths | |||
| assert "trainer-epoch_1-batch_100" in all_saved_model_paths | |||
| epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"] | |||
| epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_100"] | |||
| assert len(all_saved_model_paths) == 2 | |||
| all_state_dicts = [epoch_save_path] | |||
| elif version == 1: | |||
| pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||
| pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | |||
| # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
| if driver == "torch": | |||
| assert "last" in all_saved_model_paths | |||
| assert "trainer-last" in all_saved_model_paths | |||
| aLL_topk_folders = [] | |||
| for each_folder_name in all_saved_model_paths: | |||
| each_folder_name = pattern.findall(each_folder_name) | |||
| @@ -666,13 +669,13 @@ def test_trainer_checkpoint_callback_2( | |||
| aLL_topk_folders.append(each_folder_name[0]) | |||
| assert len(aLL_topk_folders) == 1 | |||
| last_save_path = all_saved_model_paths["last"] | |||
| last_save_path = all_saved_model_paths["trainer-last"] | |||
| topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
| assert len(all_saved_model_paths) == 2 | |||
| # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
| else: | |||
| assert "last" in all_saved_model_paths | |||
| assert "trainer-last" in all_saved_model_paths | |||
| aLL_topk_folders = [] | |||
| for each_folder_name in all_saved_model_paths: | |||
| @@ -681,7 +684,7 @@ def test_trainer_checkpoint_callback_2( | |||
| aLL_topk_folders.append(each_folder_name[0]) | |||
| assert len(aLL_topk_folders) == 1 | |||
| last_save_path = all_saved_model_paths["last"] | |||
| last_save_path = all_saved_model_paths["trainer-last"] | |||
| topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
| assert len(all_saved_model_paths) == 2 | |||
| @@ -105,6 +105,20 @@ class TestDataSetMethods(unittest.TestCase): | |||
| self.assertTrue(isinstance(field_array, FieldArray)) | |||
| self.assertEqual(len(field_array), 40) | |||
| def test_setitem(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ds.add_field('i', list(range(len(ds)))) | |||
| assert ds.get_field('i').content == list(range(len(ds))) | |||
| import random | |||
| random.shuffle(ds) | |||
| import numpy as np | |||
| np.random.shuffle(ds) | |||
| assert ds.get_field('i').content != list(range(len(ds))) | |||
| ins1 = ds[1] | |||
| ds[2] = ds[1] | |||
| assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] | |||
| def test_get_item_error(self): | |||
| with self.assertRaises(RuntimeError): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
| @@ -99,7 +99,7 @@ def _test(local_rank: int, | |||
| assert my_result == sklearn_metric | |||
| class SpanFPreRecMetricTest(unittest.TestCase): | |||
| class TestSpanFPreRecMetric: | |||
| def test_case1(self): | |||
| from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans | |||
| @@ -136,33 +136,31 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
| fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | |||
| fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
| bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
| -0.3782, 0.8240], | |||
| [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||
| bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
| -0.3782, 0.8240], | |||
| [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||
| -0.3562, -1.4116], | |||
| [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||
| 2.0023, 0.7075], | |||
| [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||
| [ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||
| 2.0023, 0.7075], | |||
| [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||
| 0.3832, -0.1540], | |||
| [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||
| [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||
| -1.3508, -0.9513], | |||
| [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||
| -0.0842, -0.4294]], | |||
| [ | |||
| [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
| -1.4138, -0.8853], | |||
| [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
| -1.0726, 0.0364], | |||
| [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
| -0.8836, -0.9320], | |||
| [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
| -1.6857, 1.1571], | |||
| [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
| -0.5837, 1.0184], | |||
| [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
| -0.9025, 0.0864]] | |||
| ] | |||
| ]) | |||
| [ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||
| -0.0842, -0.4294]], | |||
| [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
| -1.4138, -0.8853], | |||
| [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
| -1.0726, 0.0364], | |||
| [ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
| -0.8836, -0.9320], | |||
| [ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
| -1.6857, 1.1571], | |||
| [ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
| -0.5837, 1.0184], | |||
| [ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
| -0.9025, 0.0864]]]) | |||
| bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) | |||
| fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) | |||
| expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, | |||
| @@ -254,7 +252,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| # print(expected_metric) | |||
| metric_value = metric.get_metric() | |||
| for key, value in expected_metric.items(): | |||
| self.assertAlmostEqual(value, metric_value[key], places=5) | |||
| np.allclose(value, metric_value[key]) | |||
| def test_auto_encoding_type_infer(self): | |||
| # 检查是否可以自动check encode的类型 | |||
| @@ -271,9 +269,8 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| vocab.add_word('o') | |||
| vocabs[encoding_type] = vocab | |||
| for e in ['bio', 'bioes', 'bmeso']: | |||
| with self.subTest(e=e): | |||
| metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) | |||
| assert metric.encoding_type == e | |||
| metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) | |||
| assert metric.encoding_type == e | |||
| bmes_vocab = _generate_tags('bmes') | |||
| vocab = Vocabulary() | |||
| @@ -286,7 +283,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| vocab = Vocabulary() | |||
| for i in range(10): | |||
| vocab.add_word(str(i)) | |||
| with self.assertRaises(Exception): | |||
| with pytest.raises(Exception): | |||
| metric = SpanFPreRecMetric(vocab) | |||
| def test_encoding_type(self): | |||
| @@ -305,21 +302,20 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| vocab.add_word('o') | |||
| vocabs[encoding_type] = vocab | |||
| for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): | |||
| with self.subTest(e1=e1, e2=e2): | |||
| if e1 == e2: | |||
| if e1 == e2: | |||
| metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | |||
| else: | |||
| s2 = set(e2) | |||
| s2.update(set(e1)) | |||
| if s2 == set(e2): | |||
| continue | |||
| with pytest.raises(AssertionError): | |||
| metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | |||
| else: | |||
| s2 = set(e2) | |||
| s2.update(set(e1)) | |||
| if s2 == set(e2): | |||
| continue | |||
| with self.assertRaises(AssertionError): | |||
| metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) | |||
| for encoding_type in ['bio', 'bioes', 'bmeso']: | |||
| with self.assertRaises(AssertionError): | |||
| with pytest.raises(AssertionError): | |||
| metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') | |||
| with self.assertWarns(Warning): | |||
| with pytest.warns(Warning): | |||
| vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) | |||
| metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | |||
| vocab = Vocabulary().add_word_lst(list('bmes')) | |||