| @@ -1,11 +1,11 @@ | |||
| """ | |||
| 大部分用于的 NLP 任务神经网络都可以看做由编码 :mod:`~fastNLP.modules.encoder` 、 | |||
| 聚合 :mod:`~fastNLP.modules.aggregator` 、解码 :mod:`~fastNLP.modules.decoder` 三种模块组成。 | |||
| 解码 :mod:`~fastNLP.modules.decoder` 两种模块组成。 | |||
| .. image:: figures/text_classification.png | |||
| :mod:`~fastNLP.modules` 中实现了 fastNLP 提供的诸多模块组件,可以帮助用户快速搭建自己所需的网络。 | |||
| 三种模块的功能和常见组件如下: | |||
| 两种模块的功能和常见组件如下: | |||
| +-----------------------+-----------------------+-----------------------+ | |||
| | module type | functionality | example | | |||
| @@ -13,9 +13,6 @@ | |||
| | encoder | 将输入编码为具有具 | embedding, RNN, CNN, | | |||
| | | 有表示能力的向量 | transformer | | |||
| +-----------------------+-----------------------+-----------------------+ | |||
| | aggregator | 从多个向量中聚合信息 | self-attention, | | |||
| | | | max-pooling | | |||
| +-----------------------+-----------------------+-----------------------+ | |||
| | decoder | 将具有某种表示意义的 | MLP, CRF | | |||
| | | 向量解码为需要的输出 | | | |||
| | | 形式 | | | |||
| @@ -46,10 +43,8 @@ __all__ = [ | |||
| "allowed_transitions", | |||
| ] | |||
| from . import aggregator | |||
| from . import decoder | |||
| from . import encoder | |||
| from .aggregator import * | |||
| from .decoder import * | |||
| from .dropout import TimestepDropout | |||
| from .encoder import * | |||
| @@ -22,7 +22,14 @@ __all__ = [ | |||
| "VarRNN", | |||
| "VarLSTM", | |||
| "VarGRU" | |||
| "VarGRU", | |||
| "MaxPool", | |||
| "MaxPoolWithMask", | |||
| "AvgPool", | |||
| "AvgPoolWithMask", | |||
| "MultiHeadAttention", | |||
| ] | |||
| from ._bert import BertModel | |||
| from .bert import BertWordPieceEncoder | |||
| @@ -34,3 +41,6 @@ from .lstm import LSTM | |||
| from .star_transformer import StarTransformer | |||
| from .transformer import TransformerEncoder | |||
| from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
| from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask | |||
| from .attention import MultiHeadAttention | |||
| @@ -45,8 +45,7 @@ class DotAttention(nn.Module): | |||
| class MultiHeadAttention(nn.Module): | |||
| """ | |||
| 别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.aggregator.attention.MultiHeadAttention` | |||
| 别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.encoder.attention.MultiHeadAttention` | |||
| :param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||
| :param key_size: int, 每个head的维度大小。 | |||
| @@ -2,35 +2,22 @@ | |||
| import os | |||
| from torch import nn | |||
| import torch | |||
| from ...io.file_utils import _get_base_url, cached_path | |||
| from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
| from ._bert import _WordPieceBertModel, BertModel | |||
| class BertWordPieceEncoder(nn.Module): | |||
| """ | |||
| 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | |||
| :param fastNLP.Vocabulary vocab: 词表 | |||
| :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | |||
| :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||
| :param bool requires_grad: 是否需要gradient。 | |||
| """ | |||
| def __init__(self, model_dir_or_name:str='en-base-uncased', layers:str='-1', | |||
| requires_grad:bool=False): | |||
| def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||
| requires_grad: bool=False): | |||
| super().__init__() | |||
| PRETRAIN_URL = _get_base_url('bert') | |||
| PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
| 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
| 'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
| 'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
| 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
| 'cn': 'bert-base-chinese-29d0a84a.zip', | |||
| 'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
| 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
| 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
| 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
| } | |||
| if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||
| model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
| @@ -89,4 +76,4 @@ class BertWordPieceEncoder(nn.Module): | |||
| outputs = self.model(word_pieces, token_type_ids) | |||
| outputs = torch.cat([*outputs], dim=-1) | |||
| return outputs | |||
| return outputs | |||
| @@ -1,7 +1,8 @@ | |||
| __all__ = [ | |||
| "MaxPool", | |||
| "MaxPoolWithMask", | |||
| "AvgPool" | |||
| "AvgPool", | |||
| "AvgPoolWithMask" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| @@ -9,7 +10,7 @@ import torch.nn as nn | |||
| class MaxPool(nn.Module): | |||
| """ | |||
| 别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.aggregator.pooling.MaxPool` | |||
| 别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.encoder.pooling.MaxPool` | |||
| Max-pooling模块。 | |||
| @@ -58,7 +59,7 @@ class MaxPool(nn.Module): | |||
| class MaxPoolWithMask(nn.Module): | |||
| """ | |||
| 别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.aggregator.pooling.MaxPoolWithMask` | |||
| 别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.encoder.pooling.MaxPoolWithMask` | |||
| 带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 | |||
| """ | |||
| @@ -98,7 +99,7 @@ class KMaxPool(nn.Module): | |||
| class AvgPool(nn.Module): | |||
| """ | |||
| 别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.aggregator.pooling.AvgPool` | |||
| 别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.encoder.pooling.AvgPool` | |||
| 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] | |||
| """ | |||
| @@ -125,7 +126,7 @@ class AvgPool(nn.Module): | |||
| class AvgPoolWithMask(nn.Module): | |||
| """ | |||
| 别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.aggregator.pooling.AvgPoolWithMask` | |||
| 别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.encoder.pooling.AvgPoolWithMask` | |||
| 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | |||
| 的时候只会考虑mask为1的位置 | |||
| @@ -9,7 +9,7 @@ | |||
| # 任务复现 | |||
| ## Text Classification (文本分类) | |||
| - still in progress | |||
| - [Text Classification 文本分类任务复现](text_classification) | |||
| ## Matching (自然语言推理/句子匹配) | |||
| @@ -21,11 +21,11 @@ | |||
| ## Coreference resolution (指代消解) | |||
| - still in progress | |||
| - [Coreference resolution 指代消解任务复现](coreference_resolution) | |||
| ## Summarization (摘要) | |||
| - still in progress | |||
| - [BertSum](Summmarization) | |||
| ## ... | |||