| @@ -52,8 +52,8 @@ __all__ = [ | |||
| "cache_results" | |||
| ] | |||
| __version__ = '0.4.0' | |||
| from .core import * | |||
| from . import models | |||
| from . import modules | |||
| __version__ = '0.4.0' | |||
| @@ -2,6 +2,10 @@ | |||
| batch 模块实现了 fastNLP 所需的 Batch 类。 | |||
| """ | |||
| __all__ = [ | |||
| "Batch" | |||
| ] | |||
| import atexit | |||
| from queue import Empty, Full | |||
| @@ -11,10 +15,6 @@ import torch.multiprocessing as mp | |||
| from .sampler import RandomSampler | |||
| __all__ = [ | |||
| "Batch" | |||
| ] | |||
| _python_is_exit = False | |||
| @@ -49,6 +49,18 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: | |||
| trainer.train() | |||
| """ | |||
| __all__ = [ | |||
| "Callback", | |||
| "GradientClipCallback", | |||
| "EarlyStopCallback", | |||
| "TensorboardCallback", | |||
| "LRScheduler", | |||
| "ControlC", | |||
| "CallbackException", | |||
| "EarlyStopError" | |||
| ] | |||
| import os | |||
| import torch | |||
| @@ -62,18 +74,6 @@ except: | |||
| from ..io.model_io import ModelSaver, ModelLoader | |||
| __all__ = [ | |||
| "Callback", | |||
| "GradientClipCallback", | |||
| "EarlyStopCallback", | |||
| "TensorboardCallback", | |||
| "LRScheduler", | |||
| "ControlC", | |||
| "CallbackException", | |||
| "EarlyStopError" | |||
| ] | |||
| class Callback(object): | |||
| """ | |||
| @@ -272,6 +272,10 @@ | |||
| """ | |||
| __all__ = [ | |||
| "DataSet" | |||
| ] | |||
| import _pickle as pickle | |||
| import warnings | |||
| @@ -282,10 +286,6 @@ from .field import FieldArray | |||
| from .instance import Instance | |||
| from .utils import _get_func_signature | |||
| __all__ = [ | |||
| "DataSet" | |||
| ] | |||
| class DataSet(object): | |||
| """ | |||
| @@ -3,10 +3,6 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas | |||
| 原理部分请参考 :doc:`fastNLP.core.dataset` | |||
| """ | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| __all__ = [ | |||
| "FieldArray", | |||
| "Padder", | |||
| @@ -14,6 +10,10 @@ __all__ = [ | |||
| "EngChar2DPadder" | |||
| ] | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| class FieldArray(object): | |||
| """ | |||
| @@ -2,6 +2,18 @@ | |||
| losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
| """ | |||
| __all__ = [ | |||
| "LossBase", | |||
| "LossFunc", | |||
| "LossInForward", | |||
| "CrossEntropyLoss", | |||
| "BCELoss", | |||
| "L1Loss", | |||
| "NLLLoss" | |||
| ] | |||
| import inspect | |||
| from collections import defaultdict | |||
| @@ -15,18 +27,6 @@ from .utils import _check_arg_dict_list | |||
| from .utils import _check_function_or_method | |||
| from .utils import _get_func_signature | |||
| __all__ = [ | |||
| "LossBase", | |||
| "LossFunc", | |||
| "LossInForward", | |||
| "CrossEntropyLoss", | |||
| "BCELoss", | |||
| "L1Loss", | |||
| "NLLLoss" | |||
| ] | |||
| class LossBase(object): | |||
| """ | |||
| @@ -2,6 +2,13 @@ | |||
| metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
| """ | |||
| __all__ = [ | |||
| "MetricBase", | |||
| "AccuracyMetric", | |||
| "SpanFPreRecMetric", | |||
| "SQuADMetric" | |||
| ] | |||
| import inspect | |||
| from collections import defaultdict | |||
| @@ -16,13 +23,6 @@ from .utils import _get_func_signature | |||
| from .utils import seq_len_to_mask | |||
| from .vocabulary import Vocabulary | |||
| __all__ = [ | |||
| "MetricBase", | |||
| "AccuracyMetric", | |||
| "SpanFPreRecMetric", | |||
| "SQuADMetric" | |||
| ] | |||
| class MetricBase(object): | |||
| """ | |||
| @@ -2,14 +2,14 @@ | |||
| optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
| """ | |||
| import torch | |||
| __all__ = [ | |||
| "Optimizer", | |||
| "SGD", | |||
| "Adam" | |||
| ] | |||
| import torch | |||
| class Optimizer(object): | |||
| """ | |||
| @@ -1,10 +1,6 @@ | |||
| """ | |||
| sampler 子类实现了 fastNLP 所需的各种采样器。 | |||
| """ | |||
| from itertools import chain | |||
| import numpy as np | |||
| __all__ = [ | |||
| "Sampler", | |||
| "BucketSampler", | |||
| @@ -12,6 +8,10 @@ __all__ = [ | |||
| "RandomSampler" | |||
| ] | |||
| from itertools import chain | |||
| import numpy as np | |||
| class Sampler(object): | |||
| """ | |||
| @@ -295,6 +295,9 @@ Example2.3 | |||
| fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。 | |||
| """ | |||
| __all__ = [ | |||
| "Trainer" | |||
| ] | |||
| import os | |||
| import time | |||
| @@ -1,6 +1,11 @@ | |||
| """ | |||
| utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | |||
| """ | |||
| __all__ = [ | |||
| "cache_results", | |||
| "seq_len_to_mask" | |||
| ] | |||
| import _pickle | |||
| import inspect | |||
| import os | |||
| @@ -11,10 +16,6 @@ import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| __all__ = [ | |||
| "cache_results", | |||
| "seq_len_to_mask" | |||
| ] | |||
| _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
| 'varargs']) | |||
| @@ -1,12 +1,12 @@ | |||
| __all__ = [ | |||
| "Vocabulary" | |||
| ] | |||
| from functools import wraps | |||
| from collections import Counter | |||
| from .dataset import DataSet | |||
| __all__ = [ | |||
| "Vocabulary" | |||
| ] | |||
| def _check_build_vocab(func): | |||
| """A decorator to make sure the indexing is built before used. | |||
| @@ -322,7 +322,7 @@ class Vocabulary(object): | |||
| :return str word: the word | |||
| """ | |||
| return self.idx2word[idx] | |||
| def clear(self): | |||
| """ | |||
| 删除Vocabulary中的词表数据。相当于重新初始化一下。 | |||
| @@ -333,7 +333,7 @@ class Vocabulary(object): | |||
| self.word2idx = None | |||
| self.idx2word = None | |||
| self.rebuild = True | |||
| def __getstate__(self): | |||
| """Use to prepare data for pickle. | |||
| @@ -9,11 +9,6 @@ | |||
| 这些类的使用方法如下: | |||
| """ | |||
| from .embed_loader import EmbedLoader | |||
| from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||
| PeopleDailyCorpusLoader, Conll2003Loader | |||
| from .model_io import ModelLoader, ModelSaver | |||
| __all__ = [ | |||
| 'EmbedLoader', | |||
| @@ -29,3 +24,8 @@ __all__ = [ | |||
| 'ModelLoader', | |||
| 'ModelSaver', | |||
| ] | |||
| from .embed_loader import EmbedLoader | |||
| from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||
| PeopleDailyCorpusLoader, Conll2003Loader | |||
| from .model_io import ModelLoader, ModelSaver | |||
| @@ -1,10 +1,10 @@ | |||
| import _pickle as pickle | |||
| import os | |||
| __all__ = [ | |||
| "BaseLoader" | |||
| ] | |||
| import _pickle as pickle | |||
| import os | |||
| class BaseLoader(object): | |||
| """ | |||
| @@ -3,18 +3,18 @@ | |||
| .. todo:: | |||
| 这个模块中的类可能被抛弃? | |||
| """ | |||
| import configparser | |||
| import json | |||
| import os | |||
| from .base_loader import BaseLoader | |||
| __all__ = [ | |||
| "ConfigLoader", | |||
| "ConfigSection", | |||
| "ConfigSaver" | |||
| ] | |||
| import configparser | |||
| import json | |||
| import os | |||
| from .base_loader import BaseLoader | |||
| class ConfigLoader(BaseLoader): | |||
| """ | |||
| @@ -10,12 +10,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 | |||
| # ... do stuff | |||
| """ | |||
| from nltk.tree import Tree | |||
| from ..core.dataset import DataSet | |||
| from ..core.instance import Instance | |||
| from .file_reader import _read_csv, _read_json, _read_conll | |||
| __all__ = [ | |||
| 'DataSetLoader', | |||
| 'CSVLoader', | |||
| @@ -27,6 +21,12 @@ __all__ = [ | |||
| 'Conll2003Loader', | |||
| ] | |||
| from nltk.tree import Tree | |||
| from ..core.dataset import DataSet | |||
| from ..core.instance import Instance | |||
| from .file_reader import _read_csv, _read_json, _read_conll | |||
| def _download_from_url(url, path): | |||
| try: | |||
| @@ -1,3 +1,7 @@ | |||
| __all__ = [ | |||
| "EmbedLoader" | |||
| ] | |||
| import os | |||
| import warnings | |||
| @@ -6,10 +10,6 @@ import numpy as np | |||
| from ..core.vocabulary import Vocabulary | |||
| from .base_loader import BaseLoader | |||
| __all__ = [ | |||
| "EmbedLoader" | |||
| ] | |||
| class EmbedLoader(BaseLoader): | |||
| """ | |||
| @@ -1,15 +1,15 @@ | |||
| """ | |||
| 用于载入和保存模型 | |||
| """ | |||
| import torch | |||
| from .base_loader import BaseLoader | |||
| __all__ = [ | |||
| "ModelLoader", | |||
| "ModelSaver" | |||
| ] | |||
| import torch | |||
| from .base_loader import BaseLoader | |||
| class ModelLoader(BaseLoader): | |||
| """ | |||
| @@ -7,15 +7,6 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models | |||
| """ | |||
| from .base_model import BaseModel | |||
| from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | |||
| BertForTokenClassification | |||
| from .biaffine_parser import BiaffineParser, GraphParser | |||
| from .cnn_text_classification import CNNText | |||
| from .sequence_labeling import SeqLabeling, AdvSeqLabel | |||
| from .snli import ESIM | |||
| from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
| __all__ = [ | |||
| "CNNText", | |||
| @@ -32,3 +23,12 @@ __all__ = [ | |||
| "BiaffineParser", | |||
| "GraphParser" | |||
| ] | |||
| from .base_model import BaseModel | |||
| from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | |||
| BertForTokenClassification | |||
| from .biaffine_parser import BiaffineParser, GraphParser | |||
| from .cnn_text_classification import CNNText | |||
| from .sequence_labeling import SeqLabeling, AdvSeqLabel | |||
| from .snli import ESIM | |||
| from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
| @@ -1,6 +1,11 @@ | |||
| """ | |||
| Biaffine Dependency Parser 的 Pytorch 实现. | |||
| """ | |||
| __all__ = [ | |||
| "BiaffineParser", | |||
| "GraphParser" | |||
| ] | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| @@ -19,11 +24,6 @@ from ..modules.utils import get_embeddings | |||
| from .base_model import BaseModel | |||
| from ..core.utils import seq_len_to_mask | |||
| __all__ = [ | |||
| "BiaffineParser", | |||
| "GraphParser" | |||
| ] | |||
| def _mst(scores): | |||
| """ | |||
| @@ -1,13 +1,13 @@ | |||
| __all__ = [ | |||
| "CNNText" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| from ..core.const import Const as C | |||
| from ..modules import encoder | |||
| __all__ = [ | |||
| "CNNText" | |||
| ] | |||
| class CNNText(torch.nn.Module): | |||
| """ | |||
| @@ -1,6 +1,5 @@ | |||
| # Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||
| from __future__ import print_function | |||
| from collections import defaultdict | |||
| import collections | |||
| @@ -1,6 +1,11 @@ | |||
| """ | |||
| 本模块实现了两种序列标注模型 | |||
| """ | |||
| __all__ = [ | |||
| "SeqLabeling", | |||
| "AdvSeqLabel" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| @@ -10,11 +15,6 @@ from ..modules.decoder.crf import allowed_transitions | |||
| from ..core.utils import seq_len_to_mask | |||
| from ..core.const import Const as C | |||
| __all__ = [ | |||
| "SeqLabeling", | |||
| "AdvSeqLabel" | |||
| ] | |||
| class SeqLabeling(BaseModel): | |||
| """ | |||
| @@ -1,3 +1,7 @@ | |||
| __all__ = [ | |||
| "ESIM" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| @@ -8,10 +12,6 @@ from ..modules import encoder as Encoder | |||
| from ..modules import aggregator as Aggregator | |||
| from ..core.utils import seq_len_to_mask | |||
| __all__ = [ | |||
| "ESIM" | |||
| ] | |||
| my_inf = 10e12 | |||
| @@ -1,6 +1,13 @@ | |||
| """ | |||
| Star-Transformer 的 Pytorch 实现。 | |||
| """ | |||
| __all__ = [ | |||
| "StarTransEnc", | |||
| "STNLICls", | |||
| "STSeqCls", | |||
| "STSeqLabel", | |||
| ] | |||
| import torch | |||
| from torch import nn | |||
| @@ -9,13 +16,6 @@ from ..core.utils import seq_len_to_mask | |||
| from ..modules.utils import get_embeddings | |||
| from ..core.const import Const | |||
| __all__ = [ | |||
| "StarTransEnc", | |||
| "STNLICls", | |||
| "STSeqCls", | |||
| "STSeqLabel", | |||
| ] | |||
| class StarTransEnc(nn.Module): | |||
| """ | |||
| @@ -22,15 +22,6 @@ | |||
| +-----------------------+-----------------------+-----------------------+ | |||
| """ | |||
| from . import aggregator | |||
| from . import decoder | |||
| from . import encoder | |||
| from .aggregator import * | |||
| from .decoder import * | |||
| from .dropout import TimestepDropout | |||
| from .encoder import * | |||
| from .utils import get_embeddings | |||
| __all__ = [ | |||
| # "BertModel", | |||
| "ConvolutionCharEncoder", | |||
| @@ -54,3 +45,12 @@ __all__ = [ | |||
| "viterbi_decode", | |||
| "allowed_transitions", | |||
| ] | |||
| from . import aggregator | |||
| from . import decoder | |||
| from . import encoder | |||
| from .aggregator import * | |||
| from .decoder import * | |||
| from .dropout import TimestepDropout | |||
| from .encoder import * | |||
| from .utils import get_embeddings | |||
| @@ -1,10 +1,3 @@ | |||
| from .pooling import MaxPool | |||
| from .pooling import MaxPoolWithMask | |||
| from .pooling import AvgPool | |||
| from .pooling import AvgPoolWithMask | |||
| from .attention import MultiHeadAttention | |||
| __all__ = [ | |||
| "MaxPool", | |||
| "MaxPoolWithMask", | |||
| @@ -12,3 +5,10 @@ __all__ = [ | |||
| "MultiHeadAttention", | |||
| ] | |||
| from .pooling import MaxPool | |||
| from .pooling import MaxPoolWithMask | |||
| from .pooling import AvgPool | |||
| from .pooling import AvgPoolWithMask | |||
| from .attention import MultiHeadAttention | |||
| @@ -1,3 +1,7 @@ | |||
| __all__ = [ | |||
| "MultiHeadAttention" | |||
| ] | |||
| import math | |||
| import torch | |||
| @@ -8,10 +12,6 @@ from ..dropout import TimestepDropout | |||
| from ..utils import initial_parameter | |||
| __all__ = [ | |||
| "MultiHeadAttention" | |||
| ] | |||
| class DotAttention(nn.Module): | |||
| """ | |||
| @@ -1,4 +1,8 @@ | |||
| __all__ = ["MaxPool", "MaxPoolWithMask", "AvgPool"] | |||
| __all__ = [ | |||
| "MaxPool", | |||
| "MaxPoolWithMask", | |||
| "AvgPool" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| @@ -16,6 +20,7 @@ class MaxPool(nn.Module): | |||
| :param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension | |||
| :param ceil_mode: | |||
| """ | |||
| def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | |||
| super(MaxPool, self).__init__() | |||
| @@ -125,7 +130,7 @@ class AvgPoolWithMask(nn.Module): | |||
| 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | |||
| 的时候只会考虑mask为1的位置 | |||
| """ | |||
| def __init__(self): | |||
| super(AvgPoolWithMask, self).__init__() | |||
| self.inf = 10e12 | |||
| @@ -1,11 +1,11 @@ | |||
| from .crf import ConditionalRandomField | |||
| from .mlp import MLP | |||
| from .utils import viterbi_decode | |||
| from .crf import allowed_transitions | |||
| __all__ = [ | |||
| "MLP", | |||
| "ConditionalRandomField", | |||
| "viterbi_decode", | |||
| "allowed_transitions" | |||
| ] | |||
| from .crf import ConditionalRandomField | |||
| from .mlp import MLP | |||
| from .utils import viterbi_decode | |||
| from .crf import allowed_transitions | |||
| @@ -1,13 +1,13 @@ | |||
| import torch | |||
| from torch import nn | |||
| from ..utils import initial_parameter | |||
| __all__ = [ | |||
| "ConditionalRandomField", | |||
| "allowed_transitions" | |||
| ] | |||
| import torch | |||
| from torch import nn | |||
| from ..utils import initial_parameter | |||
| def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||
| """ | |||
| @@ -1,12 +1,12 @@ | |||
| __all__ = [ | |||
| "MLP" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| from ..utils import initial_parameter | |||
| __all__ = [ | |||
| "MLP" | |||
| ] | |||
| class MLP(nn.Module): | |||
| """ | |||
| @@ -1,8 +1,7 @@ | |||
| import torch | |||
| __all__ = [ | |||
| "viterbi_decode" | |||
| ] | |||
| import torch | |||
| def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||
| @@ -1,6 +1,8 @@ | |||
| import torch | |||
| __all__ = [] | |||
| import torch | |||
| class TimestepDropout(torch.nn.Dropout): | |||
| """ | |||
| 别名::class:`fastNLP.modules.TimestepDropout` | |||
| @@ -8,7 +10,7 @@ class TimestepDropout(torch.nn.Dropout): | |||
| 接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``) | |||
| 在每个timestamp上做dropout。 | |||
| """ | |||
| def forward(self, x): | |||
| dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | |||
| torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | |||
| @@ -1,12 +1,3 @@ | |||
| from .bert import BertModel | |||
| from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||
| from .conv_maxpool import ConvMaxpool | |||
| from .embedding import Embedding | |||
| from .lstm import LSTM | |||
| from .star_transformer import StarTransformer | |||
| from .transformer import TransformerEncoder | |||
| from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
| __all__ = [ | |||
| # "BertModel", | |||
| @@ -27,3 +18,11 @@ __all__ = [ | |||
| "VarLSTM", | |||
| "VarGRU" | |||
| ] | |||
| from .bert import BertModel | |||
| from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||
| from .conv_maxpool import ConvMaxpool | |||
| from .embedding import Embedding | |||
| from .lstm import LSTM | |||
| from .star_transformer import StarTransformer | |||
| from .transformer import TransformerEncoder | |||
| from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
| @@ -1,12 +1,11 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from ..utils import initial_parameter | |||
| __all__ = [ | |||
| "ConvolutionCharEncoder", | |||
| "LSTMCharEncoder" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| from ..utils import initial_parameter | |||
| # from torch.nn.init import xavier_uniform | |||
| @@ -1,13 +1,12 @@ | |||
| __all__ = [ | |||
| "ConvMaxpool" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from ..utils import initial_parameter | |||
| __all__ = [ | |||
| "ConvMaxpool" | |||
| ] | |||
| class ConvMaxpool(nn.Module): | |||
| """ | |||
| @@ -1,9 +1,8 @@ | |||
| import torch.nn as nn | |||
| from ..utils import get_embeddings | |||
| __all__ = [ | |||
| "Embedding" | |||
| ] | |||
| import torch.nn as nn | |||
| from ..utils import get_embeddings | |||
| class Embedding(nn.Embedding): | |||
| @@ -2,16 +2,16 @@ | |||
| 轻量封装的 Pytorch LSTM 模块. | |||
| 可在 forward 时传入序列的长度, 自动对padding做合适的处理. | |||
| """ | |||
| __all__ = [ | |||
| "LSTM" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.utils.rnn as rnn | |||
| from ..utils import initial_parameter | |||
| __all__ = [ | |||
| "LSTM" | |||
| ] | |||
| class LSTM(nn.Module): | |||
| """ | |||
| @@ -1,15 +1,15 @@ | |||
| """ | |||
| Star-Transformer 的encoder部分的 Pytorch 实现 | |||
| """ | |||
| __all__ = [ | |||
| "StarTransformer" | |||
| ] | |||
| import numpy as NP | |||
| import torch | |||
| from torch import nn | |||
| from torch.nn import functional as F | |||
| __all__ = [ | |||
| "StarTransformer" | |||
| ] | |||
| class StarTransformer(nn.Module): | |||
| """ | |||
| @@ -1,12 +1,11 @@ | |||
| __all__ = [ | |||
| "TransformerEncoder" | |||
| ] | |||
| from torch import nn | |||
| from ..aggregator.attention import MultiHeadAttention | |||
| from ..dropout import TimestepDropout | |||
| __all__ = [ | |||
| "TransformerEncoder" | |||
| ] | |||
| class TransformerEncoder(nn.Module): | |||
| """ | |||
| @@ -1,6 +1,12 @@ | |||
| """ | |||
| Variational RNN 的 Pytorch 实现 | |||
| """ | |||
| __all__ = [ | |||
| "VarRNN", | |||
| "VarLSTM", | |||
| "VarGRU" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |||
| @@ -17,25 +23,19 @@ except ImportError: | |||
| from ..utils import initial_parameter | |||
| __all__ = [ | |||
| "VarRNN", | |||
| "VarLSTM", | |||
| "VarGRU" | |||
| ] | |||
| class VarRnnCellWrapper(nn.Module): | |||
| """ | |||
| Wrapper for normal RNN Cells, make it support variational dropout | |||
| """ | |||
| def __init__(self, cell, hidden_size, input_p, hidden_p): | |||
| super(VarRnnCellWrapper, self).__init__() | |||
| self.cell = cell | |||
| self.hidden_size = hidden_size | |||
| self.input_p = input_p | |||
| self.hidden_p = hidden_p | |||
| def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | |||
| """ | |||
| :param PackedSequence input_x: [seq_len, batch_size, input_size] | |||
| @@ -47,13 +47,13 @@ class VarRnnCellWrapper(nn.Module): | |||
| hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||
| for other RNN, h_n, [batch_size, hidden_size] | |||
| """ | |||
| def get_hi(hi, h0, size): | |||
| h0_size = size - hi.size(0) | |||
| if h0_size > 0: | |||
| return torch.cat([hi, h0[:h0_size]], dim=0) | |||
| return hi[:size] | |||
| is_lstm = isinstance(hidden, tuple) | |||
| input, batch_sizes = input_x.data, input_x.batch_sizes | |||
| output = [] | |||
| @@ -64,7 +64,7 @@ class VarRnnCellWrapper(nn.Module): | |||
| else: | |||
| batch_iter = batch_sizes | |||
| idx = 0 | |||
| if is_lstm: | |||
| hn = (hidden[0].clone(), hidden[1].clone()) | |||
| else: | |||
| @@ -91,7 +91,7 @@ class VarRnnCellWrapper(nn.Module): | |||
| hi = cell(input_i, hi) | |||
| hn[:size] = hi | |||
| output.append(hi) | |||
| if is_reversed: | |||
| output = list(reversed(output)) | |||
| output = torch.cat(output, dim=0) | |||
| @@ -117,7 +117,7 @@ class VarRNNBase(nn.Module): | |||
| :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
| :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
| """ | |||
| def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||
| bias=True, batch_first=False, | |||
| input_dropout=0, hidden_dropout=0, bidirectional=False): | |||
| @@ -141,7 +141,7 @@ class VarRNNBase(nn.Module): | |||
| cell, self.hidden_size, input_dropout, hidden_dropout)) | |||
| initial_parameter(self) | |||
| self.is_lstm = (self.mode == "LSTM") | |||
| def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | |||
| is_lstm = self.is_lstm | |||
| idx = self.num_directions * n_layer + n_direction | |||
| @@ -150,7 +150,7 @@ class VarRNNBase(nn.Module): | |||
| output_x, hidden_x = cell( | |||
| input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||
| return output_x, hidden_x | |||
| def forward(self, x, hx=None): | |||
| """ | |||
| @@ -170,13 +170,13 @@ class VarRNNBase(nn.Module): | |||
| else: | |||
| max_batch_size = int(x.batch_sizes[0]) | |||
| x, batch_sizes = x.data, x.batch_sizes | |||
| if hx is None: | |||
| hx = x.new_zeros(self.num_layers * self.num_directions, | |||
| max_batch_size, self.hidden_size, requires_grad=True) | |||
| if is_lstm: | |||
| hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | |||
| mask_x = x.new_ones((max_batch_size, self.input_size)) | |||
| mask_out = x.new_ones( | |||
| (max_batch_size, self.hidden_size * self.num_directions)) | |||
| @@ -185,7 +185,7 @@ class VarRNNBase(nn.Module): | |||
| training=self.training, inplace=True) | |||
| nn.functional.dropout(mask_out, p=self.hidden_dropout, | |||
| training=self.training, inplace=True) | |||
| hidden = x.new_zeros( | |||
| (self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||
| if is_lstm: | |||
| @@ -207,16 +207,16 @@ class VarRNNBase(nn.Module): | |||
| else: | |||
| hidden[idx] = hidden_x | |||
| x = torch.cat(output_list, dim=-1) | |||
| if is_lstm: | |||
| hidden = (hidden, cellstate) | |||
| if is_packed: | |||
| output = PackedSequence(x, batch_sizes) | |||
| else: | |||
| x = PackedSequence(x, batch_sizes) | |||
| output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | |||
| return output, hidden | |||
| @@ -236,11 +236,11 @@ class VarLSTM(VarRNNBase): | |||
| :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
| :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super(VarLSTM, self).__init__( | |||
| mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||
| def forward(self, x, hx=None): | |||
| return super(VarLSTM, self).forward(x, hx) | |||
| @@ -261,11 +261,11 @@ class VarRNN(VarRNNBase): | |||
| :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
| :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super(VarRNN, self).__init__( | |||
| mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||
| def forward(self, x, hx=None): | |||
| return super(VarRNN, self).forward(x, hx) | |||
| @@ -286,10 +286,10 @@ class VarGRU(VarRNNBase): | |||
| :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
| :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super(VarGRU, self).__init__( | |||
| mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||
| def forward(self, x, hx=None): | |||
| return super(VarGRU, self).forward(x, hx) | |||
| @@ -1,5 +1,5 @@ | |||
| from functools import reduce | |||
| from collections import OrderedDict | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||