| @@ -1,3 +1,6 @@ | |||||
| """ | |||||
| 正在开发中的分布式训练代码 | |||||
| """ | |||||
| import torch | import torch | ||||
| import torch.cuda | import torch.cuda | ||||
| import torch.optim | import torch.optim | ||||
| @@ -41,7 +44,8 @@ def get_local_rank(): | |||||
| class DistTrainer(): | class DistTrainer(): | ||||
| """Distributed Trainer that support distributed and mixed precision training | |||||
| """ | |||||
| Distributed Trainer that support distributed and mixed precision training | |||||
| """ | """ | ||||
| def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
| callbacks_all=None, callbacks_master=None, | callbacks_all=None, callbacks_master=None, | ||||
| @@ -1,4 +1,3 @@ | |||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| import torch | import torch | ||||
| @@ -9,6 +8,10 @@ from ..core.sampler import SequentialSampler | |||||
| from ..core.utils import _move_model_to_device, _get_model_device | from ..core.utils import _move_model_to_device, _get_model_device | ||||
| from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
| __all__ = [ | |||||
| "ContextualEmbedding" | |||||
| ] | |||||
| class ContextualEmbedding(TokenEmbedding): | class ContextualEmbedding(TokenEmbedding): | ||||
| def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | ||||
| @@ -1,7 +1,9 @@ | |||||
| """ | """ | ||||
| 用于读入和处理和保存 config 文件 | 用于读入和处理和保存 config 文件 | ||||
| .. todo:: | |||||
| .. todo:: | |||||
| 这个模块中的类可能被抛弃? | 这个模块中的类可能被抛弃? | ||||
| """ | """ | ||||
| __all__ = [ | __all__ = [ | ||||
| "ConfigLoader", | "ConfigLoader", | ||||
| @@ -1,12 +1,12 @@ | |||||
| from typing import Dict, Union | from typing import Dict, Union | ||||
| from .loader import Loader | from .loader import Loader | ||||
| from ... import DataSet | |||||
| from ...core.dataset import DataSet | |||||
| from ..file_reader import _read_conll | from ..file_reader import _read_conll | ||||
| from ... import Instance | |||||
| from ...core.instance import Instance | |||||
| from .. import DataBundle | from .. import DataBundle | ||||
| from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
| from ... import Const | |||||
| from ...core.const import Const | |||||
| class ConllLoader(Loader): | class ConllLoader(Loader): | ||||
| @@ -1,6 +1,6 @@ | |||||
| from .loader import Loader | from .loader import Loader | ||||
| from ...core import DataSet, Instance | |||||
| from ...core.dataset import DataSet | |||||
| from ...core.instance import Instance | |||||
| class CWSLoader(Loader): | class CWSLoader(Loader): | ||||
| @@ -1,4 +1,4 @@ | |||||
| from ... import DataSet | |||||
| from ...core.dataset import DataSet | |||||
| from .. import DataBundle | from .. import DataBundle | ||||
| from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
| from typing import Union, Dict | from typing import Union, Dict | ||||
| @@ -1,12 +1,12 @@ | |||||
| import warnings | import warnings | ||||
| from .loader import Loader | from .loader import Loader | ||||
| from .json import JsonLoader | from .json import JsonLoader | ||||
| from ...core import Const | |||||
| from ...core.const import Const | |||||
| from .. import DataBundle | from .. import DataBundle | ||||
| import os | import os | ||||
| from typing import Union, Dict | from typing import Union, Dict | ||||
| from ...core import DataSet | |||||
| from ...core import Instance | |||||
| from ...core.dataset import DataSet | |||||
| from ...core.instance import Instance | |||||
| class MNLILoader(Loader): | class MNLILoader(Loader): | ||||
| @@ -4,13 +4,14 @@ from ..base_loader import DataBundle | |||||
| from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
| from ...core.const import Const | from ...core.const import Const | ||||
| from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | ||||
| from ...core import DataSet, Instance | |||||
| from ...core.dataset import DataSet | |||||
| from ...core.instance import Instance | |||||
| from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | ||||
| from .pipe import Pipe | from .pipe import Pipe | ||||
| import re | import re | ||||
| nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | ||||
| from ...core import cache_results | |||||
| from ...core.utils import cache_results | |||||
| class _CLSPipe(Pipe): | class _CLSPipe(Pipe): | ||||
| """ | """ | ||||
| @@ -1,7 +1,7 @@ | |||||
| from .pipe import Pipe | from .pipe import Pipe | ||||
| from .. import DataBundle | from .. import DataBundle | ||||
| from .utils import iob2, iob2bioes | from .utils import iob2, iob2bioes | ||||
| from ... import Const | |||||
| from ...core.const import Const | |||||
| from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | ||||
| from .utils import _indexize, _add_words_field | from .utils import _indexize, _add_words_field | ||||
| @@ -2,8 +2,8 @@ import math | |||||
| from .pipe import Pipe | from .pipe import Pipe | ||||
| from .utils import get_tokenizer | from .utils import get_tokenizer | ||||
| from ...core import Const | |||||
| from ...core import Vocabulary | |||||
| from ...core.const import Const | |||||
| from ...core.vocabulary import Vocabulary | |||||
| from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | ||||
| @@ -1,6 +1,6 @@ | |||||
| from typing import List | from typing import List | ||||
| from ...core import Vocabulary | |||||
| from ...core import Const | |||||
| from ...core.vocabulary import Vocabulary | |||||
| from ...core.const import Const | |||||
| def iob2(tags:List[str])->List[str]: | def iob2(tags:List[str])->List[str]: | ||||
| """ | """ | ||||
| @@ -51,7 +51,7 @@ class ChineseNERLoader(DataSetLoader): | |||||
| :param paths: | :param paths: | ||||
| :param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>] | :param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>] | ||||
| :param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>] | :param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>] | ||||
| :return: DataBundle | |||||
| :return: ~fastNLP.io.DataBundle | |||||
| 包含以下的fields | 包含以下的fields | ||||
| raw_chars: List[str] | raw_chars: List[str] | ||||
| chars: List[int] | chars: List[int] | ||||