add biaffine dependency parser & some modulestags/v0.2.0
| @@ -93,5 +93,35 @@ class LabelField(Field): | |||||
| return torch.LongTensor([self._index]) | return torch.LongTensor([self._index]) | ||||
| class SeqLabelField(Field): | |||||
| def __init__(self, label_seq, is_target=True): | |||||
| super(SeqLabelField, self).__init__(is_target) | |||||
| self.label_seq = label_seq | |||||
| self._index = None | |||||
| def get_length(self): | |||||
| return len(self.label_seq) | |||||
| def index(self, vocab): | |||||
| if self._index is None: | |||||
| self._index = [vocab[c] for c in self.label_seq] | |||||
| return self._index | |||||
| def to_tensor(self, padding_length): | |||||
| pads = [0] * (padding_length - self.get_length()) | |||||
| if self._index is None: | |||||
| if self.get_length() == 0: | |||||
| return torch.LongTensor(pads) | |||||
| elif isinstance(self.label_seq[0], int): | |||||
| return torch.LongTensor(self.label_seq + pads) | |||||
| elif isinstance(self.label_seq[0], str): | |||||
| raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
| else: | |||||
| raise RuntimeError( | |||||
| "Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
| else: | |||||
| return torch.LongTensor(self._index + pads) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| tf = TextField("test the code".split(), is_target=False) | tf = TextField("test the code".split(), is_target=False) | ||||
| @@ -18,6 +18,15 @@ def isiterable(p_object): | |||||
| return False | return False | ||||
| return True | return True | ||||
| def check_build_vocab(func): | |||||
| def _wrapper(self, *args, **kwargs): | |||||
| if self.word2idx is None: | |||||
| self.build_vocab() | |||||
| self.build_reverse_vocab() | |||||
| elif self.idx2word is None: | |||||
| self.build_reverse_vocab() | |||||
| return func(self, *args, **kwargs) | |||||
| return _wrapper | |||||
| class Vocabulary(object): | class Vocabulary(object): | ||||
| """Use for word and index one to one mapping | """Use for word and index one to one mapping | ||||
| @@ -30,30 +39,23 @@ class Vocabulary(object): | |||||
| vocab["word"] | vocab["word"] | ||||
| vocab.to_word(5) | vocab.to_word(5) | ||||
| """ | """ | ||||
| def __init__(self, need_default=True): | |||||
| def __init__(self, need_default=True, max_size=None, min_freq=None): | |||||
| """ | """ | ||||
| :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | ||||
| :param int max_size: set the max number of words in Vocabulary. Default: None | |||||
| :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | |||||
| """ | """ | ||||
| if need_default: | |||||
| self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||||
| self.padding_label = DEFAULT_PADDING_LABEL | |||||
| self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
| else: | |||||
| self.word2idx = {} | |||||
| self.padding_label = None | |||||
| self.unknown_label = None | |||||
| self.max_size = max_size | |||||
| self.min_freq = min_freq | |||||
| self.word_count = {} | |||||
| self.has_default = need_default | self.has_default = need_default | ||||
| self.word2idx = None | |||||
| self.idx2word = None | self.idx2word = None | ||||
| def __len__(self): | |||||
| return len(self.word2idx) | |||||
| def update(self, word): | def update(self, word): | ||||
| """add word or list of words into Vocabulary | """add word or list of words into Vocabulary | ||||
| :param word: a list of string or a single string | :param word: a list of string or a single string | ||||
| """ | """ | ||||
| if not isinstance(word, str) and isiterable(word): | if not isinstance(word, str) and isiterable(word): | ||||
| @@ -61,12 +63,48 @@ class Vocabulary(object): | |||||
| for w in word: | for w in word: | ||||
| self.update(w) | self.update(w) | ||||
| else: | else: | ||||
| # it's a word to be added | |||||
| if word not in self.word2idx: | |||||
| self.word2idx[word] = len(self) | |||||
| if self.idx2word is not None: | |||||
| self.idx2word = None | |||||
| # it's a word to be added | |||||
| if word not in self.word_count: | |||||
| self.word_count[word] = 1 | |||||
| else: | |||||
| self.word_count[word] += 1 | |||||
| self.word2idx = None | |||||
| def build_vocab(self): | |||||
| """build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||||
| """ | |||||
| if self.has_default: | |||||
| self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||||
| self.padding_label = DEFAULT_PADDING_LABEL | |||||
| self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
| else: | |||||
| self.word2idx = {} | |||||
| self.padding_label = None | |||||
| self.unknown_label = None | |||||
| words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | |||||
| if self.min_freq is not None: | |||||
| words = list(filter(lambda kv: kv[1] >= self.min_freq, words)) | |||||
| if self.max_size is not None and len(words) > self.max_size: | |||||
| words = words[:self.max_size] | |||||
| for w, _ in words: | |||||
| self.word2idx[w] = len(self.word2idx) | |||||
| def build_reverse_vocab(self): | |||||
| """build 'index to word' dict based on 'word to index' dict | |||||
| """ | |||||
| self.idx2word = {self.word2idx[w] : w for w in self.word2idx} | |||||
| @check_build_vocab | |||||
| def __len__(self): | |||||
| return len(self.word2idx) | |||||
| @check_build_vocab | |||||
| def has_word(self, w): | |||||
| return w in self.word2idx | |||||
| @check_build_vocab | |||||
| def __getitem__(self, w): | def __getitem__(self, w): | ||||
| """To support usage like:: | """To support usage like:: | ||||
| @@ -74,32 +112,35 @@ class Vocabulary(object): | |||||
| """ | """ | ||||
| if w in self.word2idx: | if w in self.word2idx: | ||||
| return self.word2idx[w] | return self.word2idx[w] | ||||
| else: | |||||
| elif self.has_default: | |||||
| return self.word2idx[DEFAULT_UNKNOWN_LABEL] | return self.word2idx[DEFAULT_UNKNOWN_LABEL] | ||||
| else: | |||||
| raise ValueError("word {} not in vocabulary".format(w)) | |||||
| @check_build_vocab | |||||
| def to_index(self, w): | def to_index(self, w): | ||||
| """ like to_index(w) function, turn a word to the index | """ like to_index(w) function, turn a word to the index | ||||
| if w is not in Vocabulary, return the unknown label | if w is not in Vocabulary, return the unknown label | ||||
| :param str w: | :param str w: | ||||
| """ | """ | ||||
| return self[w] | return self[w] | ||||
| @property | |||||
| @check_build_vocab | |||||
| def unknown_idx(self): | def unknown_idx(self): | ||||
| if self.unknown_label is None: | if self.unknown_label is None: | ||||
| return None | return None | ||||
| return self.word2idx[self.unknown_label] | return self.word2idx[self.unknown_label] | ||||
| @property | |||||
| @check_build_vocab | |||||
| def padding_idx(self): | def padding_idx(self): | ||||
| if self.padding_label is None: | if self.padding_label is None: | ||||
| return None | return None | ||||
| return self.word2idx[self.padding_label] | return self.word2idx[self.padding_label] | ||||
| def build_reverse_vocab(self): | |||||
| """build 'index to word' dict based on 'word to index' dict | |||||
| """ | |||||
| self.idx2word = {self.word2idx[w]: w for w in self.word2idx} | |||||
| @check_build_vocab | |||||
| def to_word(self, idx): | def to_word(self, idx): | ||||
| """given a word's index, return the word itself | """given a word's index, return the word itself | ||||
| @@ -8,9 +8,10 @@ from fastNLP.loader.base_loader import BaseLoader | |||||
| class ConfigLoader(BaseLoader): | class ConfigLoader(BaseLoader): | ||||
| """loader for configuration files""" | """loader for configuration files""" | ||||
| def __int__(self, data_path): | |||||
| def __init__(self, data_path=None): | |||||
| super(ConfigLoader, self).__init__() | super(ConfigLoader, self).__init__() | ||||
| self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
| if data_path is not None: | |||||
| self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
| @staticmethod | @staticmethod | ||||
| def parse(string): | def parse(string): | ||||
| @@ -1,10 +1,10 @@ | |||||
| import _pickle | import _pickle | ||||
| import os | import os | ||||
| import numpy as np | |||||
| import torch | |||||
| from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
| """docstring for EmbedLoader""" | """docstring for EmbedLoader""" | ||||
| @@ -13,38 +13,72 @@ class EmbedLoader(BaseLoader): | |||||
| super(EmbedLoader, self).__init__(data_path) | super(EmbedLoader, self).__init__(data_path) | ||||
| @staticmethod | @staticmethod | ||||
| def load_embedding(emb_dim, emb_file, word_dict, emb_pkl): | |||||
| def _load_glove(emb_file): | |||||
| """Read file as a glove embedding | |||||
| file format: | |||||
| embeddings are split by line, | |||||
| for one embedding, word and numbers split by space | |||||
| Example:: | |||||
| word_1 float_1 float_2 ... float_emb_dim | |||||
| word_2 float_1 float_2 ... float_emb_dim | |||||
| ... | |||||
| """ | |||||
| emb = {} | |||||
| with open(emb_file, 'r', encoding='utf-8') as f: | |||||
| for line in f: | |||||
| line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) | |||||
| if len(line) > 0: | |||||
| emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | |||||
| return emb | |||||
| @staticmethod | |||||
| def _load_pretrain(emb_file, emb_type): | |||||
| """Read txt data from embedding file and convert to np.array as pre-trained embedding | |||||
| :param emb_file: str, the pre-trained embedding file path | |||||
| :param emb_type: str, the pre-trained embedding data format | |||||
| :return dict: {str: np.array} | |||||
| """ | |||||
| if emb_type == 'glove': | |||||
| return EmbedLoader._load_glove(emb_file) | |||||
| else: | |||||
| raise Exception("embedding type {} not support yet".format(emb_type)) | |||||
| @staticmethod | |||||
| def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl): | |||||
| """Load the pre-trained embedding and combine with the given dictionary. | """Load the pre-trained embedding and combine with the given dictionary. | ||||
| :param emb_file: str, the pre-trained embedding. | |||||
| The embedding file should have the following format: | |||||
| Each line is a word embedding, where a word string is followed by multiple floats. | |||||
| Floats are separated by space. The word and the first float are separated by space. | |||||
| :param word_dict: dict, a mapping from word to index. | |||||
| :param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. | :param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. | ||||
| :param emb_file: str, the pre-trained embedding file path. | |||||
| :param emb_type: str, the pre-trained embedding format, support glove now | |||||
| :param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
| :param emb_pkl: str, the embedding pickle file. | :param emb_pkl: str, the embedding pickle file. | ||||
| :return embedding_np: numpy array of shape (len(word_dict), emb_dim) | :return embedding_np: numpy array of shape (len(word_dict), emb_dim) | ||||
| vocab: input vocab or vocab built by pre-train | |||||
| TODO: fragile code | TODO: fragile code | ||||
| """ | """ | ||||
| # If the embedding pickle exists, load it and return. | # If the embedding pickle exists, load it and return. | ||||
| if os.path.exists(emb_pkl): | if os.path.exists(emb_pkl): | ||||
| with open(emb_pkl, "rb") as f: | with open(emb_pkl, "rb") as f: | ||||
| embedding_np = _pickle.load(f) | |||||
| return embedding_np | |||||
| embedding_np, vocab = _pickle.load(f) | |||||
| return embedding_np, vocab | |||||
| # Otherwise, load the pre-trained embedding. | # Otherwise, load the pre-trained embedding. | ||||
| with open(emb_file, "r", encoding="utf-8") as f: | |||||
| # begin with a random embedding | |||||
| embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim)) | |||||
| for line in f: | |||||
| line = line.strip().split() | |||||
| if len(line) != emb_dim + 1: | |||||
| # skip this line if two embedding dimension not match | |||||
| continue | |||||
| if line[0] in word_dict: | |||||
| # find the word and replace its embedding with a pre-trained one | |||||
| embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]] | |||||
| pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | |||||
| if vocab is None: | |||||
| # build vocabulary from pre-trained embedding | |||||
| vocab = Vocabulary() | |||||
| for w in pretrain.keys(): | |||||
| vocab.update(w) | |||||
| embedding_np = torch.randn(len(vocab), emb_dim) | |||||
| for w, v in pretrain.items(): | |||||
| if len(v.shape) > 1 or emb_dim != v.shape[0]: | |||||
| raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) | |||||
| if vocab.has_word(w): | |||||
| embedding_np[vocab[w]] = v | |||||
| # save and return the result | # save and return the result | ||||
| with open(emb_pkl, "wb") as f: | with open(emb_pkl, "wb") as f: | ||||
| _pickle.dump(embedding_np, f) | |||||
| return embedding_np | |||||
| _pickle.dump((embedding_np, vocab), f) | |||||
| return embedding_np, vocab | |||||
| @@ -0,0 +1,364 @@ | |||||
| import sys, os | |||||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
| import copy | |||||
| import numpy as np | |||||
| import torch | |||||
| from collections import defaultdict | |||||
| from torch import nn | |||||
| from torch.nn import functional as F | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
| from fastNLP.modules.dropout import TimestepDropout | |||||
| def mst(scores): | |||||
| """ | |||||
| with some modification to support parser output for MST decoding | |||||
| https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | |||||
| """ | |||||
| length = scores.shape[0] | |||||
| min_score = -np.inf | |||||
| mask = np.zeros((length, length)) | |||||
| np.fill_diagonal(mask, -np.inf) | |||||
| scores = scores + mask | |||||
| heads = np.argmax(scores, axis=1) | |||||
| heads[0] = 0 | |||||
| tokens = np.arange(1, length) | |||||
| roots = np.where(heads[tokens] == 0)[0] + 1 | |||||
| if len(roots) < 1: | |||||
| root_scores = scores[tokens, 0] | |||||
| head_scores = scores[tokens, heads[tokens]] | |||||
| new_root = tokens[np.argmax(root_scores / head_scores)] | |||||
| heads[new_root] = 0 | |||||
| elif len(roots) > 1: | |||||
| root_scores = scores[roots, 0] | |||||
| scores[roots, 0] = 0 | |||||
| new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1 | |||||
| new_root = roots[np.argmin( | |||||
| scores[roots, new_heads] / root_scores)] | |||||
| heads[roots] = new_heads | |||||
| heads[new_root] = 0 | |||||
| edges = defaultdict(set) | |||||
| vertices = set((0,)) | |||||
| for dep, head in enumerate(heads[tokens]): | |||||
| vertices.add(dep + 1) | |||||
| edges[head].add(dep + 1) | |||||
| for cycle in _find_cycle(vertices, edges): | |||||
| dependents = set() | |||||
| to_visit = set(cycle) | |||||
| while len(to_visit) > 0: | |||||
| node = to_visit.pop() | |||||
| if node not in dependents: | |||||
| dependents.add(node) | |||||
| to_visit.update(edges[node]) | |||||
| cycle = np.array(list(cycle)) | |||||
| old_heads = heads[cycle] | |||||
| old_scores = scores[cycle, old_heads] | |||||
| non_heads = np.array(list(dependents)) | |||||
| scores[np.repeat(cycle, len(non_heads)), | |||||
| np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score | |||||
| new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1 | |||||
| new_scores = scores[cycle, new_heads] / old_scores | |||||
| change = np.argmax(new_scores) | |||||
| changed_cycle = cycle[change] | |||||
| old_head = old_heads[change] | |||||
| new_head = new_heads[change] | |||||
| heads[changed_cycle] = new_head | |||||
| edges[new_head].add(changed_cycle) | |||||
| edges[old_head].remove(changed_cycle) | |||||
| return heads | |||||
| def _find_cycle(vertices, edges): | |||||
| """ | |||||
| https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm | |||||
| https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py | |||||
| """ | |||||
| _index = 0 | |||||
| _stack = [] | |||||
| _indices = {} | |||||
| _lowlinks = {} | |||||
| _onstack = defaultdict(lambda: False) | |||||
| _SCCs = [] | |||||
| def _strongconnect(v): | |||||
| nonlocal _index | |||||
| _indices[v] = _index | |||||
| _lowlinks[v] = _index | |||||
| _index += 1 | |||||
| _stack.append(v) | |||||
| _onstack[v] = True | |||||
| for w in edges[v]: | |||||
| if w not in _indices: | |||||
| _strongconnect(w) | |||||
| _lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) | |||||
| elif _onstack[w]: | |||||
| _lowlinks[v] = min(_lowlinks[v], _indices[w]) | |||||
| if _lowlinks[v] == _indices[v]: | |||||
| SCC = set() | |||||
| while True: | |||||
| w = _stack.pop() | |||||
| _onstack[w] = False | |||||
| SCC.add(w) | |||||
| if not(w != v): | |||||
| break | |||||
| _SCCs.append(SCC) | |||||
| for v in vertices: | |||||
| if v not in _indices: | |||||
| _strongconnect(v) | |||||
| return [SCC for SCC in _SCCs if len(SCC) > 1] | |||||
| class GraphParser(nn.Module): | |||||
| """Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GraphParser, self).__init__() | |||||
| def forward(self, x): | |||||
| raise NotImplementedError | |||||
| def _greedy_decoder(self, arc_matrix, seq_mask=None): | |||||
| _, seq_len, _ = arc_matrix.shape | |||||
| matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | |||||
| _, heads = torch.max(matrix, dim=2) | |||||
| if seq_mask is not None: | |||||
| heads *= seq_mask.long() | |||||
| return heads | |||||
| def _mst_decoder(self, arc_matrix, seq_mask=None): | |||||
| batch_size, seq_len, _ = arc_matrix.shape | |||||
| matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | |||||
| ans = matrix.new_zeros(batch_size, seq_len).long() | |||||
| for i, graph in enumerate(matrix): | |||||
| ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
| if seq_mask is not None: | |||||
| ans *= seq_mask.long() | |||||
| return ans | |||||
| class ArcBiaffine(nn.Module): | |||||
| """helper module for Biaffine Dependency Parser predicting arc | |||||
| """ | |||||
| def __init__(self, hidden_size, bias=True): | |||||
| super(ArcBiaffine, self).__init__() | |||||
| self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) | |||||
| self.has_bias = bias | |||||
| if self.has_bias: | |||||
| self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True) | |||||
| else: | |||||
| self.register_parameter("bias", None) | |||||
| initial_parameter(self) | |||||
| def forward(self, head, dep): | |||||
| """ | |||||
| :param head arc-head tensor = [batch, length, emb_dim] | |||||
| :param dep arc-dependent tensor = [batch, length, emb_dim] | |||||
| :return output tensor = [bacth, length, length] | |||||
| """ | |||||
| output = dep.matmul(self.U) | |||||
| output = output.bmm(head.transpose(-1, -2)) | |||||
| if self.has_bias: | |||||
| output += head.matmul(self.bias).unsqueeze(1) | |||||
| return output | |||||
| class LabelBilinear(nn.Module): | |||||
| """helper module for Biaffine Dependency Parser predicting label | |||||
| """ | |||||
| def __init__(self, in1_features, in2_features, num_label, bias=True): | |||||
| super(LabelBilinear, self).__init__() | |||||
| self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | |||||
| self.lin1 = nn.Linear(in1_features, num_label, bias=False) | |||||
| self.lin2 = nn.Linear(in2_features, num_label, bias=False) | |||||
| def forward(self, x1, x2): | |||||
| output = self.bilinear(x1, x2) | |||||
| output += self.lin1(x1) + self.lin2(x2) | |||||
| return output | |||||
| class BiaffineParser(GraphParser): | |||||
| """Biaffine Dependency Parser implemantation. | |||||
| refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | |||||
| <https://arxiv.org/abs/1611.01734>`_ . | |||||
| """ | |||||
| def __init__(self, | |||||
| word_vocab_size, | |||||
| word_emb_dim, | |||||
| pos_vocab_size, | |||||
| pos_emb_dim, | |||||
| rnn_layers, | |||||
| rnn_hidden_size, | |||||
| arc_mlp_size, | |||||
| label_mlp_size, | |||||
| num_label, | |||||
| dropout, | |||||
| use_var_lstm=False, | |||||
| use_greedy_infer=False): | |||||
| super(BiaffineParser, self).__init__() | |||||
| self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | |||||
| self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | |||||
| if use_var_lstm: | |||||
| self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
| hidden_size=rnn_hidden_size, | |||||
| num_layers=rnn_layers, | |||||
| bias=True, | |||||
| batch_first=True, | |||||
| input_dropout=dropout, | |||||
| hidden_dropout=dropout, | |||||
| bidirectional=True) | |||||
| else: | |||||
| self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
| hidden_size=rnn_hidden_size, | |||||
| num_layers=rnn_layers, | |||||
| bias=True, | |||||
| batch_first=True, | |||||
| dropout=dropout, | |||||
| bidirectional=True) | |||||
| rnn_out_size = 2 * rnn_hidden_size | |||||
| self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | |||||
| nn.ELU()) | |||||
| self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | |||||
| self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | |||||
| nn.ELU()) | |||||
| self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | |||||
| self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||||
| self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||||
| self.normal_dropout = nn.Dropout(p=dropout) | |||||
| self.timestep_dropout = TimestepDropout(p=dropout) | |||||
| self.use_greedy_infer = use_greedy_infer | |||||
| initial_parameter(self) | |||||
| def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_): | |||||
| """ | |||||
| :param word_seq: [batch_size, seq_len] sequence of word's indices | |||||
| :param pos_seq: [batch_size, seq_len] sequence of word's indices | |||||
| :param seq_mask: [batch_size, seq_len] sequence of length masks | |||||
| :param gold_heads: [batch_size, seq_len] sequence of golden heads | |||||
| :return dict: parsing results | |||||
| arc_pred: [batch_size, seq_len, seq_len] | |||||
| label_pred: [batch_size, seq_len, seq_len] | |||||
| seq_mask: [batch_size, seq_len] | |||||
| head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||||
| """ | |||||
| # prepare embeddings | |||||
| batch_size, seq_len = word_seq.shape | |||||
| # print('forward {} {}'.format(batch_size, seq_len)) | |||||
| batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
| # get sequence mask | |||||
| seq_mask = seq_mask.long() | |||||
| word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | |||||
| pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | |||||
| x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||||
| # lstm, extract features | |||||
| feat, _ = self.lstm(x) # -> [N,L,C] | |||||
| # for arc biaffine | |||||
| # mlp, reduce dim | |||||
| arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) | |||||
| arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) | |||||
| label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) | |||||
| label_head = self.timestep_dropout(self.label_head_mlp(feat)) | |||||
| # biaffine arc classifier | |||||
| arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||||
| flip_mask = (seq_mask == 0) | |||||
| arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
| # use gold or predicted arc to predict label | |||||
| if gold_heads is None: | |||||
| # use greedy decoding in training | |||||
| if self.training or self.use_greedy_infer: | |||||
| heads = self._greedy_decoder(arc_pred, seq_mask) | |||||
| else: | |||||
| heads = self._mst_decoder(arc_pred, seq_mask) | |||||
| head_pred = heads | |||||
| else: | |||||
| head_pred = None | |||||
| heads = gold_heads | |||||
| label_head = label_head[batch_range, heads].contiguous() | |||||
| label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||||
| res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | |||||
| if head_pred is not None: | |||||
| res_dict['head_pred'] = head_pred | |||||
| return res_dict | |||||
| def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
| """ | |||||
| Compute loss. | |||||
| :param arc_pred: [batch_size, seq_len, seq_len] | |||||
| :param label_pred: [batch_size, seq_len, seq_len] | |||||
| :param head_indices: [batch_size, seq_len] | |||||
| :param head_labels: [batch_size, seq_len] | |||||
| :param seq_mask: [batch_size, seq_len] | |||||
| :return: loss value | |||||
| """ | |||||
| batch_size, seq_len, _ = arc_pred.shape | |||||
| arc_logits = F.log_softmax(arc_pred, dim=2) | |||||
| label_logits = F.log_softmax(label_pred, dim=2) | |||||
| batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) | |||||
| child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) | |||||
| arc_loss = arc_logits[batch_index, child_index, head_indices] | |||||
| label_loss = label_logits[batch_index, child_index, head_labels] | |||||
| arc_loss = arc_loss[:, 1:] | |||||
| label_loss = label_loss[:, 1:] | |||||
| float_mask = seq_mask[:, 1:].float() | |||||
| length = (seq_mask.sum() - batch_size).float() | |||||
| arc_nll = -(arc_loss*float_mask).sum() / length | |||||
| label_nll = -(label_loss*float_mask).sum() / length | |||||
| return arc_nll + label_nll | |||||
| def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs): | |||||
| """ | |||||
| Evaluate the performance of prediction. | |||||
| :return dict: performance results. | |||||
| head_pred_corrct: number of correct predicted heads. | |||||
| label_pred_correct: number of correct predicted labels. | |||||
| total_tokens: number of predicted tokens | |||||
| """ | |||||
| if 'head_pred' in kwargs: | |||||
| head_pred = kwargs['head_pred'] | |||||
| elif self.use_greedy_infer: | |||||
| head_pred = self._greedy_decoder(arc_pred, seq_mask) | |||||
| else: | |||||
| head_pred = self._mst_decoder(arc_pred, seq_mask) | |||||
| head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
| _, label_preds = torch.max(label_pred, dim=2) | |||||
| label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
| return {"head_pred_correct": head_pred_correct.sum(dim=1), | |||||
| "label_pred_correct": label_pred_correct.sum(dim=1), | |||||
| "total_tokens": seq_mask.sum(dim=1)} | |||||
| def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_): | |||||
| """ | |||||
| Compute the metrics of model | |||||
| :param head_pred_corrct: number of correct predicted heads. | |||||
| :param label_pred_correct: number of correct predicted labels. | |||||
| :param total_tokens: number of predicted tokens | |||||
| :return dict: the metrics results | |||||
| UAS: the head predicted accuracy | |||||
| LAS: the label predicted accuracy | |||||
| """ | |||||
| return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100, | |||||
| "LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100} | |||||
| @@ -0,0 +1,15 @@ | |||||
| import torch | |||||
| class TimestepDropout(torch.nn.Dropout): | |||||
| """This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single | |||||
| dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step. | |||||
| """ | |||||
| def forward(self, x): | |||||
| dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | |||||
| torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | |||||
| dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||||
| if self.inplace: | |||||
| x *= dropout_mask | |||||
| return | |||||
| else: | |||||
| return x * dropout_mask | |||||
| @@ -2,384 +2,153 @@ import math | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | |||||
| from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | |||||
| from torch.nn.parameter import Parameter | |||||
| from torch.nn.utils.rnn import PackedSequence | |||||
| from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
| def default_initializer(hidden_size): | |||||
| stdv = 1.0 / math.sqrt(hidden_size) | |||||
| def forward(tensor): | |||||
| nn.init.uniform_(tensor, -stdv, stdv) | |||||
| return forward | |||||
| def VarMaskedRecurrent(reverse=False): | |||||
| def forward(input, hidden, cell, mask): | |||||
| output = [] | |||||
| steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) | |||||
| for i in steps: | |||||
| if mask is None or mask[i].data.min() > 0.5: | |||||
| hidden = cell(input[i], hidden) | |||||
| elif mask[i].data.max() > 0.5: | |||||
| hidden_next = cell(input[i], hidden) | |||||
| # hack to handle LSTM | |||||
| if isinstance(hidden, tuple): | |||||
| hx, cx = hidden | |||||
| hp1, cp1 = hidden_next | |||||
| hidden = (hx + (hp1 - hx) * mask[i], cx + (cp1 - cx) * mask[i]) | |||||
| else: | |||||
| hidden = hidden + (hidden_next - hidden) * mask[i] | |||||
| # hack to handle LSTM | |||||
| output.append(hidden[0] if isinstance(hidden, tuple) else hidden) | |||||
| if reverse: | |||||
| output.reverse() | |||||
| output = torch.cat(output, 0).view(input.size(0), *output[0].size()) | |||||
| return hidden, output | |||||
| return forward | |||||
| def StackedRNN(inners, num_layers, lstm=False): | |||||
| num_directions = len(inners) | |||||
| total_layers = num_layers * num_directions | |||||
| def forward(input, hidden, cells, mask): | |||||
| assert (len(cells) == total_layers) | |||||
| next_hidden = [] | |||||
| if lstm: | |||||
| hidden = list(zip(*hidden)) | |||||
| for i in range(num_layers): | |||||
| all_output = [] | |||||
| for j, inner in enumerate(inners): | |||||
| l = i * num_directions + j | |||||
| hy, output = inner(input, hidden[l], cells[l], mask) | |||||
| next_hidden.append(hy) | |||||
| all_output.append(output) | |||||
| input = torch.cat(all_output, input.dim() - 1) | |||||
| if lstm: | |||||
| next_h, next_c = zip(*next_hidden) | |||||
| next_hidden = ( | |||||
| torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), | |||||
| torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) | |||||
| ) | |||||
| else: | |||||
| next_hidden = torch.cat(next_hidden, 0).view(total_layers, *next_hidden[0].size()) | |||||
| return next_hidden, input | |||||
| return forward | |||||
| def AutogradVarMaskedRNN(num_layers=1, batch_first=False, bidirectional=False, lstm=False): | |||||
| rec_factory = VarMaskedRecurrent | |||||
| if bidirectional: | |||||
| layer = (rec_factory(), rec_factory(reverse=True)) | |||||
| else: | |||||
| layer = (rec_factory(),) | |||||
| func = StackedRNN(layer, | |||||
| num_layers, | |||||
| lstm=lstm) | |||||
| def forward(input, cells, hidden, mask): | |||||
| if batch_first: | |||||
| input = input.transpose(0, 1) | |||||
| if mask is not None: | |||||
| mask = mask.transpose(0, 1) | |||||
| nexth, output = func(input, hidden, cells, mask) | |||||
| if batch_first: | |||||
| output = output.transpose(0, 1) | |||||
| return output, nexth | |||||
| return forward | |||||
| try: | |||||
| from torch import flip | |||||
| except ImportError: | |||||
| def flip(x, dims): | |||||
| indices = [slice(None)] * x.dim() | |||||
| for dim in dims: | |||||
| indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | |||||
| return x[tuple(indices)] | |||||
| class VarRnnCellWrapper(nn.Module): | |||||
| """Wrapper for normal RNN Cells, make it support variational dropout | |||||
| """ | |||||
| def __init__(self, cell, hidden_size, input_p, hidden_p): | |||||
| super(VarRnnCellWrapper, self).__init__() | |||||
| self.cell = cell | |||||
| self.hidden_size = hidden_size | |||||
| self.input_p = input_p | |||||
| self.hidden_p = hidden_p | |||||
| def VarMaskedStep(): | |||||
| def forward(input, hidden, cell, mask): | |||||
| if mask is None or mask.data.min() > 0.5: | |||||
| hidden = cell(input, hidden) | |||||
| elif mask.data.max() > 0.5: | |||||
| hidden_next = cell(input, hidden) | |||||
| # hack to handle LSTM | |||||
| if isinstance(hidden, tuple): | |||||
| def forward(self, input, hidden, mask_x=None, mask_h=None): | |||||
| """ | |||||
| :param input: [seq_len, batch_size, input_size] | |||||
| :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | |||||
| for other RNN, h_0, [batch_size, hidden_size] | |||||
| :param mask_x: [batch_size, input_size] dropout mask for input | |||||
| :param mask_h: [batch_size, hidden_size] dropout mask for hidden | |||||
| :return output: [seq_len, bacth_size, hidden_size] | |||||
| hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||||
| for other RNN, h_n, [batch_size, hidden_size] | |||||
| """ | |||||
| is_lstm = isinstance(hidden, tuple) | |||||
| input = input * mask_x.unsqueeze(0) if mask_x is not None else input | |||||
| output_list = [] | |||||
| for x in input: | |||||
| if is_lstm: | |||||
| hx, cx = hidden | hx, cx = hidden | ||||
| hp1, cp1 = hidden_next | |||||
| hidden = (hx + (hp1 - hx) * mask, cx + (cp1 - cx) * mask) | |||||
| hidden = (hx * mask_h, cx) if mask_h is not None else (hx, cx) | |||||
| else: | else: | ||||
| hidden = hidden + (hidden_next - hidden) * mask | |||||
| # hack to handle LSTM | |||||
| output = hidden[0] if isinstance(hidden, tuple) else hidden | |||||
| return hidden, output | |||||
| return forward | |||||
| def StackedStep(layer, num_layers, lstm=False): | |||||
| def forward(input, hidden, cells, mask): | |||||
| assert (len(cells) == num_layers) | |||||
| next_hidden = [] | |||||
| if lstm: | |||||
| hidden = list(zip(*hidden)) | |||||
| for l in range(num_layers): | |||||
| hy, output = layer(input, hidden[l], cells[l], mask) | |||||
| next_hidden.append(hy) | |||||
| input = output | |||||
| if lstm: | |||||
| next_h, next_c = zip(*next_hidden) | |||||
| next_hidden = ( | |||||
| torch.cat(next_h, 0).view(num_layers, *next_h[0].size()), | |||||
| torch.cat(next_c, 0).view(num_layers, *next_c[0].size()) | |||||
| ) | |||||
| else: | |||||
| next_hidden = torch.cat(next_hidden, 0).view(num_layers, *next_hidden[0].size()) | |||||
| return next_hidden, input | |||||
| return forward | |||||
| def AutogradVarMaskedStep(num_layers=1, lstm=False): | |||||
| layer = VarMaskedStep() | |||||
| func = StackedStep(layer, | |||||
| num_layers, | |||||
| lstm=lstm) | |||||
| def forward(input, cells, hidden, mask): | |||||
| nexth, output = func(input, hidden, cells, mask) | |||||
| return output, nexth | |||||
| return forward | |||||
| hidden *= mask_h if mask_h is not None else hidden | |||||
| hidden = self.cell(x, hidden) | |||||
| output_list.append(hidden[0] if is_lstm else hidden) | |||||
| output = torch.stack(output_list, dim=0) | |||||
| return output, hidden | |||||
| class VarMaskedRNNBase(nn.Module): | |||||
| def __init__(self, Cell, input_size, hidden_size, | |||||
| num_layers=1, bias=True, batch_first=False, | |||||
| dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs): | |||||
| super(VarMaskedRNNBase, self).__init__() | |||||
| self.Cell = Cell | |||||
| class VarRNNBase(nn.Module): | |||||
| """Implementation of Variational Dropout RNN network. | |||||
| refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | |||||
| https://arxiv.org/abs/1512.05287`. | |||||
| """ | |||||
| def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||||
| bias=True, batch_first=False, | |||||
| input_dropout=0, hidden_dropout=0, bidirectional=False): | |||||
| super(VarRNNBase, self).__init__() | |||||
| self.mode = mode | |||||
| self.input_size = input_size | self.input_size = input_size | ||||
| self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
| self.num_layers = num_layers | self.num_layers = num_layers | ||||
| self.bias = bias | self.bias = bias | ||||
| self.batch_first = batch_first | self.batch_first = batch_first | ||||
| self.input_dropout = input_dropout | |||||
| self.hidden_dropout = hidden_dropout | |||||
| self.bidirectional = bidirectional | self.bidirectional = bidirectional | ||||
| self.lstm = False | |||||
| num_directions = 2 if bidirectional else 1 | |||||
| self.all_cells = [] | |||||
| for layer in range(num_layers): | |||||
| for direction in range(num_directions): | |||||
| layer_input_size = input_size if layer == 0 else hidden_size * num_directions | |||||
| cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs) | |||||
| self.all_cells.append(cell) | |||||
| self.add_module('cell%d' % (layer * num_directions + direction), cell) | |||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameters(self): | |||||
| for cell in self.all_cells: | |||||
| cell.reset_parameters() | |||||
| def reset_noise(self, batch_size): | |||||
| for cell in self.all_cells: | |||||
| cell.reset_noise(batch_size) | |||||
| self.num_directions = 2 if bidirectional else 1 | |||||
| self._all_cells = nn.ModuleList() | |||||
| for layer in range(self.num_layers): | |||||
| for direction in range(self.num_directions): | |||||
| input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | |||||
| cell = Cell(input_size, self.hidden_size, bias) | |||||
| self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | |||||
| initial_parameter(self) | |||||
| def forward(self, input, hx=None): | |||||
| is_packed = isinstance(input, PackedSequence) | |||||
| is_lstm = (self.mode == "LSTM") | |||||
| if is_packed: | |||||
| input, batch_sizes = input | |||||
| max_batch_size = int(batch_sizes[0]) | |||||
| else: | |||||
| batch_sizes = None | |||||
| max_batch_size = input.size(0) if self.batch_first else input.size(1) | |||||
| def forward(self, input, mask=None, hx=None): | |||||
| batch_size = input.size(0) if self.batch_first else input.size(1) | |||||
| if hx is None: | if hx is None: | ||||
| num_directions = 2 if self.bidirectional else 1 | |||||
| hx = torch.tensor(input.data.new(self.num_layers * num_directions, batch_size, self.hidden_size).zero_(), | |||||
| requires_grad=True) | |||||
| if self.lstm: | |||||
| hx = input.new_zeros(self.num_layers * self.num_directions, | |||||
| max_batch_size, self.hidden_size, | |||||
| requires_grad=False) | |||||
| if is_lstm: | |||||
| hx = (hx, hx) | hx = (hx, hx) | ||||
| func = AutogradVarMaskedRNN(num_layers=self.num_layers, | |||||
| batch_first=self.batch_first, | |||||
| bidirectional=self.bidirectional, | |||||
| lstm=self.lstm) | |||||
| self.reset_noise(batch_size) | |||||
| output, hidden = func(input, self.all_cells, hx, None if mask is None else mask.view(mask.size() + (1,))) | |||||
| return output, hidden | |||||
| def step(self, input, hx=None, mask=None): | |||||
| ''' | |||||
| execute one step forward (only for one-directional RNN). | |||||
| Args: | |||||
| input (batch, input_size): input tensor of this step. | |||||
| hx (num_layers, batch, hidden_size): the hidden state of last step. | |||||
| mask (batch): the mask tensor of this step. | |||||
| Returns: | |||||
| output (batch, hidden_size): tensor containing the output of this step from the last layer of RNN. | |||||
| hn (num_layers, batch, hidden_size): tensor containing the hidden state of this step | |||||
| ''' | |||||
| assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." | |||||
| batch_size = input.size(0) | |||||
| if hx is None: | |||||
| hx = torch.tensor(input.data.new(self.num_layers, batch_size, self.hidden_size).zero_(), requires_grad=True) | |||||
| if self.lstm: | |||||
| hx = (hx, hx) | |||||
| if self.batch_first: | |||||
| input = input.transpose(0, 1) | |||||
| batch_size = input.shape[1] | |||||
| mask_x = input.new_ones((batch_size, self.input_size)) | |||||
| mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | |||||
| mask_h = input.new_ones((batch_size, self.hidden_size)) | |||||
| nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | |||||
| nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
| nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
| hidden_list = [] | |||||
| for layer in range(self.num_layers): | |||||
| output_list = [] | |||||
| for direction in range(self.num_directions): | |||||
| input_x = input if direction == 0 else flip(input, [0]) | |||||
| idx = self.num_directions * layer + direction | |||||
| cell = self._all_cells[idx] | |||||
| hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||||
| mask_xi = mask_x if layer == 0 else mask_out | |||||
| output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) | |||||
| output_list.append(output_x if direction == 0 else flip(output_x, [0])) | |||||
| hidden_list.append(hidden_x) | |||||
| input = torch.cat(output_list, dim=-1) | |||||
| output = input.transpose(0, 1) if self.batch_first else input | |||||
| if is_lstm: | |||||
| h_list, c_list = zip(*hidden_list) | |||||
| hn = torch.stack(h_list, dim=0) | |||||
| cn = torch.stack(c_list, dim=0) | |||||
| hidden = (hn, cn) | |||||
| else: | |||||
| hidden = torch.stack(hidden_list, dim=0) | |||||
| func = AutogradVarMaskedStep(num_layers=self.num_layers, lstm=self.lstm) | |||||
| if is_packed: | |||||
| output = PackedSequence(output, batch_sizes) | |||||
| output, hidden = func(input, self.all_cells, hx, mask) | |||||
| return output, hidden | return output, hidden | ||||
| class VarMaskedFastLSTM(VarMaskedRNNBase): | |||||
| class VarLSTM(VarRNNBase): | |||||
| """Variational Dropout LSTM. | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
| super(VarMaskedFastLSTM, self).__init__(VarFastLSTMCell, *args, **kwargs) | |||||
| self.lstm = True | |||||
| class VarRNNCellBase(nn.Module): | |||||
| def __repr__(self): | |||||
| s = '{name}({input_size}, {hidden_size}' | |||||
| if 'bias' in self.__dict__ and self.bias is not True: | |||||
| s += ', bias={bias}' | |||||
| if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh": | |||||
| s += ', nonlinearity={nonlinearity}' | |||||
| s += ')' | |||||
| return s.format(name=self.__class__.__name__, **self.__dict__) | |||||
| super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||||
| def reset_noise(self, batch_size): | |||||
| """ | |||||
| Should be overriden by all subclasses. | |||||
| Args: | |||||
| batch_size: (int) batch size of input. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| class VarFastLSTMCell(VarRNNCellBase): | |||||
| """ | |||||
| A long short-term memory (LSTM) cell with variational dropout. | |||||
| .. math:: | |||||
| \begin{array}{ll} | |||||
| i = \mathrm{sigmoid}(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ | |||||
| f = \mathrm{sigmoid}(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ | |||||
| g = \tanh(W_{ig} x + b_{ig} + W_{hc} h + b_{hg}) \\ | |||||
| o = \mathrm{sigmoid}(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ | |||||
| c' = f * c + i * g \\ | |||||
| h' = o * \tanh(c') \\ | |||||
| \end{array} | |||||
| class VarRNN(VarRNNBase): | |||||
| """Variational Dropout RNN. | |||||
| """ | """ | ||||
| def __init__(self, *args, **kwargs): | |||||
| super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||||
| def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None): | |||||
| super(VarFastLSTMCell, self).__init__() | |||||
| self.input_size = input_size | |||||
| self.hidden_size = hidden_size | |||||
| self.bias = bias | |||||
| self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size)) | |||||
| self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size)) | |||||
| if bias: | |||||
| self.bias_ih = Parameter(torch.Tensor(4 * hidden_size)) | |||||
| self.bias_hh = Parameter(torch.Tensor(4 * hidden_size)) | |||||
| else: | |||||
| self.register_parameter('bias_ih', None) | |||||
| self.register_parameter('bias_hh', None) | |||||
| self.initializer = default_initializer(self.hidden_size) if initializer is None else initializer | |||||
| self.reset_parameters() | |||||
| p_in, p_hidden = p | |||||
| if p_in < 0 or p_in > 1: | |||||
| raise ValueError("input dropout probability has to be between 0 and 1, " | |||||
| "but got {}".format(p_in)) | |||||
| if p_hidden < 0 or p_hidden > 1: | |||||
| raise ValueError("hidden state dropout probability has to be between 0 and 1, " | |||||
| "but got {}".format(p_hidden)) | |||||
| self.p_in = p_in | |||||
| self.p_hidden = p_hidden | |||||
| self.noise_in = None | |||||
| self.noise_hidden = None | |||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameters(self): | |||||
| for weight in self.parameters(): | |||||
| if weight.dim() == 1: | |||||
| weight.data.zero_() | |||||
| else: | |||||
| self.initializer(weight.data) | |||||
| def reset_noise(self, batch_size): | |||||
| if self.training: | |||||
| if self.p_in: | |||||
| noise = self.weight_ih.data.new(batch_size, self.input_size) | |||||
| self.noise_in = torch.tensor(noise.bernoulli_(1.0 - self.p_in) / (1.0 - self.p_in)) | |||||
| else: | |||||
| self.noise_in = None | |||||
| if self.p_hidden: | |||||
| noise = self.weight_hh.data.new(batch_size, self.hidden_size) | |||||
| self.noise_hidden = torch.tensor(noise.bernoulli_(1.0 - self.p_hidden) / (1.0 - self.p_hidden)) | |||||
| else: | |||||
| self.noise_hidden = None | |||||
| else: | |||||
| self.noise_in = None | |||||
| self.noise_hidden = None | |||||
| def forward(self, input, hx): | |||||
| return self.__forward( | |||||
| input, hx, | |||||
| self.weight_ih, self.weight_hh, | |||||
| self.bias_ih, self.bias_hh, | |||||
| self.noise_in, self.noise_hidden, | |||||
| ) | |||||
| @staticmethod | |||||
| def __forward(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None): | |||||
| if noise_in is not None: | |||||
| if input.is_cuda: | |||||
| input = input * noise_in.cuda(input.get_device()) | |||||
| else: | |||||
| input = input * noise_in | |||||
| if input.is_cuda: | |||||
| w_ih = w_ih.cuda(input.get_device()) | |||||
| w_hh = w_hh.cuda(input.get_device()) | |||||
| hidden = [h.cuda(input.get_device()) for h in hidden] | |||||
| b_ih = b_ih.cuda(input.get_device()) | |||||
| b_hh = b_hh.cuda(input.get_device()) | |||||
| igates = F.linear(input, w_ih.cuda(input.get_device())) | |||||
| hgates = F.linear(hidden[0], w_hh) if noise_hidden is None \ | |||||
| else F.linear(hidden[0] * noise_hidden.cuda(input.get_device()), w_hh) | |||||
| state = fusedBackend.LSTMFused.apply | |||||
| # print("use backend") | |||||
| # use some magic function | |||||
| return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh) | |||||
| hx, cx = hidden | |||||
| if noise_hidden is not None: | |||||
| hx = hx * noise_hidden | |||||
| gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | |||||
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |||||
| ingate = F.sigmoid(ingate) | |||||
| forgetgate = F.sigmoid(forgetgate) | |||||
| cellgate = F.tanh(cellgate) | |||||
| outgate = F.sigmoid(outgate) | |||||
| cy = (forgetgate * cx) + (ingate * cellgate) | |||||
| hy = outgate * F.tanh(cy) | |||||
| return hy, cy | |||||
| class VarGRU(VarRNNBase): | |||||
| """Variational Dropout GRU. | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | |||||
| super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||||
| @@ -0,0 +1,37 @@ | |||||
| [train] | |||||
| epochs = 50 | |||||
| batch_size = 16 | |||||
| pickle_path = "./save/" | |||||
| validate = true | |||||
| save_best_dev = false | |||||
| use_cuda = true | |||||
| model_saved_path = "./save/" | |||||
| task = "parse" | |||||
| [test] | |||||
| save_output = true | |||||
| validate_in_training = true | |||||
| save_dev_input = false | |||||
| save_loss = true | |||||
| batch_size = 16 | |||||
| pickle_path = "./save/" | |||||
| use_cuda = true | |||||
| task = "parse" | |||||
| [model] | |||||
| word_vocab_size = -1 | |||||
| word_emb_dim = 100 | |||||
| pos_vocab_size = -1 | |||||
| pos_emb_dim = 100 | |||||
| rnn_layers = 3 | |||||
| rnn_hidden_size = 400 | |||||
| arc_mlp_size = 500 | |||||
| label_mlp_size = 100 | |||||
| num_label = -1 | |||||
| dropout = 0.33 | |||||
| use_var_lstm=true | |||||
| use_greedy_infer=false | |||||
| [optim] | |||||
| lr = 2e-3 | |||||
| @@ -0,0 +1,260 @@ | |||||
| import os | |||||
| import sys | |||||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
| from collections import defaultdict | |||||
| import math | |||||
| import torch | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP.core.instance import Instance | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.batch import Batch | |||||
| from fastNLP.core.sampler import SequentialSampler | |||||
| from fastNLP.core.field import TextField, SeqLabelField | |||||
| from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
| from fastNLP.core.tester import Tester | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.loader.embed_loader import EmbedLoader | |||||
| from fastNLP.models.biaffine_parser import BiaffineParser | |||||
| from fastNLP.saver.model_saver import ModelSaver | |||||
| # not in the file's dir | |||||
| if len(os.path.dirname(__file__)) != 0: | |||||
| os.chdir(os.path.dirname(__file__)) | |||||
| class MyDataLoader(object): | |||||
| def __init__(self, pickle_path): | |||||
| self.pickle_path = pickle_path | |||||
| def load(self, path, word_v=None, pos_v=None, headtag_v=None): | |||||
| datalist = [] | |||||
| with open(path, 'r', encoding='utf-8') as f: | |||||
| sample = [] | |||||
| for line in f: | |||||
| if line.startswith('\n'): | |||||
| datalist.append(sample) | |||||
| sample = [] | |||||
| elif line.startswith('#'): | |||||
| continue | |||||
| else: | |||||
| sample.append(line.split('\t')) | |||||
| if len(sample) > 0: | |||||
| datalist.append(sample) | |||||
| ds = DataSet(name='conll') | |||||
| for sample in datalist: | |||||
| # print(sample) | |||||
| res = self.get_one(sample) | |||||
| if word_v is not None: | |||||
| word_v.update(res[0]) | |||||
| pos_v.update(res[1]) | |||||
| headtag_v.update(res[3]) | |||||
| ds.append(Instance(word_seq=TextField(res[0], is_target=False), | |||||
| pos_seq=TextField(res[1], is_target=False), | |||||
| head_indices=SeqLabelField(res[2], is_target=True), | |||||
| head_labels=TextField(res[3], is_target=True), | |||||
| seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) | |||||
| return ds | |||||
| def get_one(self, sample): | |||||
| text = ['<root>'] | |||||
| pos_tags = ['<root>'] | |||||
| heads = [0] | |||||
| head_tags = ['root'] | |||||
| for w in sample: | |||||
| t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
| if t3 == '_': | |||||
| continue | |||||
| text.append(t1) | |||||
| pos_tags.append(t2) | |||||
| heads.append(int(t3)) | |||||
| head_tags.append(t4) | |||||
| return (text, pos_tags, heads, head_tags) | |||||
| def index_data(self, dataset, word_v, pos_v, tag_v): | |||||
| dataset.index_field('word_seq', word_v) | |||||
| dataset.index_field('pos_seq', pos_v) | |||||
| dataset.index_field('head_labels', tag_v) | |||||
| # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | |||||
| datadir = "/home/yfshao/UD_English-EWT" | |||||
| cfgfile = './cfg.cfg' | |||||
| train_data_name = "en_ewt-ud-train.conllu" | |||||
| dev_data_name = "en_ewt-ud-dev.conllu" | |||||
| emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
| processed_datadir = './save' | |||||
| # Config Loader | |||||
| train_args = ConfigSection() | |||||
| test_args = ConfigSection() | |||||
| model_args = ConfigSection() | |||||
| optim_args = ConfigSection() | |||||
| ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | |||||
| # Data Loader | |||||
| def save_data(dirpath, **kwargs): | |||||
| import _pickle | |||||
| if not os.path.exists(dirpath): | |||||
| os.mkdir(dirpath) | |||||
| for name, data in kwargs.items(): | |||||
| with open(os.path.join(dirpath, name+'.pkl'), 'wb') as f: | |||||
| _pickle.dump(data, f) | |||||
| def load_data(dirpath): | |||||
| import _pickle | |||||
| datas = {} | |||||
| for f_name in os.listdir(dirpath): | |||||
| if not f_name.endswith('.pkl'): | |||||
| continue | |||||
| name = f_name[:-4] | |||||
| with open(os.path.join(dirpath, f_name), 'rb') as f: | |||||
| datas[name] = _pickle.load(f) | |||||
| return datas | |||||
| class MyTester(object): | |||||
| def __init__(self, batch_size, use_cuda=False, **kwagrs): | |||||
| self.batch_size = batch_size | |||||
| self.use_cuda = use_cuda | |||||
| def test(self, model, dataset): | |||||
| self.model = model.cuda() if self.use_cuda else model | |||||
| self.model.eval() | |||||
| batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda) | |||||
| eval_res = defaultdict(list) | |||||
| i = 0 | |||||
| for batch_x, batch_y in batchiter: | |||||
| with torch.no_grad(): | |||||
| pred_y = self.model(**batch_x) | |||||
| eval_one = self.model.evaluate(**pred_y, **batch_y) | |||||
| i += self.batch_size | |||||
| for eval_name, tensor in eval_one.items(): | |||||
| eval_res[eval_name].append(tensor) | |||||
| tmp = {} | |||||
| for eval_name, tensorlist in eval_res.items(): | |||||
| tmp[eval_name] = torch.cat(tensorlist, dim=0) | |||||
| self.res = self.model.metrics(**tmp) | |||||
| def show_metrics(self): | |||||
| s = "" | |||||
| for name, val in self.res.items(): | |||||
| s += '{}: {:.2f}\t'.format(name, val) | |||||
| return s | |||||
| loader = MyDataLoader('') | |||||
| try: | |||||
| data_dict = load_data(processed_datadir) | |||||
| word_v = data_dict['word_v'] | |||||
| pos_v = data_dict['pos_v'] | |||||
| tag_v = data_dict['tag_v'] | |||||
| train_data = data_dict['train_data'] | |||||
| dev_data = data_dict['dev_data'] | |||||
| print('use saved pickles') | |||||
| except Exception as _: | |||||
| print('load raw data and preprocess') | |||||
| word_v = Vocabulary(need_default=True, min_freq=2) | |||||
| pos_v = Vocabulary(need_default=True) | |||||
| tag_v = Vocabulary(need_default=False) | |||||
| train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) | |||||
| dev_data = loader.load(os.path.join(datadir, dev_data_name)) | |||||
| save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | |||||
| loader.index_data(train_data, word_v, pos_v, tag_v) | |||||
| loader.index_data(dev_data, word_v, pos_v, tag_v) | |||||
| print(len(train_data)) | |||||
| print(len(dev_data)) | |||||
| ep = train_args['epochs'] | |||||
| train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep | |||||
| model_args['word_vocab_size'] = len(word_v) | |||||
| model_args['pos_vocab_size'] = len(pos_v) | |||||
| model_args['num_label'] = len(tag_v) | |||||
| def train(): | |||||
| # Trainer | |||||
| trainer = Trainer(**train_args.data) | |||||
| def _define_optim(obj): | |||||
| obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | |||||
| obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4)) | |||||
| def _update(obj): | |||||
| obj._scheduler.step() | |||||
| obj._optimizer.step() | |||||
| trainer.define_optimizer = lambda: _define_optim(trainer) | |||||
| trainer.update = lambda: _update(trainer) | |||||
| trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth) | |||||
| trainer._create_validator = lambda x: MyTester(**test_args.data) | |||||
| # Model | |||||
| model = BiaffineParser(**model_args.data) | |||||
| # use pretrain embedding | |||||
| embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
| model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||||
| model.word_embedding.padding_idx = word_v.padding_idx | |||||
| model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | |||||
| model.pos_embedding.padding_idx = pos_v.padding_idx | |||||
| model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | |||||
| try: | |||||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
| print('model parameter loaded!') | |||||
| except Exception as _: | |||||
| print("No saved model. Continue.") | |||||
| pass | |||||
| # Start training | |||||
| trainer.train(model, train_data, dev_data) | |||||
| print("Training finished!") | |||||
| # Saver | |||||
| saver = ModelSaver("./save/saved_model.pkl") | |||||
| saver.save_pytorch(model) | |||||
| print("Model saved!") | |||||
| def test(): | |||||
| # Tester | |||||
| tester = MyTester(**test_args.data) | |||||
| # Model | |||||
| model = BiaffineParser(**model_args.data) | |||||
| try: | |||||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
| print('model parameter loaded!') | |||||
| except Exception as _: | |||||
| print("No saved model. Abort test.") | |||||
| raise | |||||
| # Start training | |||||
| tester.test(model, dev_data) | |||||
| print(tester.show_metrics()) | |||||
| print("Testing finished!") | |||||
| if __name__ == "__main__": | |||||
| import argparse | |||||
| parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||||
| parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
| args = parser.parse_args() | |||||
| if args.mode == 'train': | |||||
| train() | |||||
| elif args.mode == 'test': | |||||
| test() | |||||
| elif args.mode == 'infer': | |||||
| infer() | |||||
| else: | |||||
| print('no mode specified for model!') | |||||
| parser.print_help() | |||||
| @@ -0,0 +1,12 @@ | |||||
| the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 | |||||
| , 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -0.42852 -0.55641 -0.364 -0.23938 0.13001 -0.063734 -0.39575 -0.48162 0.23291 0.090201 -0.13324 0.078639 -0.41634 -0.15428 0.10068 0.48891 0.31226 -0.1252 -0.037512 -1.5179 0.12612 -0.02442 -0.042961 -0.28351 3.5416 -0.11956 -0.014533 -0.1499 0.21864 -0.33412 -0.13872 0.31806 0.70358 0.44858 -0.080262 0.63003 0.32111 -0.46765 0.22786 0.36034 -0.37818 -0.56657 0.044691 0.30392 | |||||
| . 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -0.43478 -0.31086 -0.44999 -0.29486 0.16608 0.11963 -0.41328 -0.42353 0.59868 0.28825 -0.11547 -0.041848 -0.67989 -0.25063 0.18472 0.086876 0.46582 0.015035 0.043474 -1.4671 -0.30384 -0.023441 0.30589 -0.21785 3.746 0.0042284 -0.18436 -0.46209 0.098329 -0.11907 0.23919 0.1161 0.41705 0.056763 -6.3681e-05 0.068987 0.087939 -0.10285 -0.13931 0.22314 -0.080803 -0.35652 0.016413 0.10216 | |||||
| of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 | |||||
| to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 | |||||
| and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 | |||||
| in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 | |||||
| a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 | |||||
| " 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065 | |||||
| 's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231 | |||||
| @@ -0,0 +1,33 @@ | |||||
| import unittest | |||||
| import os | |||||
| import torch | |||||
| from fastNLP.loader.embed_loader import EmbedLoader | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| class TestEmbedLoader(unittest.TestCase): | |||||
| glove_path = './test/data_for_tests/glove.6B.50d_test.txt' | |||||
| pkl_path = './save' | |||||
| raw_texts = ["i am a cat", | |||||
| "this is a test of new batch", | |||||
| "ha ha", | |||||
| "I am a good boy .", | |||||
| "This is the most beautiful girl ." | |||||
| ] | |||||
| texts = [text.strip().split() for text in raw_texts] | |||||
| vocab = Vocabulary() | |||||
| vocab.update(texts) | |||||
| def test1(self): | |||||
| emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
| self.assertTrue(emb.shape[0] == (len(self.vocab))) | |||||
| self.assertTrue(emb.shape[1] == 50) | |||||
| os.remove(self.pkl_path) | |||||
| def test2(self): | |||||
| try: | |||||
| _ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
| self.fail(msg="load dismatch embedding") | |||||
| except ValueError: | |||||
| pass | |||||
| @@ -3,35 +3,23 @@ import unittest | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM | |||||
| from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
| class TestMaskedRnn(unittest.TestCase): | class TestMaskedRnn(unittest.TestCase): | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | |||||
| masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | |||||
| x = torch.tensor([[[1.0], [2.0]]]) | x = torch.tensor([[[1.0], [2.0]]]) | ||||
| print(x.size()) | print(x.size()) | ||||
| y = masked_rnn(x) | y = masked_rnn(x) | ||||
| mask = torch.tensor([[[1], [1]]]) | |||||
| y = masked_rnn(x, mask=mask) | |||||
| mask = torch.tensor([[[1], [0]]]) | |||||
| y = masked_rnn(x, mask=mask) | |||||
| def test_case_2(self): | def test_case_2(self): | ||||
| input_size = 12 | input_size = 12 | ||||
| batch = 16 | batch = 16 | ||||
| hidden = 10 | hidden = 10 | ||||
| masked_rnn = VarMaskedFastLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) | |||||
| x = torch.randn((batch, input_size)) | |||||
| output, _ = masked_rnn.step(x) | |||||
| self.assertEqual(tuple(output.shape), (batch, hidden)) | |||||
| masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) | |||||
| xx = torch.randn((batch, 32, input_size)) | xx = torch.randn((batch, 32, input_size)) | ||||
| y, _ = masked_rnn(xx) | y, _ = masked_rnn(xx) | ||||
| self.assertEqual(tuple(y.shape), (batch, 32, hidden)) | self.assertEqual(tuple(y.shape), (batch, 32, hidden)) | ||||
| xx = torch.randn((batch, 32, input_size)) | |||||
| mask = torch.from_numpy(np.random.randint(0, 2, size=(batch, 32))).to(xx) | |||||
| y, _ = masked_rnn(xx, mask=mask) | |||||
| self.assertEqual(tuple(y.shape), (batch, 32, hidden)) | |||||