| @@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||
| 对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 | |||
| """ | |||
| __all__ = [ | |||
| "DataSet", | |||
| "Instance", | |||
| "FieldArray", | |||
| "Padder", | |||
| "AutoPadder", | |||
| "EngChar2DPadder", | |||
| "Vocabulary", | |||
| "DataSetIter", | |||
| "BatchIter", | |||
| "TorchLoaderIter", | |||
| "Const", | |||
| "Tester", | |||
| "Trainer", | |||
| "cache_results", | |||
| "seq_len_to_mask", | |||
| "get_seq_len", | |||
| "logger", | |||
| "Callback", | |||
| "GradientClipCallback", | |||
| "EarlyStopCallback", | |||
| "FitlogCallback", | |||
| "EvaluateCallback", | |||
| "LRScheduler", | |||
| "ControlC", | |||
| "LRFinder", | |||
| "TensorboardCallback", | |||
| "WarmupCallback", | |||
| 'SaveModelCallback', | |||
| "EchoCallback", | |||
| "TesterCallback", | |||
| "CallbackException", | |||
| "EarlyStopError", | |||
| "LossFunc", | |||
| "CrossEntropyLoss", | |||
| "L1Loss", | |||
| "BCELoss", | |||
| "NLLLoss", | |||
| "LossInForward", | |||
| "AccuracyMetric", | |||
| "SpanFPreRecMetric", | |||
| "ExtractiveQAMetric", | |||
| "Optimizer", | |||
| "SGD", | |||
| "Adam", | |||
| "AdamW", | |||
| "SequentialSampler", | |||
| "BucketSampler", | |||
| "RandomSampler", | |||
| "Sampler", | |||
| ] | |||
| from ._logger import logger | |||
| from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
| from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | |||
| LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | |||
| @@ -28,4 +92,3 @@ from .tester import Tester | |||
| from .trainer import Trainer | |||
| from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
| from .vocabulary import Vocabulary | |||
| from ._logger import logger | |||
| @@ -1,15 +1,15 @@ | |||
| """undocumented""" | |||
| __all__ = [ | |||
| 'logger', | |||
| ] | |||
| import logging | |||
| import logging.config | |||
| import torch | |||
| import _pickle as pickle | |||
| import os | |||
| import sys | |||
| import warnings | |||
| __all__ = [ | |||
| 'logger', | |||
| ] | |||
| ROOT_NAME = 'fastNLP' | |||
| try: | |||
| @@ -25,7 +25,7 @@ if tqdm is not None: | |||
| class TqdmLoggingHandler(logging.Handler): | |||
| def __init__(self, level=logging.INFO): | |||
| super().__init__(level) | |||
| def emit(self, record): | |||
| try: | |||
| msg = self.format(record) | |||
| @@ -59,14 +59,14 @@ def _add_file_handler(logger, path, level='INFO'): | |||
| if os.path.abspath(path) == h.baseFilename: | |||
| # file path already added | |||
| return | |||
| # File Handler | |||
| if os.path.exists(path): | |||
| assert os.path.isfile(path) | |||
| warnings.warn('log already exists in {}'.format(path)) | |||
| dirname = os.path.abspath(os.path.dirname(path)) | |||
| os.makedirs(dirname, exist_ok=True) | |||
| file_handler = logging.FileHandler(path, mode='a') | |||
| file_handler.setLevel(_get_level(level)) | |||
| file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | |||
| @@ -87,7 +87,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||
| break | |||
| if stream_handler is not None: | |||
| logger.removeHandler(stream_handler) | |||
| # Stream Handler | |||
| if stdout == 'plain': | |||
| stream_handler = logging.StreamHandler(sys.stdout) | |||
| @@ -95,7 +95,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||
| stream_handler = TqdmLoggingHandler(level) | |||
| else: | |||
| stream_handler = None | |||
| if stream_handler is not None: | |||
| stream_formatter = logging.Formatter('%(message)s') | |||
| stream_handler.setLevel(level) | |||
| @@ -103,38 +103,40 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||
| logger.addHandler(stream_handler) | |||
| class FastNLPLogger(logging.getLoggerClass()): | |||
| def __init__(self, name): | |||
| super().__init__(name) | |||
| def add_file(self, path='./log.txt', level='INFO'): | |||
| """add log output file and level""" | |||
| _add_file_handler(self, path, level) | |||
| def set_stdout(self, stdout='tqdm', level='INFO'): | |||
| """set stdout format and level""" | |||
| _set_stdout_handler(self, stdout, level) | |||
| logging.setLoggerClass(FastNLPLogger) | |||
| # print(logging.getLoggerClass()) | |||
| # print(logging.getLogger()) | |||
| def _init_logger(path=None, stdout='tqdm', level='INFO'): | |||
| """initialize logger""" | |||
| level = _get_level(level) | |||
| # logger = logging.getLogger() | |||
| logger = logging.getLogger(ROOT_NAME) | |||
| logger.propagate = False | |||
| logger.setLevel(level) | |||
| _set_stdout_handler(logger, stdout, level) | |||
| # File Handler | |||
| if path is not None: | |||
| _add_file_handler(logger, path, level) | |||
| return logger | |||
| @@ -1,11 +1,14 @@ | |||
| """undocumented""" | |||
| __all__ = [] | |||
| import threading | |||
| import torch | |||
| from torch import nn | |||
| from torch.nn.parallel.parallel_apply import get_a_var | |||
| from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
| from torch.nn.parallel.replicate import replicate | |||
| from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
| def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
| @@ -27,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
| assert len(modules) == len(devices) | |||
| else: | |||
| devices = [None] * len(modules) | |||
| lock = threading.Lock() | |||
| results = {} | |||
| grad_enabled = torch.is_grad_enabled() | |||
| def _worker(i, module, input, kwargs, device=None): | |||
| torch.set_grad_enabled(grad_enabled) | |||
| if device is None: | |||
| @@ -47,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
| except Exception as e: | |||
| with lock: | |||
| results[i] = e | |||
| if len(modules) > 1: | |||
| threads = [threading.Thread(target=_worker, | |||
| args=(i, module, input, kwargs, device)) | |||
| for i, (module, input, kwargs, device) in | |||
| enumerate(zip(modules, inputs, kwargs_tup, devices))] | |||
| for thread in threads: | |||
| thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| else: | |||
| _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | |||
| outputs = [] | |||
| for i in range(len(inputs)): | |||
| output = results[i] | |||
| @@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
| :param output_device: nn.DataParallel中的output_device | |||
| :return: | |||
| """ | |||
| def wrapper(network, *inputs, **kwargs): | |||
| inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | |||
| if len(device_ids) == 1: | |||
| @@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
| replicas = replicate(network, device_ids[:len(inputs)]) | |||
| outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | |||
| return gather(outputs, output_device) | |||
| return wrapper | |||
| @@ -99,4 +104,4 @@ def _model_contains_inner_module(model): | |||
| if isinstance(model, nn.Module): | |||
| if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | |||
| return True | |||
| return False | |||
| return False | |||
| @@ -1,3 +1,13 @@ | |||
| """ | |||
| .. todo:: | |||
| doc | |||
| """ | |||
| __all__ = [ | |||
| "Const" | |||
| ] | |||
| class Const: | |||
| """ | |||
| fastNLP中field命名常量。 | |||
| @@ -25,47 +35,47 @@ class Const: | |||
| LOSS = 'loss' | |||
| RAW_WORD = 'raw_words' | |||
| RAW_CHAR = 'raw_chars' | |||
| @staticmethod | |||
| def INPUTS(i): | |||
| """得到第 i 个 ``INPUT`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.INPUT + str(i) | |||
| @staticmethod | |||
| def CHAR_INPUTS(i): | |||
| """得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.CHAR_INPUT + str(i) | |||
| @staticmethod | |||
| def RAW_WORDS(i): | |||
| i = int(i) + 1 | |||
| return Const.RAW_WORD + str(i) | |||
| @staticmethod | |||
| def RAW_CHARS(i): | |||
| i = int(i) + 1 | |||
| return Const.RAW_CHAR + str(i) | |||
| @staticmethod | |||
| def INPUT_LENS(i): | |||
| """得到第 i 个 ``INPUT_LEN`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.INPUT_LEN + str(i) | |||
| @staticmethod | |||
| def OUTPUTS(i): | |||
| """得到第 i 个 ``OUTPUT`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.OUTPUT + str(i) | |||
| @staticmethod | |||
| def TARGETS(i): | |||
| """得到第 i 个 ``TARGET`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.TARGET + str(i) | |||
| @staticmethod | |||
| def LOSSES(i): | |||
| """得到第 i 个 ``LOSS`` 的命名""" | |||
| @@ -1,29 +1,29 @@ | |||
| """ | |||
| """undocumented | |||
| 正在开发中的分布式训练代码 | |||
| """ | |||
| import logging | |||
| import os | |||
| import time | |||
| from datetime import datetime | |||
| import torch | |||
| import torch.cuda | |||
| import torch.optim | |||
| import torch.distributed as dist | |||
| from torch.utils.data.distributed import DistributedSampler | |||
| import torch.optim | |||
| from pkg_resources import parse_version | |||
| from torch.nn.parallel import DistributedDataParallel as DDP | |||
| import os | |||
| from torch.utils.data.distributed import DistributedSampler | |||
| from tqdm import tqdm | |||
| import time | |||
| from datetime import datetime, timedelta | |||
| from functools import partial | |||
| from ._logger import logger | |||
| from .batch import DataSetIter, BatchIter | |||
| from .callback import DistCallbackManager, CallbackException, TesterCallback | |||
| from .dataset import DataSet | |||
| from .losses import _prepare_losser | |||
| from .optimizer import Optimizer | |||
| from .utils import _build_args | |||
| from .utils import _move_dict_value_to_device | |||
| from .utils import _get_func_signature | |||
| from ._logger import logger | |||
| import logging | |||
| from pkg_resources import parse_version | |||
| from .utils import _move_dict_value_to_device | |||
| __all__ = [ | |||
| 'get_local_rank', | |||
| @@ -1,18 +1,25 @@ | |||
| """ | |||
| .. todo:: | |||
| doc | |||
| """ | |||
| __all__ = [ | |||
| "Padder", | |||
| "AutoPadder", | |||
| "EngChar2DPadder", | |||
| ] | |||
| from numbers import Number | |||
| import torch | |||
| import numpy as np | |||
| from typing import Any | |||
| from abc import abstractmethod | |||
| from copy import deepcopy | |||
| from collections import Counter | |||
| from .utils import _is_iterable | |||
| from copy import deepcopy | |||
| from numbers import Number | |||
| from typing import Any | |||
| import numpy as np | |||
| import torch | |||
| from ._logger import logger | |||
| from .utils import _is_iterable | |||
| class SetInputOrTargetException(Exception): | |||
| @@ -1,13 +1,15 @@ | |||
| """ | |||
| ..todo:: | |||
| 检查这个类是否需要 | |||
| """ | |||
| """undocumented""" | |||
| __all__ = [ | |||
| "Predictor" | |||
| ] | |||
| from collections import defaultdict | |||
| import torch | |||
| from . import DataSetIter | |||
| from . import DataSet | |||
| from . import DataSetIter | |||
| from . import SequentialSampler | |||
| from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||
| @@ -21,7 +23,7 @@ class Predictor(object): | |||
| :param torch.nn.Module network: 用来完成预测任务的模型 | |||
| """ | |||
| def __init__(self, network): | |||
| if not isinstance(network, torch.nn.Module): | |||
| raise ValueError( | |||
| @@ -29,7 +31,7 @@ class Predictor(object): | |||
| self.network = network | |||
| self.batch_size = 1 | |||
| self.batch_output = [] | |||
| def predict(self, data: DataSet, seq_len_field_name=None): | |||
| """用已经训练好的模型进行inference. | |||
| @@ -41,27 +43,27 @@ class Predictor(object): | |||
| raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | |||
| if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | |||
| raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||
| prev_training = self.network.training | |||
| self.network.eval() | |||
| network_device = _get_model_device(self.network) | |||
| batch_output = defaultdict(list) | |||
| data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
| if hasattr(self.network, "predict"): | |||
| predict_func = self.network.predict | |||
| else: | |||
| predict_func = self.network.forward | |||
| with torch.no_grad(): | |||
| for batch_x, _ in data_iterator: | |||
| _move_dict_value_to_device(batch_x, _, device=network_device) | |||
| refined_batch_x = _build_args(predict_func, **batch_x) | |||
| prediction = predict_func(**refined_batch_x) | |||
| if seq_len_field_name is not None: | |||
| seq_lens = batch_x[seq_len_field_name].tolist() | |||
| for key, value in prediction.items(): | |||
| value = value.cpu().numpy() | |||
| if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | |||
| @@ -74,6 +76,6 @@ class Predictor(object): | |||
| batch_output[key].extend(tmp_batch) | |||
| else: | |||
| batch_output[key].append(value) | |||
| self.network.train(prev_training) | |||
| return batch_output | |||
| @@ -1,16 +1,22 @@ | |||
| """ | |||
| .. todo:: | |||
| doc | |||
| """ | |||
| __all__ = [ | |||
| "Vocabulary", | |||
| "VocabularyOption", | |||
| ] | |||
| from functools import wraps | |||
| from collections import Counter | |||
| from functools import partial | |||
| from functools import wraps | |||
| from ._logger import logger | |||
| from .dataset import DataSet | |||
| from .utils import Option | |||
| from functools import partial | |||
| import numpy as np | |||
| from .utils import _is_iterable | |||
| from ._logger import logger | |||
| class VocabularyOption(Option): | |||
| def __init__(self, | |||
| @@ -51,7 +57,7 @@ def _check_build_status(func): | |||
| self.rebuild = True | |||
| if self.max_size is not None and len(self.word_count) >= self.max_size: | |||
| logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
| "Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
| "Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
| self.max_size, func.__name__)) | |||
| return func(self, *args, **kwargs) | |||
| @@ -199,7 +205,7 @@ class Vocabulary(object): | |||
| self.build_reverse_vocab() | |||
| self.rebuild = False | |||
| return self | |||
| def build_reverse_vocab(self): | |||
| """ | |||
| 基于 `word to index` dict, 构建 `index to word` dict. | |||
| @@ -279,19 +285,19 @@ class Vocabulary(object): | |||
| if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||
| raise RuntimeError("Only support field with 2 dimensions.") | |||
| return [[self.to_index(c) for c in w] for w in field] | |||
| new_field_name = new_field_name or field_name | |||
| if type(new_field_name) == type(field_name): | |||
| if isinstance(new_field_name, list): | |||
| assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ | |||
| "field_name." | |||
| "field_name." | |||
| elif isinstance(new_field_name, str): | |||
| field_name = [field_name] | |||
| new_field_name = [new_field_name] | |||
| else: | |||
| raise TypeError("field_name and new_field_name can only be str or List[str].") | |||
| for idx, dataset in enumerate(datasets): | |||
| if isinstance(dataset, DataSet): | |||
| try: | |||
| @@ -377,7 +383,7 @@ class Vocabulary(object): | |||
| :return: bool | |||
| """ | |||
| return word in self._no_create_word | |||
| def to_index(self, w): | |||
| """ | |||
| 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | |||
| @@ -8,15 +8,17 @@ __all__ = [ | |||
| ] | |||
| from abc import abstractmethod | |||
| import torch | |||
| from ..core.vocabulary import Vocabulary | |||
| from ..core.dataset import DataSet | |||
| from .embedding import TokenEmbedding | |||
| from ..core import logger | |||
| from ..core.batch import DataSetIter | |||
| from ..core.dataset import DataSet | |||
| from ..core.sampler import SequentialSampler | |||
| from ..core.utils import _move_model_to_device, _get_model_device | |||
| from .embedding import TokenEmbedding | |||
| from ..core import logger | |||
| from ..core.vocabulary import Vocabulary | |||
| class ContextualEmbedding(TokenEmbedding): | |||
| def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): | |||