| @@ -50,20 +50,20 @@ def prepare_callbacks(callbacks, progress_bar): | |||
| raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | |||
| _callbacks += callbacks | |||
| has_no_progress = False | |||
| has_no_progress = True | |||
| for _callback in _callbacks: | |||
| if isinstance(_callback, ProgressCallback): | |||
| has_no_progress = True | |||
| if not has_no_progress: | |||
| has_no_progress = False | |||
| if has_no_progress and progress_bar is not None: | |||
| callback = choose_progress_callback(progress_bar) | |||
| if callback is not None: | |||
| _callbacks.append(callback) | |||
| elif progress_bar is not None and progress_bar != 'auto': | |||
| logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.") | |||
| has_no_progress = False | |||
| elif has_no_progress is False and progress_bar not in ('auto', None): | |||
| logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.") | |||
| if has_no_progress and progress_bar is None: | |||
| rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " | |||
| "during training.") | |||
| if has_no_progress: | |||
| logger.rank_zero_warning("No progress bar is provided, there will have no progress output during training.") | |||
| return _callbacks | |||
| @@ -87,17 +87,20 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
| trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
| def on_train_end(self, trainer): | |||
| logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...") | |||
| if self.real_save_folder: | |||
| trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
| model_load_fn=self.model_load_fn) | |||
| else: | |||
| self.buffer.seek(0) | |||
| trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
| if self.delete_after_after: | |||
| trainer.driver.barrier() | |||
| self._delete_folder() | |||
| trainer.driver.barrier() | |||
| if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | |||
| if self.real_save_folder: | |||
| logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") | |||
| trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
| model_load_fn=self.model_load_fn) | |||
| else: | |||
| logger.info( | |||
| f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||
| self.buffer.seek(0) | |||
| trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
| if self.delete_after_after: | |||
| trainer.driver.barrier() | |||
| self._delete_folder() | |||
| trainer.driver.barrier() | |||
| def _delete_folder(self): | |||
| if self.real_save_folder: | |||
| @@ -138,8 +138,6 @@ class PaddleTensorPadder(Padder): | |||
| shapes = [field.shape for field in batch_field] | |||
| max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
| if isinstance(dtype, np.dtype): | |||
| print(dtype) | |||
| tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
| for i, field in enumerate(batch_field): | |||
| slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
| @@ -4,11 +4,11 @@ __all__ = [ | |||
| ] | |||
| from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | |||
| import inspect | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.collators import Collator | |||
| from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
| # from fastNLP.io.data_bundle import DataBundle | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||
| @@ -79,35 +79,30 @@ class TorchDataLoader(DataLoader): | |||
| if sampler is None and batch_sampler is None: | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| shuffle=False | |||
| super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
| batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | |||
| pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, | |||
| persistent_workers=persistent_workers) | |||
| if isinstance(collate_fn, str): | |||
| if collate_fn == 'auto': | |||
| if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||
| self._collate_fn = dataset.dataset.collator | |||
| self._collate_fn.set_backend(backend="torch") | |||
| collate_fn = dataset.dataset.collator | |||
| collate_fn.set_backend(backend="torch") | |||
| else: | |||
| self._collate_fn = Collator(backend="torch") | |||
| collate_fn = Collator(backend="torch") | |||
| else: | |||
| raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
| elif isinstance(collate_fn, Callable): | |||
| if collate_fn is not default_collate: | |||
| self._collate_fn = collate_fn | |||
| else: | |||
| self._collate_fn = default_collate | |||
| self.cur_indices_batch = None | |||
| super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
| batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
| pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, | |||
| persistent_workers=persistent_workers) | |||
| self.cur_batch_indices = None | |||
| def __iter__(self): | |||
| # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||
| # if len(self._collate_fn.get_collators()) == 0: | |||
| # self._collate_fn.add_collator(self.collate_fn) | |||
| self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
| self.collate_fn = indice_collate_wrapper(self.collate_fn) | |||
| for indices, data in super().__iter__(): | |||
| self.cur_batch_indices = indices | |||
| yield data | |||
| @@ -132,12 +127,26 @@ class TorchDataLoader(DataLoader): | |||
| 形式,输出将被直接作为结果输出。 | |||
| :return: 返回 Collator 自身 | |||
| """ | |||
| if isinstance(self._collate_fn, Collator): | |||
| self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||
| return self._collate_fn | |||
| collator = self._get_collator() | |||
| if isinstance(collator, Collator): | |||
| collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||
| return collator | |||
| else: | |||
| raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||
| def _get_collator(self): | |||
| """ | |||
| 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None | |||
| :return: | |||
| """ | |||
| collator = None | |||
| if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): | |||
| collator = self.collate_fn.__wrapped__ | |||
| elif isinstance(self.collate_fn, Collator): | |||
| collator = self.collate_fn | |||
| return collator | |||
| def set_ignore(self, *field_names) -> Collator: | |||
| """ | |||
| 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
| @@ -149,9 +158,10 @@ class TorchDataLoader(DataLoader): | |||
| __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
| :return: 返回 Collator 自身 | |||
| """ | |||
| if isinstance(self._collate_fn, Collator): | |||
| self._collate_fn.set_ignore(*field_names) | |||
| return self._collate_fn | |||
| collator = self._get_collator() | |||
| if isinstance(collator, Collator): | |||
| collator.set_ignore(*field_names) | |||
| return collator | |||
| else: | |||
| raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||
| @@ -164,7 +174,8 @@ class TorchDataLoader(DataLoader): | |||
| return self.cur_batch_indices | |||
| def prepare_torch_dataloader(ds_or_db, | |||
| def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | |||
| batch_size: int = 16, | |||
| shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
| batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||
| @@ -197,7 +208,8 @@ def prepare_torch_dataloader(ds_or_db, | |||
| :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | |||
| :param non_train_batch_size: | |||
| """ | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| from fastNLP.io import DataBundle | |||
| if isinstance(ds_or_db, DataSet): | |||
| dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
| shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
| @@ -208,7 +220,7 @@ def prepare_torch_dataloader(ds_or_db, | |||
| ) | |||
| return dl | |||
| elif isinstance(ds_or_db, DataBundle): | |||
| elif type(ds_or_db, DataBundle): | |||
| dl_bundle = {} | |||
| for name, ds in ds_or_db.iter_datasets(): | |||
| if 'train' in name: | |||
| @@ -10,12 +10,25 @@ def indice_collate_wrapper(func): | |||
| :param func: 需要修饰的函数 | |||
| :return: | |||
| """ | |||
| if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 | |||
| return func | |||
| def wrapper(tuple_data): | |||
| def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 | |||
| indice, ins_list = [], [] | |||
| for idx, ins in tuple_data: | |||
| indice.append(idx) | |||
| ins_list.append(ins) | |||
| return indice, func(ins_list) | |||
| _indice_collate_wrapper.__wrapped__ = func # 记录对应的 | |||
| return wrapper | |||
| return _indice_collate_wrapper | |||
| if __name__ == '__main__': | |||
| def demo(*args, **kwargs): | |||
| pass | |||
| d = indice_collate_wrapper(demo) | |||
| print(d.__name__) | |||
| print(d.__wrapped__) | |||
| @@ -8,6 +8,7 @@ __all__ = [ | |||
| from collections import Counter | |||
| from typing import Any, Union, List, Callable | |||
| from ..log import logger | |||
| import numpy as np | |||
| @@ -21,7 +22,7 @@ class FieldArray: | |||
| try: | |||
| _content = list(_content) | |||
| except BaseException as e: | |||
| print(f"Cannot convert content(of type:{type(content)}) into list.") | |||
| logger.error(f"Cannot convert content(of type:{type(content)}) into list.") | |||
| raise e | |||
| self.name = name | |||
| self.content = _content | |||
| @@ -87,7 +88,7 @@ class FieldArray: | |||
| try: | |||
| new_contents.append(cell.split(sep)) | |||
| except Exception as e: | |||
| print(f"Exception happens when process value in index {index}.") | |||
| logger.error(f"Exception happens when process value in index {index}.") | |||
| raise e | |||
| return self._after_process(new_contents, inplace=inplace) | |||
| @@ -111,7 +111,7 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
| def warning_once(self, msg, *args, **kwargs): | |||
| """ | |||
| 通过 warning 内容只会 warning 一次 | |||
| 相同的 warning 内容只会 warning 一次 | |||
| :param msg: | |||
| :param args: | |||
| @@ -124,6 +124,22 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| self._warning_msgs.add(msg) | |||
| def rank_zero_warning(self, msg, *args, **kwargs): | |||
| """ | |||
| 只在 rank 0 上 warning 。 | |||
| :param msg: | |||
| :param args: | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | |||
| if msg not in self._warning_msgs: | |||
| if self.isEnabledFor(WARNING): | |||
| # kwargs = self._add_rank_info(kwargs) | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| self._warning_msgs.add(msg) | |||
| def warn(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(WARNING): | |||
| kwargs = self._add_rank_info(kwargs) | |||
| @@ -156,8 +156,9 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
| super().stop_task(task_id) | |||
| super().remove_task(task_id) | |||
| self.refresh() # 使得bar不残留 | |||
| if len(self._tasks) == 0: | |||
| super().stop() | |||
| # 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。 | |||
| # if len(self._tasks) == 0: | |||
| # self.live.stop() | |||
| def start(self) -> None: | |||
| super().start() | |||
| @@ -15,6 +15,7 @@ from functools import wraps | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.utils.utils import Option | |||
| from fastNLP.core.utils.utils import _is_iterable | |||
| from .log import logger | |||
| import io | |||
| @@ -56,7 +57,7 @@ def _check_build_status(func): | |||
| if self.rebuild is False: | |||
| self.rebuild = True | |||
| if self.max_size is not None and len(self.word_count) >= self.max_size: | |||
| print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
| logger.warning("Vocabulary has reached the max size {} when calling {} method. " | |||
| "Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
| self.max_size, func.__name__)) | |||
| return func(self, *args, **kwargs) | |||
| @@ -322,7 +323,7 @@ class Vocabulary(object): | |||
| for f_n, n_f_n in zip(field_name, new_field_name): | |||
| dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | |||
| except Exception as e: | |||
| print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
| logger.error("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
| raise e | |||
| else: | |||
| raise RuntimeError("Only DataSet type is allowed.") | |||
| @@ -378,7 +379,7 @@ class Vocabulary(object): | |||
| try: | |||
| dataset.apply(construct_vocab) | |||
| except BaseException as e: | |||
| print("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
| logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
| raise e | |||
| else: | |||
| raise TypeError("Only DataSet type is allowed.") | |||
| @@ -10,7 +10,7 @@ from typing import Union, List, Callable | |||
| from ..core.dataset import DataSet | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| # from ..core._logger import _logger | |||
| from fastNLP.core import logger | |||
| class DataBundle: | |||
| @@ -72,7 +72,7 @@ class DataBundle: | |||
| else: | |||
| error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ | |||
| f'It should be one of {self.datasets.keys()}.' | |||
| print(error_msg) | |||
| logger.error(error_msg) | |||
| raise KeyError(error_msg) | |||
| def delete_dataset(self, name: str): | |||
| @@ -97,7 +97,7 @@ class DataBundle: | |||
| else: | |||
| error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ | |||
| f'It should be one of {self.vocabs.keys()}.' | |||
| print(error_msg) | |||
| logger.error(error_msg) | |||
| raise KeyError(error_msg) | |||
| def delete_vocab(self, field_name: str): | |||
| @@ -117,7 +117,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
| yield line_idx, res | |||
| except Exception as e: | |||
| if dropna: | |||
| print('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||
| logger.error('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||
| sample = [] | |||
| continue | |||
| raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | |||
| @@ -132,5 +132,5 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
| except Exception as e: | |||
| if dropna: | |||
| return | |||
| print('invalid instance ends at line: {}'.format(line_idx)) | |||
| logger.error('invalid instance ends at line: {}'.format(line_idx)) | |||
| raise e | |||
| @@ -29,6 +29,7 @@ import warnings | |||
| from .loader import Loader | |||
| from fastNLP.core.dataset import Instance, DataSet | |||
| from ...core import logger | |||
| # from ...core._logger import log | |||
| @@ -86,7 +87,8 @@ class CLSBaseLoader(Loader): | |||
| if raw_words: | |||
| ds.append(Instance(raw_words=raw_words, target=target)) | |||
| except Exception as e: | |||
| print(f'Load file `{path}` failed for `{e}`') | |||
| logger.error(f'Fail to load `{path}`.') | |||
| raise e | |||
| return ds | |||