| @@ -79,7 +79,7 @@ class RichCallback(ProgressCallback): | |||||
| def on_train_begin(self, trainer): | def on_train_begin(self, trainer): | ||||
| self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | ||||
| completed=trainer.global_forward_batches/(trainer.total_batches+1e-6)) | |||||
| completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)) | |||||
| def on_train_epoch_begin(self, trainer): | def on_train_epoch_begin(self, trainer): | ||||
| self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | ||||
| @@ -190,7 +190,7 @@ class RawTextCallback(ProgressCallback): | |||||
| self.loss = 0 | self.loss = 0 | ||||
| text = f'Epoch:{trainer.cur_epoch_idx}/{trainer.n_epochs}, Batch:{trainer.batch_idx_in_epoch}, ' \ | text = f'Epoch:{trainer.cur_epoch_idx}/{trainer.n_epochs}, Batch:{trainer.batch_idx_in_epoch}, ' \ | ||||
| f'loss:{round(loss, self.loss_round_ndigit)}, ' \ | f'loss:{round(loss, self.loss_round_ndigit)}, ' \ | ||||
| f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' | |||||
| f'finished {round(trainer.global_forward_batches/trainer.n_batches*100, 2)}%.' | |||||
| logger.info(text) | logger.info(text) | ||||
| def on_evaluate_end(self, trainer, results): | def on_evaluate_end(self, trainer, results): | ||||
| @@ -251,7 +251,7 @@ class TqdmCallback(ProgressCallback): | |||||
| def on_train_begin(self, trainer): | def on_train_begin(self, trainer): | ||||
| self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | ||||
| bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', | bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', | ||||
| initial=trainer.global_forward_batches/(trainer.total_batches+1e-6)) | |||||
| initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)) | |||||
| def on_train_epoch_begin(self, trainer): | def on_train_epoch_begin(self, trainer): | ||||
| self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | ||||
| @@ -41,7 +41,7 @@ class TorchWarmupCallback(Callback): | |||||
| return max((progress - 1.) / (self.warmup - 1.), 0.) | return max((progress - 1.) / (self.warmup - 1.), 0.) | ||||
| def on_train_begin(self, trainer): | def on_train_begin(self, trainer): | ||||
| self.t_steps = trainer.total_batches | |||||
| self.t_steps = trainer.n_batches | |||||
| if self.warmup >1: | if self.warmup >1: | ||||
| self.warmup = self.warmup / self.t_steps | self.warmup = self.warmup / self.t_steps | ||||
| self.t_steps = max(2, self.t_steps) # 不能小于2 | self.t_steps = max(2, self.t_steps) # 不能小于2 | ||||
| @@ -460,14 +460,15 @@ class _MetricsWrapper: | |||||
| for metric in self._metrics: | for metric in self._metrics: | ||||
| args = [] | args = [] | ||||
| if not isinstance(batch, dict): | if not isinstance(batch, dict): | ||||
| logger.warning_once( | |||||
| logger.rank_zero_warning( | |||||
| f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | ||||
| f"the output of model to update metric.") | |||||
| f"the output of model to update metric.", once=True) | |||||
| else: | else: | ||||
| args.append(batch) | args.append(batch) | ||||
| if not isinstance(outputs, dict): | if not isinstance(outputs, dict): | ||||
| raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | ||||
| f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||||
| f" return a dict from your model or use `output_mapping` to convert it into dict " | |||||
| f"type.") | |||||
| if isinstance(metric, Metric): | if isinstance(metric, Metric): | ||||
| # 这样在 auto_param_call 报错的时候才清晰。 | # 这样在 auto_param_call 报错的时候才清晰。 | ||||
| auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | ||||
| @@ -110,7 +110,7 @@ class Trainer(TrainerEventTrigger): | |||||
| 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 | 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 | ||||
| :param n_epochs: 训练总共的 epoch 的数量,默认为 20; | |||||
| :param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 | |||||
| :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | ||||
| 为 None; | 为 None; | ||||
| :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, | :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, | ||||
| @@ -237,6 +237,8 @@ class Trainer(TrainerEventTrigger): | |||||
| 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | ||||
| :param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||||
| :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; | :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; | ||||
| .. note:: | .. note:: | ||||
| @@ -356,6 +358,7 @@ class Trainer(TrainerEventTrigger): | |||||
| fp16: bool = False, | fp16: bool = False, | ||||
| monitor: Union[str, Callable] = None, | monitor: Union[str, Callable] = None, | ||||
| larger_better: bool = True, | larger_better: bool = True, | ||||
| n_batches: int = -1, | |||||
| marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| @@ -426,6 +429,7 @@ class Trainer(TrainerEventTrigger): | |||||
| model_wo_auto_param_call=model_wo_auto_param_call, | model_wo_auto_param_call=model_wo_auto_param_call, | ||||
| accumulation_steps=accumulation_steps, | accumulation_steps=accumulation_steps, | ||||
| fp16=fp16, | fp16=fp16, | ||||
| n_batches=n_batches, | |||||
| marker=marker, | marker=marker, | ||||
| **kwargs | **kwargs | ||||
| ) | ) | ||||
| @@ -444,12 +448,12 @@ class Trainer(TrainerEventTrigger): | |||||
| # 初始化 state,包括提供给用户的接口和我们自己使用的接口; | # 初始化 state,包括提供给用户的接口和我们自己使用的接口; | ||||
| self.state = State() | self.state = State() | ||||
| self.trainer_state = TrainerState( | self.trainer_state = TrainerState( | ||||
| n_epochs=n_epochs, | |||||
| n_epochs=n_epochs if n_batches!=-1 else None, | |||||
| cur_epoch_idx=0, | cur_epoch_idx=0, | ||||
| global_forward_batches=0, | global_forward_batches=0, | ||||
| batch_idx_in_epoch=0, | batch_idx_in_epoch=0, | ||||
| num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化; | num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化; | ||||
| total_batches=None | |||||
| n_batches=n_batches | |||||
| ) | ) | ||||
| if metrics is None and evaluate_dataloaders is not None: | if metrics is None and evaluate_dataloaders is not None: | ||||
| @@ -598,14 +602,18 @@ class Trainer(TrainerEventTrigger): | |||||
| self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch) | self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch) | ||||
| self.num_batches_per_epoch = len(self.dataloader) | self.num_batches_per_epoch = len(self.dataloader) | ||||
| self.total_batches = self.num_batches_per_epoch * self.n_epochs | |||||
| if self.n_batches == -1: | |||||
| self.n_batches = self.num_batches_per_epoch * self.n_epochs | |||||
| else: | |||||
| self.n_epochs = (self.n_batches+self.num_batches_per_epoch-1)//self.num_batches_per_epoch | |||||
| self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | ||||
| try: | try: | ||||
| self.on_train_begin() | self.on_train_begin() | ||||
| self.driver.barrier() | self.driver.barrier() | ||||
| self.driver.zero_grad() | self.driver.zero_grad() | ||||
| while self.cur_epoch_idx < self.n_epochs: | |||||
| while self.cur_epoch_idx < self.n_epochs and self.global_forward_batches < self.n_batches: | |||||
| # 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save | # 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save | ||||
| self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | ||||
| self.driver.set_model_mode("train") | self.driver.set_model_mode("train") | ||||
| @@ -1367,15 +1375,15 @@ class Trainer(TrainerEventTrigger): | |||||
| self.trainer_state.num_batches_per_epoch = num_batches_per_epoch | self.trainer_state.num_batches_per_epoch = num_batches_per_epoch | ||||
| @property | @property | ||||
| def total_batches(self) -> int: | |||||
| def n_batches(self) -> int: | |||||
| r""" | r""" | ||||
| :return: 返回整体的训练中实际会训练多少个 batch 的数据; | :return: 返回整体的训练中实际会训练多少个 batch 的数据; | ||||
| """ | """ | ||||
| return self.trainer_state.total_batches | |||||
| return self.trainer_state.n_batches | |||||
| @total_batches.setter | |||||
| def total_batches(self, total_batches: int): | |||||
| self.trainer_state.total_batches = total_batches | |||||
| @n_batches.setter | |||||
| def n_batches(self, n_batches: int): | |||||
| self.trainer_state.n_batches = n_batches | |||||
| """ driver property """ | """ driver property """ | ||||
| @@ -50,7 +50,7 @@ class TrainerState: | |||||
| :param global_forward_batches: 当前模型总共 forward 了多少个 step; | :param global_forward_batches: 当前模型总共 forward 了多少个 step; | ||||
| :param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | :param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | ||||
| :param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | :param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | ||||
| :param total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | |||||
| :param n_batches: 完整训练过程会 forward 的 step 数量,注意 n_batches = n_batches * n_epochs; | |||||
| """ | """ | ||||
| n_epochs: Optional[int] = None # 无论如何重新算 | n_epochs: Optional[int] = None # 无论如何重新算 | ||||
| @@ -61,7 +61,7 @@ class TrainerState: | |||||
| num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | ||||
| total_batches: Optional[int] = None # 无论如何重新算 | |||||
| n_batches: Optional[int] = None # 无论如何重新算 | |||||
| def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
| r""" | r""" | ||||
| @@ -156,7 +156,6 @@ import _pickle as pickle | |||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from typing import Optional, List, Callable, Union, Dict, Any, Mapping | from typing import Optional, List, Callable, Union, Dict, Any, Mapping | ||||
| from types import LambdaType | from types import LambdaType | ||||
| from subprocess import DEVNULL | |||||
| import sys | import sys | ||||
| import time | import time | ||||
| @@ -170,6 +169,7 @@ from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress | |||||
| from fastNLP.core.utils.tqdm_progress import f_tqdm_progress | from fastNLP.core.utils.tqdm_progress import f_tqdm_progress | ||||
| from ..log import logger | from ..log import logger | ||||
| from fastNLP.core.utils.dummy_class import DummyClass | from fastNLP.core.utils.dummy_class import DummyClass | ||||
| from ..utils.utils import _get_fun_msg | |||||
| progress_bars = { | progress_bars = { | ||||
| @@ -780,8 +780,8 @@ class DataSet: | |||||
| apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | ||||
| progress_bar=progress_bar) | progress_bar=progress_bar) | ||||
| # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | ||||
| if not isinstance(apply_out[0], dict): | |||||
| raise Exception("The result of func is not a dict") | |||||
| if not isinstance(apply_out[0], Mapping): | |||||
| raise Exception(f"The result of func:{_get_fun_msg(func)} is not a dict, but of type {type(apply_out[0])}") | |||||
| for key, value in apply_out[0].items(): | for key, value in apply_out[0].items(): | ||||
| results[key] = [value] | results[key] = [value] | ||||
| @@ -789,7 +789,8 @@ class DataSet: | |||||
| try: | try: | ||||
| for idx, per_out in enumerate(apply_out[1:]): | for idx, per_out in enumerate(apply_out[1:]): | ||||
| if len(set(results.keys()) - set(per_out.keys())): | if len(set(results.keys()) - set(per_out.keys())): | ||||
| raise ApplyResultException("apply results have different fields", idx + 1) | |||||
| raise ApplyResultException(f"Apply results have different fields:{set(results.keys())} and " | |||||
| f"{set(per_out.keys())}", idx + 1) | |||||
| for key, value in per_out.items(): | for key, value in per_out.items(): | ||||
| results[key].append(value) | results[key].append(value) | ||||
| @@ -169,7 +169,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| :param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
| """ | """ | ||||
| def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | ||||
| drop_last: bool = False, seed: int = 0, **kwargs): | |||||
| drop_last: bool = False, seed: int = None, **kwargs): | |||||
| super().__init__() | super().__init__() | ||||
| self.dataset = dataset | self.dataset = dataset | ||||
| @@ -120,7 +120,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
| def add_task( | def add_task( | ||||
| self, | self, | ||||
| description: str, | |||||
| description: str = 'Progress', | |||||
| start: bool = True, | start: bool = True, | ||||
| total: float = 100.0, | total: float = 100.0, | ||||
| completed: int = 0, | completed: int = 0, | ||||
| @@ -7,7 +7,7 @@ __all__ = [] | |||||
| import json | import json | ||||
| import csv | import csv | ||||
| # from ..core import log | |||||
| from ..core import logger | |||||
| def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | ||||
| @@ -81,7 +81,7 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||||
| yield line_idx, _res | yield line_idx, _res | ||||
| def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||||
| def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True, drophash=True): | |||||
| r""" | r""" | ||||
| Construct a generator to read conll items. | Construct a generator to read conll items. | ||||
| @@ -91,6 +91,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||||
| :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | ||||
| :param dropna: weather to ignore and drop invalid data, | :param dropna: weather to ignore and drop invalid data, | ||||
| :if False, raise ValueError when reading invalid data. default: True | :if False, raise ValueError when reading invalid data. default: True | ||||
| :param drophash: 是否丢掉以 # 开头的 line 。 | |||||
| :return: generator, every time yield (line number, conll item) | :return: generator, every time yield (line number, conll item) | ||||
| """ | """ | ||||
| @@ -121,7 +122,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||||
| sample = [] | sample = [] | ||||
| continue | continue | ||||
| raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | ||||
| elif line.startswith('#'): | |||||
| elif line.startswith('#') and drophash: | |||||
| continue | continue | ||||
| else: | else: | ||||
| sample.append(line.split(sep)) if sep else sample.append(line.split()) | sample.append(line.split(sep)) if sep else sample.append(line.split()) | ||||
| @@ -52,13 +52,14 @@ class ConllLoader(Loader): | |||||
| """ | """ | ||||
| def __init__(self, headers, sep=None, indexes=None, dropna=True): | |||||
| def __init__(self, headers, sep=None, indexes=None, dropna=True, drophash=True): | |||||
| r""" | r""" | ||||
| :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | ||||
| :param list sep: 指定分隔符,默认为制表符 | :param list sep: 指定分隔符,默认为制表符 | ||||
| :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | ||||
| :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | ||||
| :param bool drophashtag: 是否忽略以 ``#`` 开头的句子。 | |||||
| """ | """ | ||||
| super(ConllLoader, self).__init__() | super(ConllLoader, self).__init__() | ||||
| if not isinstance(headers, (list, tuple)): | if not isinstance(headers, (list, tuple)): | ||||
| @@ -66,6 +67,7 @@ class ConllLoader(Loader): | |||||
| 'invalid headers: {}, should be list of strings'.format(headers)) | 'invalid headers: {}, should be list of strings'.format(headers)) | ||||
| self.headers = headers | self.headers = headers | ||||
| self.dropna = dropna | self.dropna = dropna | ||||
| self.drophash = drophash | |||||
| self.sep=sep | self.sep=sep | ||||
| if indexes is None: | if indexes is None: | ||||
| self.indexes = list(range(len(self.headers))) | self.indexes = list(range(len(self.headers))) | ||||
| @@ -82,7 +84,8 @@ class ConllLoader(Loader): | |||||
| :return: DataSet | :return: DataSet | ||||
| """ | """ | ||||
| ds = DataSet() | ds = DataSet() | ||||
| for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): | |||||
| for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna, | |||||
| drophash=self.drophash): | |||||
| ins = {h: data[i] for i, h in enumerate(self.headers)} | ins = {h: data[i] for i, h in enumerate(self.headers)} | ||||
| ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
| return ds | return ds | ||||
| @@ -32,4 +32,4 @@ def test_torch_warmup_callback(warmup, schedule, accumulation_steps): | |||||
| elif schedule == 'constant': | elif schedule == 'constant': | ||||
| assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr']) | assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr']) | ||||
| assert len(r_callback.lrs)<=trainer.total_batches//accumulation_steps+1 | |||||
| assert len(r_callback.lrs)<=trainer.n_batches//accumulation_steps+1 | |||||
| @@ -55,4 +55,4 @@ class RecordAccumulationStepsCallback_Torch(Callback): | |||||
| def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
| print(f"\n equal num: {self.equal}.\n") | print(f"\n equal num: {self.equal}.\n") | ||||
| print(f"\ntotal_batch_num: {trainer.total_batches}.\n") | |||||
| print(f"\ntotal_batch_num: {trainer.n_batches}.\n") | |||||