diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 3becce2c..1c7d33c5 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -16,6 +16,7 @@ from ._logger import logger from .dataset import DataSet from .utils import Option from .utils import _is_iterable +import io class VocabularyOption(Option): @@ -487,76 +488,99 @@ class Vocabulary(object): def save(self, filepath): r""" - :param str filepath: Vocabulary的储存路径 + :param str,io.StringIO filepath: Vocabulary的储存路径 :return: """ - with open(filepath, 'w', encoding='utf-8') as f: - f.write(f'max_size\t{self.max_size}\n') - f.write(f'min_freq\t{self.min_freq}\n') - f.write(f'unknown\t{self.unknown}\n') - f.write(f'padding\t{self.padding}\n') - f.write(f'rebuild\t{self.rebuild}\n') - f.write('\n') - # idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 - # no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise - # word \t count \t idx \t no_create_entry \n - idx = -2 - for word, count in self.word_count.items(): - if self._word2idx is not None: - idx = self._word2idx.get(word, -1) - is_no_create_entry = int(self._is_word_no_create_entry(word)) - f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') + if isinstance(filepath, io.IOBase): + assert filepath.writable() + f = filepath + elif isinstance(filepath, str): + try: + f = open(filepath, 'w', encoding='utf-8') + except Exception as e: + raise e + else: + raise TypeError("Illegal `filepath`.") + + f.write(f'max_size\t{self.max_size}\n') + f.write(f'min_freq\t{self.min_freq}\n') + f.write(f'unknown\t{self.unknown}\n') + f.write(f'padding\t{self.padding}\n') + f.write(f'rebuild\t{self.rebuild}\n') + f.write('\n') + # idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 + # no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise + # word \t count \t idx \t no_create_entry \n + idx = -2 + for word, count in self.word_count.items(): + if self._word2idx is not None: + idx = self._word2idx.get(word, -1) + is_no_create_entry = int(self._is_word_no_create_entry(word)) + f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') + if isinstance(filepath, str): # 如果是file的话就关闭 + f.close() @staticmethod def load(filepath): r""" - :param str filepath: Vocabulary的读取路径 + :param str,io.StringIO filepath: Vocabulary的读取路径 :return: Vocabulary """ - with open(filepath, 'r', encoding='utf-8') as f: - vocab = Vocabulary() - for line in f: - line = line.strip() - if line: - name, value = line.split() - if name in ('max_size', 'min_freq'): - value = int(value) if value!='None' else None - setattr(vocab, name, value) - elif name in ('unknown', 'padding'): - value = value if value!='None' else None - setattr(vocab, name, value) - elif name == 'rebuild': - vocab.rebuild = True if value=='True' else False - else: - break - word_counter = {} - no_create_entry_counter = {} - word2idx = {} - for line in f: - line = line.strip() - if line: - parts = line.split('\t') - word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) - if idx >= 0: - word2idx[word] = idx - word_counter[word] = count - if no_create_entry: - no_create_entry_counter[word] = count - - word_counter = Counter(word_counter) - no_create_entry_counter = Counter(no_create_entry_counter) - if len(word2idx)>0: - if vocab.padding: - word2idx[vocab.padding] = 0 - if vocab.unknown: - word2idx[vocab.unknown] = 1 if vocab.padding else 0 - idx2word = {value:key for key,value in word2idx.items()} - - vocab.word_count = word_counter - vocab._no_create_word = no_create_entry_counter - if word2idx: - vocab._word2idx = word2idx - vocab._idx2word = idx2word + if isinstance(filepath, io.IOBase): + assert filepath.writable() + f = filepath + elif isinstance(filepath, str): + try: + f = open(filepath, 'r', encoding='utf-8') + except Exception as e: + raise e + else: + raise TypeError("Illegal `filepath`.") + vocab = Vocabulary() + for line in f: + line = line.strip() + if line: + name, value = line.split() + if name in ('max_size', 'min_freq'): + value = int(value) if value!='None' else None + setattr(vocab, name, value) + elif name in ('unknown', 'padding'): + value = value if value!='None' else None + setattr(vocab, name, value) + elif name == 'rebuild': + vocab.rebuild = True if value=='True' else False + else: + break + word_counter = {} + no_create_entry_counter = {} + word2idx = {} + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) + if idx >= 0: + word2idx[word] = idx + word_counter[word] = count + if no_create_entry: + no_create_entry_counter[word] = count + + word_counter = Counter(word_counter) + no_create_entry_counter = Counter(no_create_entry_counter) + if len(word2idx)>0: + if vocab.padding: + word2idx[vocab.padding] = 0 + if vocab.unknown: + word2idx[vocab.unknown] = 1 if vocab.padding else 0 + idx2word = {value:key for key,value in word2idx.items()} + + vocab.word_count = word_counter + vocab._no_create_word = no_create_entry_counter + if word2idx: + vocab._word2idx = word2idx + vocab._idx2word = idx2word + if isinstance(filepath, str): # 如果是file的话就关闭 + f.close() return vocab diff --git a/fastNLP/embeddings/__init__.py b/fastNLP/embeddings/__init__.py index bf35b7d4..1b3b6e83 100644 --- a/fastNLP/embeddings/__init__.py +++ b/fastNLP/embeddings/__init__.py @@ -22,8 +22,9 @@ __all__ = [ "StackEmbedding", "LSTMCharEmbedding", "CNNCharEmbedding", - "get_embeddings", + "get_embeddings", + "get_sinusoid_encoding_table" ] from .embedding import Embedding, TokenEmbedding @@ -34,7 +35,7 @@ from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding from .stack_embedding import StackEmbedding -from .utils import get_embeddings +from .utils import get_embeddings, get_sinusoid_encoding_table import sys from ..doc_utils import doc_process diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 69bc3108..6ec5df99 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -8,11 +8,11 @@ __all__ = [ "BertWordPieceEncoder" ] -import collections +import os import warnings from itertools import chain from functools import partial - +import json import numpy as np import torch from torch import nn @@ -24,6 +24,13 @@ from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR from ..modules.encoder.bert import BertModel from ..modules.tokenizer import BertTokenizer +# TODO 需要重新修改,使得encoder可以直接读取embedding的权重 +VOCAB_NAME = 'vocab.txt' +BERT_EMBED_HYPER = 'bert_hyper.json' +BERT_EMBED_FOLDER = 'bert' +BERT_ENCODER_HYPER = 'bert_hyper.json' +BERT_ENCODER_FOLDER = 'bert' + class BertEmbedding(ContextualEmbedding): r""" @@ -82,10 +89,7 @@ class BertEmbedding(ContextualEmbedding): word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] 来进行分类的任务将auto_truncate置为True。 :param kwargs: - bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 - 建议设置为True。 - int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 - bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) + int min_freq: 小于该次数的词会被unk代替 """ super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) @@ -106,14 +110,11 @@ class BertEmbedding(ContextualEmbedding): if '[CLS]' in vocab: self._word_cls_index = vocab['CLS'] - only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) - truncate_embed = kwargs.get('truncate_embed', True) min_freq = kwargs.get('min_freq', 2) - + self._min_freq = min_freq self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, - pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, - only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) + pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate) self.requires_grad = requires_grad self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size @@ -160,6 +161,57 @@ class BertEmbedding(ContextualEmbedding): words = words.masked_fill(mask, self._word_unk_index) return words + def save(self, folder): + """ + 将embedding保存到folder这个目录下,将会保存三个文件vocab.txt, bert_embed_hyper.txt, bert_embed/, 其中bert_embed下包含 + config.json,pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取) + + :param str folder: + :return: + """ + os.makedirs(folder, exist_ok=True) + + self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME)) + + hyper = {} + hyper['min_freq'] = self._min_freq + hyper['layers'] = ','.join(map(str, self.model.layers)) + hyper['pool_method'] = self.model.pool_method + hyper['dropout'] = self.dropout_layer.p + hyper['word_dropout'] = self.word_dropout + hyper['include_cls_sep'] = self.model.include_cls_sep + hyper['pooled_cls'] = self.model.pooled_cls + hyper['auto_truncate'] = self.model.auto_truncate + hyper['requires_grad'] = bool(self.requires_grad) + + with open(os.path.join(folder, BERT_EMBED_HYPER), 'w', encoding='utf-8') as f: + json.dump(hyper, f, indent=2) + + os.makedirs(os.path.join(folder, BERT_EMBED_FOLDER), exist_ok=True) + self.model.save(os.path.join(folder, BERT_EMBED_FOLDER)) + logger.debug(f"BERTEmbedding has been saved in {folder}") + + @classmethod + def load(cls, folder): + """ + 给定一个folder, 需要包含以下三个内容vocab.txt, bert_embed_hyper.txt, bert_embed/ + + :param str folder: + :return: + """ + for name in [VOCAB_NAME, BERT_EMBED_FOLDER, BERT_EMBED_HYPER]: + assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." + + vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME)) + + with open(os.path.join(folder, BERT_EMBED_HYPER), 'r', encoding='utf-8') as f: + hyper = json.load(f) + + model_dir_or_name = os.path.join(os.path.join(folder, BERT_EMBED_FOLDER)) + + bert_embed = cls(vocab=vocab, model_dir_or_name=model_dir_or_name, **hyper) + return bert_embed + class BertWordPieceEncoder(nn.Module): r""" @@ -180,7 +232,7 @@ class BertWordPieceEncoder(nn.Module): """ def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, - word_dropout=0, dropout=0, requires_grad: bool = True): + word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs): r""" :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` @@ -270,11 +322,53 @@ class BertWordPieceEncoder(nn.Module): words = words.masked_fill(mask, self._wordpiece_unk_index) return words + def save(self, folder): + """ + 会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件config.json, + pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取) + + :param str folder: + :return: + """ + os.makedirs(folder, exist_ok=True) + + hyper = {} + hyper['layers'] = ','.join(map(str, self.model.layers)) + hyper['dropout'] = self.dropout_layer.p + hyper['word_dropout'] = self.word_dropout + hyper['pooled_cls'] = self.model.pooled_cls + hyper['requires_grad'] = bool(self.requires_grad) + + with open(os.path.join(folder, BERT_ENCODER_HYPER), 'w', encoding='utf-8') as f: + json.dump(hyper, f, indent=2) + + os.makedirs(os.path.join(folder, BERT_ENCODER_FOLDER), exist_ok=True) + self.model.save(os.path.join(folder, BERT_ENCODER_FOLDER)) + logger.debug(f"BertWordPieceEncoder has been saved in {folder}") + + @classmethod + def load(cls, folder): + """ + 会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件 + + :param folder: + :return: + """ + for name in [BERT_ENCODER_HYPER, BERT_ENCODER_FOLDER]: + assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." + + with open(os.path.join(folder, BERT_ENCODER_HYPER), 'r', encoding='utf-8') as f: + hyper = json.load(f) + + model_dir_or_name = os.path.join(os.path.join(folder, BERT_ENCODER_FOLDER)) + + bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper) + return bert_encoder + class _BertWordModel(nn.Module): def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', - include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, - only_use_pretrain_bpe=False, truncate_embed=True): + include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): super().__init__() self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) @@ -303,73 +397,8 @@ class _BertWordModel(nn.Module): self.auto_truncate = auto_truncate # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] - logger.info("Start to generate word pieces for word.") self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids - # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 - word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 - new_add_to_bpe_vocab = 0 - - unsegment_count = 0 - if '[sep]' in vocab: - warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") - if "[CLS]" in vocab: - warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CLS] and [SEP] to the begin " - "and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" - " and end.") - for word, index in vocab: - if index == vocab.padding_idx: # pad是个特殊的符号 - word = '[PAD]' - elif index == vocab.unknown_idx: - word = '[UNK]' - _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() - word_pieces = [] - for w in _words: - word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w)) - if len(word_pieces) == 1: - if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 - if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 - if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( - word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 - word_piece_dict[word] = 1 # 新增一个值 - new_add_to_bpe_vocab += 1 - unsegment_count += 1 - - continue - for word_piece in word_pieces: - word_piece_dict[word_piece] = 1 - original_embed = self.encoder.embeddings.word_embeddings.weight.data - - # 特殊词汇要特殊处理 - if not truncate_embed:# 如果不删除的话需要将已有的加上 - word_piece_dict.update(self.tokenzier.vocab) - embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed - new_word_piece_vocab = collections.OrderedDict() - - for index, token in enumerate(['[PAD]', '[UNK]']): - index = word_piece_dict.pop(token, None) - if index is not None: - new_word_piece_vocab[token] = len(new_word_piece_vocab) - embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.vocab[token]] - for token in word_piece_dict.keys(): - if token not in new_word_piece_vocab: - new_word_piece_vocab[token] = len(new_word_piece_vocab) - index = new_word_piece_vocab[token] - if token in self.tokenzier.vocab: - embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] - else: - embed.weight.data[index] = original_embed[self.tokenzier.vocab['[UNK]']] - - self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) - self.encoder.embeddings.word_embeddings = embed - - self.encoder.config.vocab_size = len(new_word_piece_vocab) - if unsegment_count>0: - if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: - logger.info(f"{unsegment_count} words are unsegmented.") - else: - logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") - word_to_wordpieces = [] word_pieces_lengths = [] for word, index in vocab: @@ -377,6 +406,8 @@ class _BertWordModel(nn.Module): word = '[PAD]' elif index == vocab.unknown_idx: word = '[UNK]' + elif vocab.word_count[word] int: + def num_embeddings(self) -> int: r""" 这个值可能会大于实际的embedding矩阵的大小。 :return: @@ -205,7 +205,7 @@ class TokenEmbedding(nn.Module): @property def size(self): - return torch.Size(self.num_embedding, self._embed_size) + return torch.Size(self.num_embeddings, self._embed_size) @abstractmethod def forward(self, words): diff --git a/fastNLP/embeddings/roberta_embedding.py b/fastNLP/embeddings/roberta_embedding.py index 1479a383..8e16c055 100644 --- a/fastNLP/embeddings/roberta_embedding.py +++ b/fastNLP/embeddings/roberta_embedding.py @@ -10,8 +10,8 @@ __all__ = [ from functools import partial -import collections -import warnings +import os +import json from itertools import chain import numpy as np @@ -24,6 +24,13 @@ from ..modules.encoder.roberta import RobertaModel from ..modules.tokenizer import RobertaTokenizer +VOCAB_NAME = 'vocab.txt' +ROBERTA_EMBED_HYPER = 'roberta_hyper.json' +ROBERTA_ENCODER_HYPER = 'roberta_hyper.json' +ROBERTA_EMBED_FOLDER = 'roberta' +ROBERTA_ENCODER_FOLDER = 'roberta' + + class RobertaEmbedding(ContextualEmbedding): r""" 使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 @@ -71,10 +78,7 @@ class RobertaEmbedding(ContextualEmbedding): word pieces后的内容,并将第512个word piece置为。超过长度的部分的encode结果直接全部置零。一般仅有只使用 来进行分类的任务将auto_truncate置为True。 :param kwargs: - bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 - 建议设置为True。 - int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 - bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) + int min_freq: 小于该次数的词会被unk代替 """ super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) @@ -89,14 +93,12 @@ class RobertaEmbedding(ContextualEmbedding): if '' in vocab: self._word_cls_index = vocab[''] - only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) - truncate_embed = kwargs.get('truncate_embed', True) min_freq = kwargs.get('min_freq', 2) + self._min_freq = min_freq self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, - pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, - only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) + pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq) self.requires_grad = requires_grad self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size @@ -142,11 +144,56 @@ class RobertaEmbedding(ContextualEmbedding): words = words.masked_fill(mask, self._word_unk_index) return words + def save(self, folder): + """ + 将roberta embedding保存到folder,保存之后包含三个文件vocab.txt, roberta_embed_hyper.txt, roberta_embed/, + + :param str folder: 保存地址 + :return: + """ + os.makedirs(folder, exist_ok=True) + self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME)) + + hyper = {} + hyper['min_freq'] = self._min_freq + hyper['layers'] = ','.join(map(str, self.model.layers)) + hyper['pool_method'] = self.model.pool_method + hyper['dropout'] = self.dropout_layer.p + hyper['word_dropout'] = self.word_dropout + hyper['include_cls_sep'] = self.model.include_cls_sep + hyper['pooled_cls'] = self.model.pooled_cls + hyper['auto_truncate'] = self.model.auto_truncate + hyper['requires_grad'] = bool(self.requires_grad) + + with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'w', encoding='utf-8') as f: + json.dump(hyper, f, indent=2) + + os.makedirs(os.path.join(folder, ROBERTA_EMBED_FOLDER), exist_ok=True) + self.model.save(os.path.join(folder, ROBERTA_EMBED_FOLDER)) + + @classmethod + def load(cls, folder): + """ + 从folder中读取数据初始化RobertaEmbedding + + :param folder: + :return: + """ + for name in [VOCAB_NAME, ROBERTA_EMBED_HYPER, ROBERTA_EMBED_FOLDER]: + assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." + + vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME)) + with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'r', encoding='utf-8') as f: + hyper = json.load(f) + model_name_or_path = os.path.join(folder, ROBERTA_EMBED_FOLDER) + + roberta = cls(vocab=vocab, model_dir_or_name=model_name_or_path, **hyper) + return roberta + class _RobertaWordModel(nn.Module): def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', - include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, - only_use_pretrain_bpe=False, truncate_embed=True): + include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): super().__init__() self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name) @@ -177,72 +224,6 @@ class _RobertaWordModel(nn.Module): self.pooled_cls = pooled_cls self.auto_truncate = auto_truncate - # 将所有vocab中word的wordpiece计算出来, 需要额外考虑 - logger.info("Start to generate word pieces for word.") - # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 - word_piece_dict = {'': 1, '': 1} # 用到的word_piece以及新增的 - found_count = 0 - new_add_to_bpe_vocab = 0 - unsegment_count = 0 - if "" in vocab: - warnings.warn(" detected in your vocabulary. RobertaEmbedding will add and to the begin " - "and end of the input automatically, make sure you don't add and at the begin" - " and end.") - for word, index in vocab: - if index == vocab.padding_idx: # pad是个特殊的符号 - word = '' - elif index == vocab.unknown_idx: - word = '' - # _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容 - word_pieces = [] - # 如果这个word不是在句子开头 - word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True)) - if len(word_pieces) == 1: - if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 - if index != vocab.unknown_idx and word_pieces[0] == '': # 说明这个词不在原始的word里面 - if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( - word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 - word_piece_dict[word] = 1 # 新增一个值 - new_add_to_bpe_vocab += 1 - unsegment_count += 1 - continue - found_count += 1 - for word_piece in word_pieces: - word_piece_dict[word_piece] = 1 - # 如果这个word是在句子开头 - - original_embed = self.encoder.embeddings.word_embeddings.weight.data - # 特殊词汇要特殊处理 - if not truncate_embed: # 如果不删除的话需要将已有的加上 - word_piece_dict.update(self.tokenzier.encoder) - - embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed - new_word_piece_vocab = collections.OrderedDict() - - for index, token in enumerate(['', '', '', '']): - index = word_piece_dict.pop(token, None) - if index is not None: - new_word_piece_vocab[token] = len(new_word_piece_vocab) - embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]] - for token in word_piece_dict.keys(): - if token not in new_word_piece_vocab: - new_word_piece_vocab[token] = len(new_word_piece_vocab) - index = new_word_piece_vocab[token] - if token in self.tokenzier.encoder: - embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]] - else: - embed.weight.data[index] = original_embed[self.tokenzier.encoder['']] - - self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) - self.encoder.embeddings.word_embeddings = embed - self.encoder.config.vocab_size = len(new_word_piece_vocab) - - if unsegment_count>0: - if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: - logger.info(f"{unsegment_count} words are unsegmented.") - else: - logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") - word_to_wordpieces = [] word_pieces_lengths = [] for word, index in vocab: @@ -250,6 +231,8 @@ class _RobertaWordModel(nn.Module): word = '' elif index == vocab.unknown_idx: word = '' + elif vocab.word_count[word] 0: - if model_dir_or_name is not None: - warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" + if model_dir_or_name: + logger.info(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" f" dimension {embedding_dim}. If you want to use pre-trained embedding, " f"set `embedding_dim` to 0.") model_dir_or_name = None @@ -116,7 +123,9 @@ class StaticEmbedding(TokenEmbedding): model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") - + + kwargs['min_freq'] = min_freq + kwargs['lower'] = lower # 根据min_freq缩小vocab truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq) if truncate_vocab: @@ -143,7 +152,7 @@ class StaticEmbedding(TokenEmbedding): truncated_words_to_words = torch.arange(len(vocab)).long() for word, index in vocab: truncated_words_to_words[index] = truncated_vocab.to_index(word) - logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") + logger.info(f"{len(vocab) - len(truncated_vocab)} words have frequency less than {min_freq}.") vocab = truncated_vocab self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) @@ -198,6 +207,7 @@ class StaticEmbedding(TokenEmbedding): sparse=False, _weight=embedding) self._embed_size = self.embedding.weight.size(1) self.requires_grad = requires_grad + self.kwargs = kwargs @property def weight(self): @@ -321,3 +331,71 @@ class StaticEmbedding(TokenEmbedding): words = self.embedding(words) words = self.dropout(words) return words + + def save(self, folder): + """ + 将embedding存储到folder下,之后可以通过使用load方法读取 + + :param str folder: 会在该folder下生成三个文件, vocab.txt, static_embed_hyper.txt, static_embed_hyper.json. + 其中vocab.txt可以用Vocabulary通过load读取; embedding.txt按照word2vec的方式存储,以空格的方式隔开元素, + 第一行只有两个元素,剩下的行首先是word然后是各个维度的值; static_embed_hyper.json是StaticEmbedding的超参数 + :return: + """ + os.makedirs(folder, exist_ok=True) + + vocab = self.get_word_vocab() + vocab_fp = os.path.join(folder, VOCAB_FILENAME) + vocab.save(vocab_fp) + kwargs = self.kwargs.copy() + kwargs['dropout'] = self.dropout_layer.p + kwargs['word_dropout'] = self.word_dropout + kwargs['requires_grad'] = self.requires_grad + kwargs['only_norm_found_vector'] = False + kwargs['only_use_pretrain_word'] = True + + with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'w', encoding='utf-8') as f: + json.dump(kwargs, f, indent=2) + + with open(os.path.join(folder, STATIC_EMBED_FILENAME), 'w', encoding='utf-8') as f: + f.write('{}\n'.format(' '*30)) # 留白之后再来填写 + word_count = 0 + saved_word = {} + valid_word_count = 0 + for i in range(len(self.words_to_words)): + word = vocab.to_word(i) + if not vocab._is_word_no_create_entry(word): + word_count += 1 + if kwargs['lower']: + word = word.lower() + if word in saved_word: + continue + saved_word[word] = 1 + vec_i = self.words_to_words[i] + if vec_i==vocab.unknown_idx and i!=vocab.unknown_idx: + continue + vec = self.embedding.weight.data[vec_i].tolist() + vec_str = ' '.join(map(str, vec)) + f.write(f'{word} {vec_str}\n') + valid_word_count += 1 + f.seek(0) + f.write('{} {}'.format(valid_word_count, self.embedding_dim)) + logger.debug(f"StaticEmbedding has been saved to {folder}.") + + @classmethod + def load(cls, folder): + """ + + :param str folder: 该folder下应该有以下三个文件vocab.txt, static_embed.txt, static_hyper.json + :return: + """ + for name in [VOCAB_FILENAME, STATIC_EMBED_FILENAME, STATIC_HYPER_FILENAME]: + assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." + + vocab = Vocabulary.load(os.path.join(folder, VOCAB_FILENAME)) + with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'r', encoding='utf-8') as f: + hyper = json.load(f) + + logger.info(f"Load StaticEmbedding from {folder}.") + embed = cls(vocab=vocab, model_dir_or_name=os.path.join(folder, STATIC_EMBED_FILENAME), **hyper) + return embed + diff --git a/fastNLP/embeddings/utils.py b/fastNLP/embeddings/utils.py index 7f6ba3b1..cec015e0 100644 --- a/fastNLP/embeddings/utils.py +++ b/fastNLP/embeddings/utils.py @@ -9,7 +9,8 @@ from torch import nn as nn from ..core.vocabulary import Vocabulary __all__ = [ - 'get_embeddings' + 'get_embeddings', + 'get_sinusoid_encoding_table' ] @@ -31,7 +32,7 @@ def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, inclu return char_vocab -def get_embeddings(init_embed): +def get_embeddings(init_embed, padding_idx=None): r""" 根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray 的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 @@ -40,11 +41,12 @@ def get_embeddings(init_embed): :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; 传入torch.Tensor, 将使用传入的值作为Embedding初始化。 + :param padding_idx: 当传入tuple时,padding_idx有效 :return nn.Embedding: embeddings """ if isinstance(init_embed, tuple): res = nn.Embedding( - num_embeddings=init_embed[0], embedding_dim=init_embed[1]) + num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx) nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), b=np.sqrt(3 / res.weight.data.size(1))) elif isinstance(init_embed, nn.Module): @@ -58,3 +60,32 @@ def get_embeddings(init_embed): raise TypeError( 'invalid init_embed type: {}'.format((type(init_embed)))) return res + + +def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + """ + sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos + + :param int n_position: 一共多少个position + :param int d_hid: 多少维度,需要为偶数 + :param padding_idx: + :return: torch.FloatTensor, shape为n_position x d_hid + """ + + def cal_angle(position, hid_idx): + return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) + + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) + + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0. + + return torch.FloatTensor(sinusoid_table) + diff --git a/fastNLP/models/__init__.py b/fastNLP/models/__init__.py index 0c0be9cf..f00687b3 100644 --- a/fastNLP/models/__init__.py +++ b/fastNLP/models/__init__.py @@ -9,18 +9,18 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models """ __all__ = [ "CNNText", - + "SeqLabeling", "AdvSeqLabel", "BiLSTMCRF", - + "ESIM", - + "StarTransEnc", "STSeqLabel", "STNLICls", "STSeqCls", - + "BiaffineParser", "GraphParser", @@ -28,7 +28,13 @@ __all__ = [ "BertForSentenceMatching", "BertForMultipleChoice", "BertForTokenClassification", - "BertForQuestionAnswering" + "BertForQuestionAnswering", + + "TransformerSeq2SeqModel", + "LSTMSeq2SeqModel", + "Seq2SeqModel", + + 'SequenceGeneratorModel' ] from .base_model import BaseModel @@ -39,7 +45,9 @@ from .cnn_text_classification import CNNText from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF from .snli import ESIM from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel - +from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, Seq2SeqModel +from .seq2seq_generator import SequenceGeneratorModel import sys from ..doc_utils import doc_process -doc_process(sys.modules[__name__]) \ No newline at end of file + +doc_process(sys.modules[__name__]) diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 8a51aede..5851f8c8 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -39,7 +39,7 @@ from torch import nn from .base_model import BaseModel from ..core._logger import logger from ..core.const import Const -from ..embeddings import BertEmbedding +from ..embeddings.bert_embedding import BertEmbedding class BertForSequenceClassification(BaseModel): diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 343b46ae..dff4809c 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -314,13 +314,8 @@ class BiaffineParser(GraphParser): raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) self.position_emb = nn.Embedding(num_embeddings=self.max_len, embedding_dim=rnn_out_size, ) - self.encoder = TransformerEncoder(num_layers=rnn_layers, - model_size=rnn_out_size, - inner_size=1024, - key_size=d_k, - value_size=d_v, - num_head=n_head, - dropout=dropout, ) + self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size, + n_head=n_head, dim_ff=1024, dropout=dropout) else: raise ValueError('unsupported encoder type: {}'.format(encoder)) diff --git a/fastNLP/models/seq2seq_generator.py b/fastNLP/models/seq2seq_generator.py new file mode 100644 index 00000000..aa270b5f --- /dev/null +++ b/fastNLP/models/seq2seq_generator.py @@ -0,0 +1,62 @@ +import torch +from torch import nn +from .seq2seq_model import Seq2SeqModel +from ..modules.generator.seq2seq_generator import SequenceGenerator + + +class SequenceGeneratorModel(nn.Module): + """ + 用于封装Seq2SeqModel使其可以做生成任务 + + """ + + def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, num_beams=1, + do_sample=True, temperature=1.0, top_k=50, top_p=1.0, + repetition_penalty=1, length_penalty=1.0, pad_token_id=0): + """ + + :param Seq2SeqModel seq2seq_model: 序列到序列模型 + :param int,None bos_token_id: 句子开头的token id + :param int,None eos_token_id: 句子结束的token id + :param int max_length: 句子的最大长度 + :param int num_beams: beam search的大小 + :param bool do_sample: 是否通过采样的方式生成 + :param float temperature: 只有在do_sample为True才有意义 + :param int top_k: 只从top_k中采样 + :param float top_p: 只从top_p的token中采样,nucles sample + :param float repetition_penalty: 多大程度上惩罚重复的token + :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 + :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 + """ + super().__init__() + self.seq2seq_model = seq2seq_model + self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, num_beams=num_beams, + do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, length_penalty=length_penalty, + pad_token_id=pad_token_id) + + def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): + """ + 透传调用seq2seq_model的forward + + :param torch.LongTensor src_tokens: bsz x max_len + :param torch.LongTensor tgt_tokens: bsz x max_len' + :param torch.LongTensor src_seq_len: bsz + :param torch.LongTensor tgt_seq_len: bsz + :return: + """ + return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) + + def predict(self, src_tokens, src_seq_len=None): + """ + 给定source的内容,输出generate的内容 + + :param torch.LongTensor src_tokens: bsz x max_len + :param torch.LongTensor src_seq_len: bsz + :return: + """ + state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) + result = self.generator.generate(state) + return {'pred': result} diff --git a/fastNLP/models/seq2seq_model.py b/fastNLP/models/seq2seq_model.py new file mode 100644 index 00000000..ce867c0b --- /dev/null +++ b/fastNLP/models/seq2seq_model.py @@ -0,0 +1,176 @@ +r""" +主要包含组成Sequence-to-Sequence的model + +""" + +import torch +from torch import nn + +from ..embeddings import get_embeddings +from ..embeddings.utils import get_sinusoid_encoding_table +from ..modules.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder +from ..modules.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder + + +class Seq2SeqModel(nn.Module): + def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): + """ + 可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。 + + :param encoder: Encoder + :param decoder: Decoder + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): + """ + + :param torch.LongTensor src_tokens: source的token + :param torch.LongTensor tgt_tokens: target的token + :param torch.LongTensor src_seq_len: src的长度 + :param torch.LongTensor tgt_seq_len: target的长度,默认用不上 + :return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size + """ + state = self.prepare_state(src_tokens, src_seq_len) + decoder_output = self.decoder(tgt_tokens, state) + if isinstance(decoder_output, torch.Tensor): + return {'pred': decoder_output} + elif isinstance(decoder_output, (tuple, list)): + return {'pred': decoder_output[0]} + else: + raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}") + + def prepare_state(self, src_tokens, src_seq_len=None): + """ + 调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state + + :param src_tokens: + :param src_seq_len: + :return: + """ + encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) + state = self.decoder.init_state(encoder_output, encoder_mask) + return state + + @classmethod + def build_model(cls, *args, **kwargs): + """ + 需要实现本方法来进行Seq2SeqModel的初始化 + + :return: + """ + raise NotImplemented + + +class TransformerSeq2SeqModel(Seq2SeqModel): + """ + Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化 + + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def build_model(cls, src_embed, tgt_embed=None, + pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, + bind_encoder_decoder_embed=False, + bind_decoder_input_output_embed=True): + """ + 初始化一个TransformerSeq2SeqModel + + :param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding + :param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 + True,则不要输入该值 + :param str pos_embed: 支持sin, learned两种 + :param int max_position: 最大支持长度 + :param int num_layers: encoder和decoder的层数 + :param int d_model: encoder和decoder输入输出的大小 + :param int n_head: encoder和decoder的head的数量 + :param int dim_ff: encoder和decoder中FFN中间映射的维度 + :param float dropout: Attention和FFN dropout的大小 + :param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding + :param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 + :return: TransformerSeq2SeqModel + """ + if bind_encoder_decoder_embed and tgt_embed is not None: + raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") + + src_embed = get_embeddings(src_embed) + + if bind_encoder_decoder_embed: + tgt_embed = src_embed + else: + assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" + tgt_embed = get_embeddings(tgt_embed) + + if pos_embed == 'sin': + encoder_pos_embed = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), + freeze=True) # 这里规定0是padding + deocder_pos_embed = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), + freeze=True) # 这里规定0是padding + elif pos_embed == 'learned': + encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) + deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) + else: + raise ValueError("pos_embed only supports sin or learned.") + + encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed, + num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, + dropout=dropout) + decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, + d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, + dropout=dropout, + bind_decoder_input_output_embed=bind_decoder_input_output_embed) + + return cls(encoder, decoder) + + +class LSTMSeq2SeqModel(Seq2SeqModel): + """ + 使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model + + """ + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def build_model(cls, src_embed, tgt_embed=None, + num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True, + attention=True, bind_encoder_decoder_embed=False, + bind_decoder_input_output_embed=True): + """ + + :param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding + :param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 + True,则不要输入该值 + :param int num_layers: Encoder和Decoder的层数 + :param int hidden_size: encoder和decoder的隐藏层大小 + :param float dropout: 每层之间的Dropout的大小 + :param bool bidirectional: encoder是否使用双向LSTM + :param bool attention: decoder是否使用attention attend encoder在所有时刻的状态 + :param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding + :param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 + :return: LSTMSeq2SeqModel + """ + if bind_encoder_decoder_embed and tgt_embed is not None: + raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") + + src_embed = get_embeddings(src_embed) + + if bind_encoder_decoder_embed: + tgt_embed = src_embed + else: + assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" + tgt_embed = get_embeddings(tgt_embed) + + encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers, + hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional) + decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size, + dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed, + attention=attention) + return cls(encoder, decoder) diff --git a/fastNLP/models/sequence_labeling.py b/fastNLP/models/sequence_labeling.py index fa3037a3..de7943c0 100644 --- a/fastNLP/models/sequence_labeling.py +++ b/fastNLP/models/sequence_labeling.py @@ -14,9 +14,9 @@ import torch.nn.functional as F from .base_model import BaseModel from ..core.const import Const as C from ..core.utils import seq_len_to_mask -from ..embeddings import get_embeddings -from ..modules import ConditionalRandomField -from ..modules import LSTM +from ..embeddings.utils import get_embeddings +from ..modules.decoder import ConditionalRandomField +from ..modules.encoder import LSTM from ..modules import decoder, encoder from ..modules.decoder.crf import allowed_transitions diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py index e334eb29..77144660 100644 --- a/fastNLP/modules/__init__.py +++ b/fastNLP/modules/__init__.py @@ -58,7 +58,21 @@ __all__ = [ "RobertaModel", "GPT2Model", - "GPT2Tokenizer" + "GPT2Tokenizer", + + "TransformerSeq2SeqEncoder", + "LSTMSeq2SeqEncoder", + "Seq2SeqEncoder", + + "TransformerSeq2SeqDecoder", + "LSTMSeq2SeqDecoder", + "Seq2SeqDecoder", + + "TransformerState", + "LSTMState", + "State", + + "SequenceGenerator" ] import sys @@ -68,6 +82,7 @@ from . import encoder from .decoder import * from .dropout import TimestepDropout from .encoder import * +from .generator import * from .utils import summary from ..doc_utils import doc_process from .tokenizer import * diff --git a/fastNLP/modules/encoder/attention.py b/fastNLP/modules/attention.py similarity index 53% rename from fastNLP/modules/encoder/attention.py rename to fastNLP/modules/attention.py index 2f94ea07..85810670 100644 --- a/fastNLP/modules/encoder/attention.py +++ b/fastNLP/modules/attention.py @@ -12,7 +12,8 @@ import torch import torch.nn.functional as F from torch import nn -from fastNLP.modules.utils import initial_parameter +from .utils import initial_parameter +from .decoder.seq2seq_state import TransformerState class DotAttention(nn.Module): @@ -45,64 +46,153 @@ class DotAttention(nn.Module): class MultiHeadAttention(nn.Module): - r""" - Transformer当中的MultiHeadAttention """ + Attention is all you need中提到的多头注意力 - def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): - r""" - - :param input_size: int, 输入维度的大小。同时也是输出维度的大小。 - :param key_size: int, 每个head的维度大小。 - :param value_size: int,每个head中value的维度。 - :param num_head: int,head的数量。 - :param dropout: float。 - """ + """ + def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): super(MultiHeadAttention, self).__init__() - self.input_size = input_size - self.key_size = key_size - self.value_size = value_size - self.num_head = num_head - - in_size = key_size * num_head - self.q_in = nn.Linear(input_size, in_size) - self.k_in = nn.Linear(input_size, in_size) - self.v_in = nn.Linear(input_size, in_size) - self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) - self.out = nn.Linear(value_size * num_head, input_size) + self.d_model = d_model + self.n_head = n_head + self.dropout = dropout + self.head_dim = d_model // n_head + self.layer_idx = layer_idx + assert d_model % n_head == 0, "d_model should be divisible by n_head" + self.scaling = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + self.reset_parameters() + def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): + """ + + :param query: batch x seq x dim + :param key: batch x seq x dim + :param value: batch x seq x dim + :param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 + :param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 + :param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 + :return: + """ + assert key.size() == value.size() + if state is not None: + assert self.layer_idx is not None + qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() + + q = self.q_proj(query) # batch x seq x dim + q *= self.scaling + k = v = None + prev_k = prev_v = None + + # 从state中取kv + if isinstance(state, TransformerState): # 说明此时在inference阶段 + if qkv_same: # 此时在decoder self attention + prev_k = state.decoder_prev_key[self.layer_idx] + prev_v = state.decoder_prev_value[self.layer_idx] + else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 + k = state.encoder_key[self.layer_idx] + v = state.encoder_value[self.layer_idx] + + if k is None: + k = self.k_proj(key) + v = self.v_proj(value) + + if prev_k is not None: + k = torch.cat((prev_k, k), dim=1) + v = torch.cat((prev_v, v), dim=1) + + # 更新state + if isinstance(state, TransformerState): + if qkv_same: + state.decoder_prev_key[self.layer_idx] = k + state.decoder_prev_value[self.layer_idx] = v + else: + state.encoder_key[self.layer_idx] = k + state.encoder_value[self.layer_idx] = v + + # 开始计算attention + batch_size, q_len, d_model = query.size() + k_len, v_len = k.size(1), v.size(1) + q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) + k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) + v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) + + attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head + if key_mask is not None: + _key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1 + attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) + + if attn_mask is not None: + _attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head + attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) + + attn_weights = F.softmax(attn_weights, dim=2) + attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) + + output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim + output = output.reshape(batch_size, q_len, -1) + output = self.out_proj(output) # batch,q_len,dim + + return output, attn_weights + def reset_parameters(self): - sqrt = math.sqrt - nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) - nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) - nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) - nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size)) + nn.init.xavier_uniform_(self.q_proj.weight) + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.out_proj.weight) + + def set_layer_idx(self, layer_idx): + self.layer_idx = layer_idx - def forward(self, Q, K, V, atte_mask_out=None): - r""" - :param Q: [batch, seq_len_q, model_size] - :param K: [batch, seq_len_k, model_size] - :param V: [batch, seq_len_k, model_size] - :param seq_mask: [batch, seq_len] +class AttentionLayer(nn.Module): + def __init__(selfu, input_size, key_dim, value_dim, bias=False): """ - batch, sq, _ = Q.size() - sk = K.size(1) - d_k, d_v, n_head = self.key_size, self.value_size, self.num_head - # input linear - q = self.q_in(Q).view(batch, sq, n_head, d_k).transpose(1, 2) - k = self.k_in(K).view(batch, sk, n_head, d_k).transpose(1, 2) - v = self.v_in(V).view(batch, sk, n_head, d_v).transpose(1, 2) - - if atte_mask_out is not None: - atte_mask_out = atte_mask_out[:,None,:,:] # [bsz,1,1,len] - atte = self.attention(q, k, v, atte_mask_out).view(batch, n_head, sq, d_v) - - # concat all heads, do output linear - atte = atte.transpose(1, 2).contiguous().view(batch, sq, -1) - output = self.out(atte) - return output + 可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention + + :param int input_size: 输入的大小 + :param int key_dim: 一般就是encoder_output输出的维度 + :param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 + :param bias: + """ + super().__init__() + + selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) + selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias) + + def forward(self, input, encode_outputs, encode_mask): + """ + + :param input: batch_size x input_size + :param encode_outputs: batch_size x max_len x key_dim + :param encode_mask: batch_size x max_len, 为0的地方为padding + :return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 + """ + + # x: bsz x encode_hidden_size + x = self.input_proj(input) + + # compute attention + attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len + + # don't attend over padding + if encode_mask is not None: + attn_scores = attn_scores.float().masked_fill_( + encode_mask.eq(0), + float('-inf') + ).type_as(attn_scores) # FP16 support: cast to float and back + + attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz + + # sum weighted sources + x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size + + x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) + return x, attn_scores def _masked_softmax(tensor, mask): diff --git a/fastNLP/modules/decoder/__init__.py b/fastNLP/modules/decoder/__init__.py index 5dae0052..93099be0 100644 --- a/fastNLP/modules/decoder/__init__.py +++ b/fastNLP/modules/decoder/__init__.py @@ -6,10 +6,20 @@ __all__ = [ "MLP", "ConditionalRandomField", "viterbi_decode", - "allowed_transitions" + "allowed_transitions", + + "LSTMState", + "TransformerState", + "State", + + "TransformerSeq2SeqDecoder", + "LSTMSeq2SeqDecoder", + "Seq2SeqDecoder" ] from .crf import ConditionalRandomField from .crf import allowed_transitions from .mlp import MLP from .utils import viterbi_decode +from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder +from .seq2seq_state import State, LSTMState, TransformerState diff --git a/fastNLP/modules/decoder/seq2seq_decoder.py b/fastNLP/modules/decoder/seq2seq_decoder.py old mode 100755 new mode 100644 index 3933867a..987679b3 --- a/fastNLP/modules/decoder/seq2seq_decoder.py +++ b/fastNLP/modules/decoder/seq2seq_decoder.py @@ -1,109 +1,413 @@ -# coding=utf-8 -__all__ = [ - "TransformerPast", - "Past", - "Decoder" -] + +from typing import Union, Tuple +import math + import torch from torch import nn -import abc import torch.nn.functional as F +from ..attention import AttentionLayer, MultiHeadAttention from ...embeddings import StaticEmbedding -import numpy as np -from typing import Union, Tuple from ...embeddings.utils import get_embeddings -from torch.nn import LayerNorm -import math +from .seq2seq_state import State, LSTMState, TransformerState + +class Seq2SeqDecoder(nn.Module): + """ + Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 + 用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 -class Past: + """ def __init__(self): - pass + super().__init__() + + def forward(self, tokens, state, **kwargs): + """ - @abc.abstractmethod - def num_samples(self): - pass + :param torch.LongTensor tokens: bsz x max_len + :param State state: state包含了encoder的输出以及decode之前的内容 + :return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 + """ + raise NotImplemented - @abc.abstractmethod - def reorder_past(self, indices: torch.LongTensor): + def reorder_states(self, indices, states): """ - 根据indices中的index,将past的中状态置为正确的顺序。inplace改变 + 根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 :param torch.LongTensor indices: - :param Past past: + :param State states: :return: """ - raise NotImplemented + assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" + states.reorder_state(indices) + def init_state(self, encoder_output, encoder_mask): + """ + 初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 + + :param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch + 维度 + :param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch + 维度 + :param kwargs: + :return: State, 返回一个State对象,记录了encoder的输出 + """ + state = State(encoder_output, encoder_mask) + return state -class TransformerPast(Past): - def __init__(self, encoder_outputs: torch.Tensor = None, encoder_mask: torch.Tensor = None, - num_decoder_layer: int = 6): + def decode(self, tokens, state): """ + 根据states中的内容,以及tokens中的内容进行之后的生成。 - :param encoder_outputs: (batch,src_seq_len,dim) - :param encoder_mask: (batch,src_seq_len) - :param encoder_key: list of (batch, src_seq_len, dim) - :param encoder_value: - :param decoder_prev_key: - :param decoder_prev_value: + :param torch.LongTensor tokens: bsz x max_len, 上一个时刻的token输出。 + :param State state: 记录了encoder输出与decoder过去状态 + :return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 """ + outputs = self(state=state, tokens=tokens) + if isinstance(outputs, torch.Tensor): + return outputs[:, -1] + else: + raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.") + + +class TiedEmbedding(nn.Module): + """ + 用于将weight和原始weight绑定 + + """ + def __init__(self, weight): super().__init__() - self.encoder_outputs = encoder_outputs - self.encoder_mask = encoder_mask - self.encoder_key = [None] * num_decoder_layer - self.encoder_value = [None] * num_decoder_layer - self.decoder_prev_key = [None] * num_decoder_layer - self.decoder_prev_value = [None] * num_decoder_layer - - def num_samples(self): - if self.encoder_outputs is not None: - return self.encoder_outputs.size(0) - return None - - def _reorder_state(self, state, indices): - if type(state) == torch.Tensor: - state = state.index_select(index=indices, dim=0) - elif type(state) == list: - for i in range(len(state)): - assert state[i] is not None - state[i] = state[i].index_select(index=indices, dim=0) + self.weight = weight # vocab_size x embed_size + + def forward(self, x): + """ + + :param torch.FloatTensor x: bsz x * x embed_size + :return: torch.FloatTensor bsz x * x vocab_size + """ + return torch.matmul(x, self.weight.t()) + + +def get_binded_decoder_output_embed(embed): + """ + 给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding + + :param embed: + :return: + """ + if isinstance(embed, StaticEmbedding): + for idx, map2idx in enumerate(embed.words_to_words): + assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ + "include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ + "`lower=True` or `min_freq!=1`." + elif not isinstance(embed, nn.Embedding): + raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") + + return TiedEmbedding(embed.weight) + + +class LSTMSeq2SeqDecoder(Seq2SeqDecoder): + def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, hidden_size = 300, + dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): + """ + LSTM的Decoder + + :param nn.Module,tuple embed: decoder输入的embedding. + :param int num_layers: 多少层LSTM + :param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 + :param dropout: Dropout的大小 + :param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, + 则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. + :param bool attention: 是否使用attention + """ + super().__init__() + self.embed = get_embeddings(init_embed=embed) + self.embed_dim = embed.embedding_dim + + if bind_decoder_input_output_embed: + self.output_layer = get_binded_decoder_output_embed(self.embed) + else: # 不需要bind + self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) + self.output_layer = TiedEmbedding(self.output_embed.weight) + + self.hidden_size = hidden_size + self.num_layers = num_layers + self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, + batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0) + + self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None + self.output_proj = nn.Linear(hidden_size, self.embed_dim) + self.dropout_layer = nn.Dropout(dropout) + + def forward(self, tokens, state, return_attention=False): + """ + + :param torch.LongTensor tokens: batch x max_len + :param LSTMState state: 保存encoder输出和decode状态的State对象 + :param bool return_attention: 是否返回attention的的score + :return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length + """ + src_output = state.encoder_output + encoder_mask = state.encoder_mask + + assert tokens.size(1)>state.decode_length, "The state does not match the tokens." + tokens = tokens[:, state.decode_length:] + x = self.embed(tokens) + + attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq + input_feed = state.input_feed + decoder_out = [] + + cur_hidden = state.hidden + cur_cell = state.cell + + # 开始计算 + for i in range(tokens.size(1)): + input = torch.cat( + (x[:, i:i + 1, :], + input_feed[:, None, :] + ), + dim=2 + ) # batch,1,2*dim + _, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size + if self.attention_layer is not None: + input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) + attn_weights.append(attn_weight) + else: + input_feed = cur_hidden[-1] + + state.input_feed = input_feed # batch, hidden + state.hidden = cur_hidden + state.cell = cur_cell + state.decode_length += 1 + decoder_out.append(input_feed) + + decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden + decoder_out = self.dropout_layer(decoder_out) + if attn_weights is not None: + attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len + + decoder_out = self.output_proj(decoder_out) + feats = self.output_layer(decoder_out) + + if return_attention: + return feats, attn_weights + return feats + + def init_state(self, encoder_output, encoder_mask) -> LSTMState: + """ + + :param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: + bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 + hidden state和cell state来赋值这两个值 + (2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 + :param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend + :return: + """ + if not isinstance(encoder_output, torch.Tensor): + encoder_output, (hidden, cell) = encoder_output else: - raise ValueError('State does not support other format') + hidden = cell = None + assert encoder_output.ndim==3 + assert encoder_mask.size()==encoder_output.size()[:2] + assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \ + "the hidden_size." + + t = [hidden, cell] + for idx in range(2): + v = t[idx] + if v is None: + v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size) + else: + assert v.dim()==2 + assert v.size(-1)==self.hidden_size + v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size + t[idx] = v + + state = LSTMState(encoder_output, encoder_mask, t[0], t[1]) return state - def reorder_past(self, indices: torch.LongTensor): - self.encoder_outputs = self._reorder_state(self.encoder_outputs, indices) - self.encoder_mask = self._reorder_state(self.encoder_mask, indices) - self.encoder_key = self._reorder_state(self.encoder_key, indices) - self.encoder_value = self._reorder_state(self.encoder_value, indices) - self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) - self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) - return self +class TransformerSeq2SeqDecoderLayer(nn.Module): + def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): + """ -class Decoder(nn.Module): - def __init__(self): + :param int d_model: 输入、输出的维度 + :param int n_head: 多少个head,需要能被d_model整除 + :param int dim_ff: + :param float dropout: + :param int layer_idx: layer的编号 + """ super().__init__() + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息 + + self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) + self.self_attn_layer_norm = nn.LayerNorm(d_model) + + self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) + self.encoder_attn_layer_norm = nn.LayerNorm(d_model) + + self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(self.dim_ff, self.d_model), + nn.Dropout(dropout)) + + self.final_layer_norm = nn.LayerNorm(self.d_model) - @abc.abstractmethod - def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: + def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): """ - 当模型进行解码时,使用这个函数。返回一个batch_size x vocab_size的结果与更新的Past状态。需要考虑一种特殊情况,即tokens长度不是1,即给定了 - 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态。 - :return: tensor:batch_size x vocab_size, past: Past + :param x: (batch, seq_len, dim), decoder端的输入 + :param encoder_output: (batch,src_seq_len,dim) + :param encoder_mask: batch,src_seq_len + :param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 + :param TransformerState state: 只在inference阶段传入 + :return: """ - raise NotImplemented - @abc.abstractmethod - def reorder_past(self, indices: torch.LongTensor, past: Past): + # self attention part + residual = x + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn(query=x, + key=x, + value=x, + attn_mask=self_attn_mask, + state=state) + + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # encoder attention part + residual = x + x = self.encoder_attn_layer_norm(x) + x, attn_weight = self.encoder_attn(query=x, + key=encoder_output, + value=encoder_output, + key_mask=encoder_mask, + state=state) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # ffn + residual = x + x = self.final_layer_norm(x) + x = self.ffn(x) + x = residual + x + + return x, attn_weight + + +class TransformerSeq2SeqDecoder(Seq2SeqDecoder): + def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, + d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, + bind_decoder_input_output_embed = True): """ - 根据indices中的index,将past的中状态置为正确的顺序。inplace改变 - :param torch.LongTensor indices: - :param Past past: - :return: + :param embed: 输入token的embedding + :param nn.Module pos_embed: 位置embedding + :param int d_model: 输出、输出的大小 + :param int num_layers: 多少层 + :param int n_head: 多少个head + :param int dim_ff: FFN 的中间大小 + :param float dropout: Self-Attention和FFN中的dropout的大小 + :param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, + 则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. + """ + super().__init__() + + self.embed = get_embeddings(embed) + self.pos_embed = pos_embed + + if bind_decoder_input_output_embed: + self.output_layer = get_binded_decoder_output_embed(self.embed) + else: # 不需要bind + self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) + self.output_layer = TiedEmbedding(self.output_embed.weight) + + self.num_layers = num_layers + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) + self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) + for layer_idx in range(num_layers)]) + + self.embed_scale = math.sqrt(d_model) + self.layer_norm = nn.LayerNorm(d_model) + self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) + + def forward(self, tokens, state, return_attention=False): """ - raise NotImplemented \ No newline at end of file + + :param torch.LongTensor tokens: batch x tgt_len,decode的词 + :param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 + :param bool return_attention: 是否返回对encoder结果的attention score + :return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length + """ + + encoder_output = state.encoder_output + encoder_mask = state.encoder_mask + + assert state.decode_length1: + triangle_mask = self._get_triangle_mask(tokens) + else: + triangle_mask = None + + for layer in self.layer_stacks: + x, attn_weight = layer(x=x, + encoder_output=encoder_output, + encoder_mask=encoder_mask, + self_attn_mask=triangle_mask, + state=state + ) + + x = self.layer_norm(x) # batch, tgt_len, dim + x = self.output_fc(x) + feats = self.output_layer(x) + + if return_attention: + return feats, attn_weight + return feats + + def init_state(self, encoder_output, encoder_mask): + """ + 初始化一个TransformerState用于forward + + :param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 + :param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 + :return: TransformerState + """ + if isinstance(encoder_output, torch.Tensor): + encoder_output = encoder_output + elif isinstance(encoder_output, (list, tuple)): + encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 + else: + raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") + state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) + return state + + @staticmethod + def _get_triangle_mask(tokens): + tensor = tokens.new_ones(tokens.size(1), tokens.size(1)) + return torch.tril(tensor).byte() + + diff --git a/fastNLP/modules/decoder/seq2seq_state.py b/fastNLP/modules/decoder/seq2seq_state.py new file mode 100644 index 00000000..de200f86 --- /dev/null +++ b/fastNLP/modules/decoder/seq2seq_state.py @@ -0,0 +1,145 @@ +r""" +每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录 + +""" + +__all__ = [ + 'State', + "LSTMState", + "TransformerState" +] + +from typing import Union +import torch + + +class State: + def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): + """ + 每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 + + :param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch + 维度 + :param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch + 维度 + :param kwargs: + """ + self.encoder_output = encoder_output + self.encoder_mask = encoder_mask + self._decode_length = 0 + + @property + def num_samples(self): + """ + 返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 + + :return: + """ + if self.encoder_output is not None: + return self.encoder_output.size(0) + else: + return None + + @property + def decode_length(self): + """ + 当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 + + :return: + """ + return self._decode_length + + @decode_length.setter + def decode_length(self, value): + self._decode_length = value + + def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): + if isinstance(state, torch.Tensor): + state = state.index_select(index=indices, dim=dim) + elif isinstance(state, list): + for i in range(len(state)): + assert state[i] is not None + state[i] = self._reorder_state(state[i], indices, dim) + elif isinstance(state, tuple): + tmp_list = [] + for i in range(len(state)): + assert state[i] is not None + tmp_list.append(self._reorder_state(state[i], indices, dim)) + state = tuple(tmp_list) + else: + raise TypeError(f"Cannot reorder data of type:{type(state)}") + + return state + + def reorder_state(self, indices: torch.LongTensor): + if self.encoder_mask is not None: + self.encoder_mask = self._reorder_state(self.encoder_mask, indices) + if self.encoder_output is not None: + self.encoder_output = self._reorder_state(self.encoder_output, indices) + + +class LSTMState(State): + def __init__(self, encoder_output, encoder_mask, hidden, cell): + """ + LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 + + :param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 + :param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding + :param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 + :param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 + """ + super().__init__(encoder_output, encoder_mask) + self.hidden = hidden + self.cell = cell + self._input_feed = hidden[0] # 默认是上一个时刻的输出 + + @property + def input_feed(self): + """ + LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, + input_feed即上个时刻的hidden state, 否则是attention layer的输出。 + :return: torch.FloatTensor, bsz x hidden_size + """ + return self._input_feed + + @input_feed.setter + def input_feed(self, value): + self._input_feed = value + + def reorder_state(self, indices: torch.LongTensor): + super().reorder_state(indices) + self.hidden = self._reorder_state(self.hidden, indices, dim=1) + self.cell = self._reorder_state(self.cell, indices, dim=1) + if self.input_feed is not None: + self.input_feed = self._reorder_state(self.input_feed, indices, dim=0) + + +class TransformerState(State): + def __init__(self, encoder_output, encoder_mask, num_decoder_layer): + """ + 与TransformerSeq2SeqDecoder对应的State, + + :param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 + :param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend + :param int num_decoder_layer: decode有多少层 + """ + super().__init__(encoder_output, encoder_mask) + self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim + self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim + self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim + self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim + + def reorder_state(self, indices: torch.LongTensor): + super().reorder_state(indices) + self.encoder_key = self._reorder_state(self.encoder_key, indices) + self.encoder_value = self._reorder_state(self.encoder_value, indices) + self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) + self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) + + @property + def decode_length(self): + if self.decoder_prev_key[0] is not None: + return self.decoder_prev_key[0].size(1) + return 0 + + diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index 71acde33..f9a637a7 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -4,8 +4,6 @@ r""" """ __all__ = [ - # "BertModel", - "ConvolutionCharEncoder", "LSTMCharEncoder", @@ -35,10 +33,14 @@ __all__ = [ "RobertaModel", - "GPT2Model" + "GPT2Model", + + "LSTMSeq2SeqEncoder", + "TransformerSeq2SeqEncoder", + "Seq2SeqEncoder" ] -from .attention import MultiHeadAttention, BiAttention, SelfAttention +from fastNLP.modules.attention import MultiHeadAttention, BiAttention, SelfAttention from .bert import BertModel from .roberta import RobertaModel from .gpt2 import GPT2Model @@ -49,3 +51,4 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo from .star_transformer import StarTransformer from .transformer import TransformerEncoder from .variational_rnn import VarRNN, VarLSTM, VarGRU +from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 62c18d48..7a9ba57e 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -10,6 +10,7 @@ __all__ = [ import copy import json import math +import os import torch from torch import nn @@ -20,7 +21,8 @@ from ...io.file_utils import _get_bert_dir from ...core import logger -CONFIG_FILE = 'bert_config.json' +CONFIG_FILE = 'config.json' +WEIGHTS_NAME = 'pytorch_model.bin' BERT_KEY_RENAME_MAP_1 = { 'gamma': 'weight', @@ -57,7 +59,8 @@ class BertConfig(object): max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, - layer_norm_eps=1e-12): + layer_norm_eps=1e-12, + architectures='bert'): r"""Constructs BertConfig. Args: @@ -101,6 +104,7 @@ class BertConfig(object): self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps + self.architectures = architectures else: raise ValueError("First argument must be either a vocabulary size (int)" "or the path to a pretrained model config file (str)") @@ -134,9 +138,13 @@ class BertConfig(object): def to_json_file(self, json_file_path): r""" Save this instance to a json file.""" + if os.path.isdir(json_file_path): + json_file_path = os.path.join(json_file_path, CONFIG_FILE) with open(json_file_path, "w", encoding='utf-8') as writer: writer.write(self.to_json_string()) + def save_pretrained(self, save_directory): + self.to_json_file(save_directory) def gelu(x): return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) @@ -149,21 +157,6 @@ def swish(x): ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} -# class BertLayerNorm(nn.Module): -# def __init__(self, hidden_size, eps=1e-12): -# r"""Construct a layernorm module in the TF style (epsilon inside the square root). -# """ -# super(BertLayerNorm, self).__init__() -# self.weight = nn.Parameter(torch.ones(hidden_size)) -# self.bias = nn.Parameter(torch.zeros(hidden_size)) -# self.variance_epsilon = eps -# -# def forward(self, x): -# u = x.mean(-1, keepdim=True) -# s = (x - u).pow(2).mean(-1, keepdim=True) -# x = (x - u) / torch.sqrt(s + self.variance_epsilon) -# return self.weight * x + self.bias - BertLayerNorm = torch.nn.LayerNorm @@ -613,3 +606,24 @@ class BertModel(nn.Module): logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") return model + + def save_pretrained(self, save_directory): + """ 保存模型到某个folder + """ + assert os.path.isdir( + save_directory + ), "Saving path should be a directory where the model and configuration can be saved" + + # Only save the model itself if we are using distributed training + model_to_save = self.module if hasattr(self, "module") else self + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Save configuration file + model_to_save.config.save_pretrained(save_directory) + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(save_directory, WEIGHTS_NAME) + torch.save(model_to_save.state_dict(), output_model_file) + logger.debug("Model weights saved in {}".format(output_model_file)) diff --git a/fastNLP/modules/encoder/gpt2.py b/fastNLP/modules/encoder/gpt2.py index 1bb1bc12..e534fa5c 100644 --- a/fastNLP/modules/encoder/gpt2.py +++ b/fastNLP/modules/encoder/gpt2.py @@ -15,9 +15,8 @@ import math from torch.nn import CrossEntropyLoss from fastNLP.io.file_utils import _get_file_name_base_on_postfix -from ..decoder.seq2seq_decoder import Decoder, Past +from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State from ..generator.seq2seq_generator import SequenceGenerator -from typing import Tuple GELU_CONSTANT = math.sqrt(2 / math.pi) @@ -732,7 +731,7 @@ class GPT2PreTrainedModel(nn.Module): bos_token_id=bos_token_id, eos_token_id=eos_token_ids, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) - results = generator.generate(input_ids, past=None) + results = generator.generate(tokens=input_ids, state=GPT2State()) return results @@ -788,21 +787,13 @@ class GPT2Model(GPT2PreTrainedModel): for layer, heads in heads_to_prune.items(): self.h[layer].attn.prune_heads(heads) - def forward( - self, - input_ids, - past=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - output_attentions=True - ): + def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, + head_mask=None, output_attentions=True): """ :param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 - :param GPT2Past past: 之前的状态 - :param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与past的concat一样大。 + :param GPT2State state: 之前的状态 + :param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与state的concat一样大。 为0的地方为padding。 :param torch.LongTensor token_type_ids: batch_size x max_len。 :param torch.LongTensor position_ids: 与input_ids对应的位置 @@ -818,11 +809,11 @@ class GPT2Model(GPT2PreTrainedModel): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) - if past is None or len(past)==0: + if state is None or len(state)==0: past_length = 0 - past = [None] * len(self.h) # len(self.h) 是layer的层数 + state = [None] * len(self.h) # len(self.h) 是layer的层数 else: - past_length = past[0][0].size(-2) + past_length = state[0][0].size(-2) if position_ids is None: # 如果没有position id则生成 device = input_ids.device position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) @@ -880,7 +871,7 @@ class GPT2Model(GPT2PreTrainedModel): presents = () all_attentions = [] all_hidden_states = () - for i, (block, layer_past) in enumerate(zip(self.h, past)): + for i, (block, layer_past) in enumerate(zip(self.h, state)): all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) outputs = block( @@ -915,56 +906,63 @@ class GPT2Model(GPT2PreTrainedModel): return outputs # last hidden state, (presents), (all hidden_states), (attentions) -class GPT2Past(Past): +class GPT2State(State): def __init__(self): - super().__init__() - self.past = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim] + super().__init__(None, None) + self.state = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim] + @property def num_samples(self): - if self.past is not None: - return self.past[0].size(1) + if self.state is not None: + return self.state[0].size(1) return None - def reorder_past(self, indices): - for i in range(len(self.past)): - assert self.past[i] is not None - self.past[i] = self.past[i].index_select(index=indices, dim=1) + @property + def decode_length(self): + if self.state is None: + return 0 + return self.state[0].size(-2) + + def reorder_state(self, indices): + if self.state: + for i in range(len(self.state)): + assert self.state[i] is not None + self.state[i] = self.state[i].index_select(index=indices, dim=1) def __iter__(self): - for p in self.past: + for p in self.state: yield p def __getitem__(self, item): assert isinstance(item, int) - return self.past[item] + return self.state[item] def __len__(self): - if self.past is not None: - return len(self.past) + if self.state is not None: + return len(self.state) return 0 -class _GPT2Decoder(Decoder): +class _GPT2Decoder(Seq2SeqDecoder): + """ + 用于wrap GPT2是的可以在SequenceGenerator中使用 + """ def __init__(self, gpt_model): super().__init__() self.gpt_model = gpt_model - def decode(self, tokens, past=None) -> Tuple[torch.Tensor, Past]: - if past is None: - past = GPT2Past() - lm_logits, presents, _ = self.gpt_model(input_ids=tokens, - past=past, + def decode(self, tokens, state=None) -> torch.Tensor: + if state is None: + state = GPT2State() + lm_logits, presents, _ = self.gpt_model(input_ids=tokens[:, state.decode_length:], + state=state, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, output_attentions=False) - past.past = list(presents) - return lm_logits[:, -1], past - - def reorder_past(self, indices: torch.LongTensor, past: GPT2Past) -> GPT2Past: - past.reorder_past(indices) - return past + state.state = list(presents) + return lm_logits[:, -1] class GPT2LMHeadModel(GPT2PreTrainedModel): @@ -1008,21 +1006,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def get_input_embeddings(self): return self.transformer.wte - def forward( - self, - input_ids, - past=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - labels=None, - output_attentions=False - ): + def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, + head_mask=None, labels=None, output_attentions=False): """ :param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 - :param tuple past: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入 + :param tuple state: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入 :param torch.ByteTensor attention_mask: batch_size x max_len, 与input_ids一样大。为0的地方为padding。 :param torch.LongTensor token_type_ids: batch_size x max_len。 :param torch.LongTensor position_ids: 与input_ids对应的位置 @@ -1034,7 +1023,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): """ transformer_outputs = self.transformer( input_ids, - past=past, + state=state, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/fastNLP/modules/encoder/seq2seq_encoder.py b/fastNLP/modules/encoder/seq2seq_encoder.py new file mode 100644 index 00000000..c38fa896 --- /dev/null +++ b/fastNLP/modules/encoder/seq2seq_encoder.py @@ -0,0 +1,189 @@ +import torch.nn as nn +import torch +from torch.nn import LayerNorm +import torch.nn.functional as F +from typing import Union, Tuple +from ...core.utils import seq_len_to_mask +import math +from ...modules.encoder.lstm import LSTM +from fastNLP.modules.attention import MultiHeadAttention +from ...embeddings import StaticEmbedding +from ...embeddings.utils import get_embeddings + + +class Seq2SeqEncoder(nn.Module): + """ + 所有Sequence2Sequence Encoder的基类。需要实现forward函数 + + """ + def __init__(self): + super().__init__() + + def forward(self, tokens, seq_len): + """ + + :param torch.LongTensor tokens: bsz x max_len, encoder的输入 + :param torch.LongTensor seq_len: bsz + :return: + """ + raise NotImplementedError + + +class TransformerSeq2SeqEncoderLayer(nn.Module): + def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, + dropout: float = 0.1): + """ + Self-Attention的Layer, + + :param int d_model: input和output的输出维度 + :param int n_head: 多少个head,每个head的维度为d_model/n_head + :param int dim_ff: FFN的维度大小 + :param float dropout: Self-attention和FFN的dropout大小,0表示不drop + """ + super(TransformerSeq2SeqEncoderLayer, self).__init__() + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.self_attn = MultiHeadAttention(d_model, n_head, dropout) + self.attn_layer_norm = LayerNorm(d_model) + self.ffn_layer_norm = LayerNorm(d_model) + + self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(self.dim_ff, self.d_model), + nn.Dropout(dropout)) + + def forward(self, x, mask): + """ + + :param x: batch x src_seq x d_model + :param mask: batch x src_seq,为0的地方为padding + :return: + """ + # attention + residual = x + x = self.attn_layer_norm(x) + x, _ = self.self_attn(query=x, + key=x, + value=x, + key_mask=mask) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # ffn + residual = x + x = self.ffn_layer_norm(x) + x = self.ffn(x) + x = residual + x + + return x + + +class TransformerSeq2SeqEncoder(Seq2SeqEncoder): + def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, + num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): + """ + 基于Transformer的Encoder + + :param embed: encoder输入token的embedding + :param nn.Module pos_embed: position embedding + :param int num_layers: 多少层的encoder + :param int d_model: 输入输出的维度 + :param int n_head: 多少个head + :param int dim_ff: FFN中间的维度大小 + :param float dropout: Attention和FFN的dropout大小 + """ + super(TransformerSeq2SeqEncoder, self).__init__() + self.embed = get_embeddings(embed) + self.embed_scale = math.sqrt(d_model) + self.pos_embed = pos_embed + self.num_layers = num_layers + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) + self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) + for _ in range(num_layers)]) + self.layer_norm = LayerNorm(d_model) + + def forward(self, tokens, seq_len): + """ + + :param tokens: batch x max_len + :param seq_len: [batch] + :return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) + """ + x = self.embed(tokens) * self.embed_scale # batch, seq, dim + batch_size, max_src_len, _ = x.size() + device = x.device + if self.pos_embed is not None: + position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) + x += self.pos_embed(position) + + x = self.input_fc(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + encoder_mask = seq_len_to_mask(seq_len) + encoder_mask = encoder_mask.to(device) + + for layer in self.layer_stacks: + x = layer(x, encoder_mask) + + x = self.layer_norm(x) + + return x, encoder_mask + + +class LSTMSeq2SeqEncoder(Seq2SeqEncoder): + def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, + hidden_size = 400, dropout = 0.3, bidirectional=True): + """ + LSTM的Encoder + + :param embed: encoder的token embed + :param int num_layers: 多少层 + :param int hidden_size: LSTM隐藏层、输出的大小 + :param float dropout: LSTM层之间的Dropout是多少 + :param bool bidirectional: 是否使用双向 + """ + super().__init__() + self.embed = get_embeddings(embed) + self.num_layers = num_layers + self.dropout = dropout + self.hidden_size = hidden_size + self.bidirectional = bidirectional + hidden_size = hidden_size//2 if bidirectional else hidden_size + self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, + batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) + + def forward(self, tokens, seq_len): + """ + + :param torch.LongTensor tokens: bsz x max_len + :param torch.LongTensor seq_len: bsz + :return: (output, (hidden, cell)), encoder_mask + output: bsz x max_len x hidden_size, + hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 + encoder_mask: bsz x max_len, 为0的地方是padding + """ + x = self.embed(tokens) + device = x.device + x, (final_hidden, final_cell) = self.lstm(x, seq_len) + encoder_mask = seq_len_to_mask(seq_len).to(device) + + # x: batch,seq_len,dim; h/c: num_layers*2,batch,dim + + if self.bidirectional: + final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input + final_cell = self.concat_bidir(final_cell) + + return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return + + def concat_bidir(self, input): + output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2) + return output.reshape(self.num_layers, input.size(1), -1) diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index fe8d94cd..3597c1be 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -5,7 +5,7 @@ __all__ = [ ] from torch import nn -from .attention import MultiHeadAttention +from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer class TransformerEncoder(nn.Module): @@ -13,66 +13,30 @@ class TransformerEncoder(nn.Module): transformer的encoder模块,不包含embedding层 """ + def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): + """ - class SubLayer(nn.Module): - def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): - super(TransformerEncoder.SubLayer, self).__init__() - self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) - self.norm1 = nn.LayerNorm(model_size, eps=1e-6) - self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), - nn.ReLU(), - nn.Dropout(dropout), - nn.Linear(inner_size, model_size)) - self.norm2 = nn.LayerNorm(model_size, eps=1e-6) - self.dropout = nn.Dropout(dropout) - - def forward(self, input, seq_mask=None, atte_mask_out=None): - r""" - - :param input: [batch, seq_len, model_size] - :param seq_mask: [batch, seq_len] - :return: [batch, seq_len, model_size] - """ - if seq_mask is None: # 防止后续乘法时出错 - seq_mask = 1 - input = self.norm1(input) - attention = self.atte(input, input, input, atte_mask_out) - input = input + self.dropout(attention) - attention *= seq_mask - input = self.norm2(input) - output = self.ffn(input) - input = input + self.dropout(output) - input *= seq_mask - return input - - def __init__(self, num_layers, **kargs): - r""" - - :param int num_layers: transformer的层数 - :param int model_size: 输入维度的大小。同时也是输出维度的大小。 - :param int inner_size: FFN层的hidden大小 - :param int key_size: 每个head的维度大小。 - :param int value_size: 每个head中value的维度。 - :param int num_head: head的数量。 - :param float dropout: dropout概率. Default: 0.1 + :param int num_layers: 多少层Transformer + :param int d_model: input和output的大小 + :param int n_head: 多少个head + :param int dim_ff: FFN中间hidden大小 + :param float dropout: 多大概率drop attention和ffn中间的表示 """ super(TransformerEncoder, self).__init__() - self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) - self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6) + self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, + dropout = dropout) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, x, seq_mask=None): r""" :param x: [batch, seq_len, model_size] 输入序列 - :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. + :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend Default: ``None`` :return: [batch, seq_len, model_size] 输出序列 """ output = x if seq_mask is None: - atte_mask_out = None - else: - atte_mask_out = (seq_mask.eq(False))[:, None, :] - seq_mask = seq_mask[:, :, None] + seq_mask = x.new_ones(x.size(0), x.size(1)).bool() for layer in self.layers: - output = layer(output, seq_mask, atte_mask_out) + output = layer(output, seq_mask) return self.norm(output) diff --git a/fastNLP/modules/generator/__init__.py b/fastNLP/modules/generator/__init__.py index e69de29b..512a95d7 100644 --- a/fastNLP/modules/generator/__init__.py +++ b/fastNLP/modules/generator/__init__.py @@ -0,0 +1,9 @@ +r""" + +""" + +__all__ = [ + "SequenceGenerator" +] + +from .seq2seq_generator import SequenceGenerator \ No newline at end of file diff --git a/fastNLP/modules/generator/seq2seq_generator.py b/fastNLP/modules/generator/seq2seq_generator.py old mode 100755 new mode 100644 index e09de85a..e6115407 --- a/fastNLP/modules/generator/seq2seq_generator.py +++ b/fastNLP/modules/generator/seq2seq_generator.py @@ -7,16 +7,35 @@ __all__ = [ ] import torch -from ..decoder.seq2seq_decoder import Decoder +from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State import torch.nn.functional as F -from fastNLP.core.utils import _get_model_device +from ...core.utils import _get_model_device from functools import partial class SequenceGenerator: - def __init__(self, decoder: Decoder, max_length=20, num_beams=1, + """ + 给定一个Seq2SeqDecoder,decode出句子 + + """ + def __init__(self, decoder: Seq2SeqDecoder, max_length=20, num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1, length_penalty=1.0, pad_token_id=0): + """ + + :param Seq2SeqDecoder decoder: Decoder对象 + :param int max_length: 句子的最大长度 + :param int num_beams: beam search的大小 + :param bool do_sample: 是否通过采样的方式生成 + :param float temperature: 只有在do_sample为True才有意义 + :param int top_k: 只从top_k中采样 + :param float top_p: 只从top_p的token中采样,nucles sample + :param int,None bos_token_id: 句子开头的token id + :param int,None eos_token_id: 句子结束的token id + :param float repetition_penalty: 多大程度上惩罚重复的token + :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 + :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 + """ if do_sample: self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, @@ -40,19 +59,19 @@ class SequenceGenerator: self.decoder = decoder @torch.no_grad() - def generate(self, tokens=None, past=None): + def generate(self, state, tokens=None): """ - :param torch.LongTensor tokens: batch_size x length, 开始的token - :param past: - :return: + :param State state: encoder结果的State, 是与Decoder配套是用的 + :param torch.LongTensor,None tokens: batch_size x length, 开始的token + :return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id """ - # TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode? - return self.generate_func(tokens=tokens, past=past) + + return self.generate_func(tokens=tokens, state=state) @torch.no_grad() -def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, +def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1, length_penalty=1.0): """ @@ -60,23 +79,23 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, :param Decoder decoder: Decoder对象 :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 - :param Past past: 应该包好encoder的一些输出。 + :param State state: 应该包含encoder的一些输出。 :param int max_length: 生成句子的最大长度。 :param int num_beams: 使用多大的beam进行解码。 :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 - :param int pad_token_id: + :param int pad_token_id: pad的token id :param float repetition_penalty: 对重复出现的token多大的惩罚。 :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 :return: """ if num_beams == 1: - token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1, + token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=1, top_k=50, top_p=1, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) else: - token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, + token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, temperature=1, top_k=50, top_p=1, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty, @@ -86,7 +105,7 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, @torch.no_grad() -def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, +def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, length_penalty=1.0): """ @@ -94,7 +113,7 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, :param Decoder decoder: Decoder对象 :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 - :param Past past: 应该包好encoder的一些输出。 + :param State state: 应该包含encoder的一些输出。 :param int max_length: 生成句子的最大长度。 :param int num_beam: 使用多大的beam进行解码。 :param float temperature: 采样时的退火大小 @@ -109,13 +128,13 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, """ # 每个位置在生成的时候会sample生成 if num_beams == 1: - token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature, + token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) else: - token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, + token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, repetition_penalty=repetition_penalty, length_penalty=length_penalty, @@ -123,40 +142,35 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, return token_ids -def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50, +def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): device = _get_model_device(decoder) if tokens is None: if bos_token_id is None: raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") - if past is None: - raise RuntimeError("You have to specify either `past` or `tokens`.") - batch_size = past.num_samples() + batch_size = state.num_samples if batch_size is None: - raise RuntimeError("Cannot infer the number of samples from `past`.") + raise RuntimeError("Cannot infer the number of samples from `state`.") tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) batch_size = tokens.size(0) - if past is not None: - assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." + if state.num_samples: + assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." if eos_token_id is None: - _eos_token_id = float('nan') + _eos_token_id = -1 else: _eos_token_id = eos_token_id - # for i in range(tokens.size(1)): - # scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past - scores, past = decoder.decode(tokens, past) - - token_ids = tokens.clone() + scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state + next_tokens = scores.argmax(dim=-1, keepdim=True) + token_ids = torch.cat([tokens, next_tokens], dim=1) cur_len = token_ids.size(1) dones = token_ids.new_zeros(batch_size).eq(1) # tokens = tokens[:, -1:] while cur_len < max_length: - # scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past - scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past + scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) @@ -204,7 +218,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt return token_ids -def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0, +def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, num_beams=4, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: # 进行beam search @@ -212,21 +226,20 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 if tokens is None: if bos_token_id is None: raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") - if past is None: - raise RuntimeError("You have to specify either `past` or `tokens`.") - batch_size = past.num_samples() + batch_size = state.num_samples if batch_size is None: - raise RuntimeError("Cannot infer the number of samples from `past`.") + raise RuntimeError("Cannot infer the number of samples from `state`.") tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) batch_size = tokens.size(0) - if past is not None: - assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." - - # for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode - # scores, past = decoder.decode_one(tokens[:, :i + 1], - # past) # (batch_size, vocab_size), Past - # scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 - scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 + if state.num_samples: + assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." + + if eos_token_id is None: + _eos_token_id = -1 + else: + _eos_token_id = eos_token_id + + scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 vocab_size = scores.size(1) assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." @@ -240,15 +253,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 # 得到(batch_size, num_beams), (batch_size, num_beams) next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) + # 根据index来做顺序的调转 indices = torch.arange(batch_size, dtype=torch.long).to(device) indices = indices.repeat_interleave(num_beams) - decoder.reorder_past(indices, past) + state.reorder_state(indices) tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length # 记录生成好的token (batch_size', cur_len) token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) dones = [False] * batch_size - tokens = next_tokens.view(-1, 1) beam_scores = next_scores.view(-1) # batch_size * num_beams @@ -262,8 +275,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) while cur_len < max_length: - # scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past - scores, past = decoder.decode(tokens, past) + scores = decoder.decode(token_ids, state) if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) lt_zero_mask = token_scores.lt(0).float() @@ -307,7 +319,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 next_tokens = next_tokens.gather(dim=1, index=sorted_inds) from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) - not_eos_mask = next_tokens.ne(eos_token_id) # 为1的地方不是eos + not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 @@ -316,18 +328,18 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) beam_scores = _next_scores.view(-1) - # 更改past状态, 重组token_ids + # 更改state状态, 重组token_ids reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 - decoder.reorder_past(reorder_inds, past) + state.reorder_state(reorder_inds) flag = True - if cur_len + 1 == max_length: + if cur_len+1 == max_length: eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 else: # 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 - effective_eos_mask = next_tokens[:, :num_beams].eq(eos_token_id) # batch_size x num_beams + effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams if effective_eos_mask.sum().gt(0): eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) # 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams @@ -335,16 +347,17 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos else: flag = False + + # 重新组织token_ids的状态 + tokens = _next_tokens + token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) + if flag: for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()): if not dones[batch_idx]: score = next_scores[batch_idx, beam_ind].item() - hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) - - # 重新组织token_ids的状态 - tokens = _next_tokens - token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) + hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len+1].clone(), score) for batch_idx in range(batch_size): dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) @@ -360,15 +373,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 for i, hypotheses in enumerate(hypos): best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] - tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol + tgt_len[i] = len(best_hyp) # +1 for the symbol best.append(best_hyp) # generate target batch decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) for i, hypo in enumerate(best): - decoded[i, :tgt_len[i] - 1] = hypo + decoded[i, :tgt_len[i]] = hypo if eos_token_id is not None: - decoded[i, tgt_len[i] - 1] = eos_token_id + decoded[i, tgt_len[i] - 1] = _eos_token_id return decoded diff --git a/fastNLP/modules/tokenizer/bert_tokenizer.py b/fastNLP/modules/tokenizer/bert_tokenizer.py index ee541c50..f71f1093 100644 --- a/fastNLP/modules/tokenizer/bert_tokenizer.py +++ b/fastNLP/modules/tokenizer/bert_tokenizer.py @@ -384,6 +384,9 @@ class BertTokenizer(object): index += 1 return vocab_file + def save_pretrained(self, save_directory): + self.save_vocabulary(save_directory) + @classmethod def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): r""" diff --git a/fastNLP/modules/tokenizer/gpt2_tokenizer.py b/fastNLP/modules/tokenizer/gpt2_tokenizer.py index 9cfa8f2c..6bf6ce67 100644 --- a/fastNLP/modules/tokenizer/gpt2_tokenizer.py +++ b/fastNLP/modules/tokenizer/gpt2_tokenizer.py @@ -377,6 +377,9 @@ class GPT2Tokenizer: text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) return text + def save_pretrained(self, save_directory): + return self.save_vocabulary(save_directory) + def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" if not os.path.isdir(save_directory): diff --git a/reproduction/Summarization/Baseline/model/TForiginal.py b/reproduction/Summarization/Baseline/model/TForiginal.py index a08a9213..d1444150 100644 --- a/reproduction/Summarization/Baseline/model/TForiginal.py +++ b/reproduction/Summarization/Baseline/model/TForiginal.py @@ -32,7 +32,6 @@ from tools.PositionEmbedding import get_sinusoid_encoding_table from tools.logger import * from fastNLP.core.const import Const -from fastNLP.modules.encoder.transformer import TransformerEncoder from transformer.Layers import EncoderLayer diff --git a/reproduction/Summarization/Baseline/model/TransformerModel.py b/reproduction/Summarization/Baseline/model/TransformerModel.py index 0a30f36d..4d314f84 100644 --- a/reproduction/Summarization/Baseline/model/TransformerModel.py +++ b/reproduction/Summarization/Baseline/model/TransformerModel.py @@ -30,7 +30,7 @@ from .Encoder import Encoder from tools.PositionEmbedding import get_sinusoid_encoding_table from fastNLP.core.const import Const -from fastNLP.modules.encoder.transformer import TransformerEncoder +from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoderLayer class TransformerModel(nn.Module): def __init__(self, hps, vocab): @@ -68,7 +68,8 @@ class TransformerModel(nn.Module): get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) self.layer_stack = nn.ModuleList([ - TransformerEncoder.SubLayer(model_size=self.hidden_size, inner_size=self.d_inner, key_size=self.d_k, value_size=self.d_v,num_head=self.n_head, dropout=hps.atten_dropout_prob) + TransformerSeq2SeqEncoderLayer(d_model = self.hidden_size, n_head = self.n_head, dim_ff = self.d_inner, + dropout = hps.atten_dropout_prob) for _ in range(self.num_layers)]) self.wh = nn.Linear(self.hidden_size, 2) @@ -109,7 +110,7 @@ class TransformerModel(nn.Module): for enc_layer in self.layer_stack: # enc_output = [batch_size, N, hidden_size = n_head * d_v] # enc_slf_attn = [n_head * batch_size, N, N] - enc_input = enc_layer(enc_input, seq_mask=self.non_pad_mask, atte_mask_out=self.slf_attn_mask) + enc_input = enc_layer(enc_input, encoder_mask=self.slf_attn_mask) enc_input_list += [enc_input] self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] diff --git a/reproduction/Summarization/Baseline/train_origin.py b/reproduction/Summarization/Baseline/train_origin.py index 7c4d2f12..e1248025 100644 --- a/reproduction/Summarization/Baseline/train_origin.py +++ b/reproduction/Summarization/Baseline/train_origin.py @@ -265,7 +265,7 @@ def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): label = Variable(label) input_len = Variable(input_len, requires_grad=False) - model_outputs = model.forward(input,input_len) # [batch, N, 2] + model_outputs = model.forward(input, input_len) # [batch, N, 2] outputs = model_outputs["p_sent"] prediction = model_outputs["prediction"] diff --git a/reproduction/Summarization/Baseline/train_transformer.py b/reproduction/Summarization/Baseline/train_transformer.py index 50d05f5c..e838a803 100644 --- a/reproduction/Summarization/Baseline/train_transformer.py +++ b/reproduction/Summarization/Baseline/train_transformer.py @@ -264,7 +264,7 @@ def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): label = Variable(label) input_len = Variable(input_len, requires_grad=False) - model_outputs = model.forward(input,input_len) # [batch, N, 2] + model_outputs = model.forward(input, input_len) # [batch, N, 2] outputs = model_outputs[Const.OUTPUTS] prediction = model_outputs["prediction"] diff --git a/reproduction/Summarization/Baseline/transformer/Models.py b/reproduction/Summarization/Baseline/transformer/Models.py index d323e785..2d928f96 100644 --- a/reproduction/Summarization/Baseline/transformer/Models.py +++ b/reproduction/Summarization/Baseline/transformer/Models.py @@ -7,10 +7,12 @@ from transformer.Layers import EncoderLayer, DecoderLayer __author__ = "Yu-Hsiang Huang" + def get_non_pad_mask(seq): assert seq.dim() == 2 return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) + def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): ''' Sinusoid position encoding table ''' @@ -31,6 +33,7 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): return torch.FloatTensor(sinusoid_table) + def get_attn_key_pad_mask(seq_k, seq_q): ''' For masking out the padding part of key sequence. ''' @@ -41,6 +44,7 @@ def get_attn_key_pad_mask(seq_k, seq_q): return padding_mask + def get_subsequent_mask(seq): ''' For masking out the subsequent info. ''' @@ -51,6 +55,7 @@ def get_subsequent_mask(seq): return subsequent_mask + class Encoder(nn.Module): ''' A encoder model with self attention mechanism. ''' @@ -98,6 +103,7 @@ class Encoder(nn.Module): return enc_output, enc_slf_attn_list return enc_output, + class Decoder(nn.Module): ''' A decoder model with self attention mechanism. ''' @@ -152,6 +158,7 @@ class Decoder(nn.Module): return dec_output, dec_slf_attn_list, dec_enc_attn_list return dec_output, + class Transformer(nn.Module): ''' A sequence to sequence model with attention mechanism. ''' @@ -181,8 +188,8 @@ class Transformer(nn.Module): nn.init.xavier_normal_(self.tgt_word_prj.weight) assert d_model == d_word_vec, \ - 'To facilitate the residual connections, \ - the dimensions of all module outputs shall be the same.' + 'To facilitate the residual connections, \ + the dimensions of all module outputs shall be the same.' if tgt_emb_prj_weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer @@ -194,7 +201,7 @@ class Transformer(nn.Module): if emb_src_tgt_weight_sharing: # Share the weight matrix between source & target word embeddings assert n_src_vocab == n_tgt_vocab, \ - "To share word embedding table, the vocabulary size of src/tgt shall be the same." + "To share word embedding table, the vocabulary size of src/tgt shall be the same." self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): diff --git a/reproduction/multi-criteria-cws/models.py b/reproduction/multi-criteria-cws/models.py index e83a5375..92c93175 100644 --- a/reproduction/multi-criteria-cws/models.py +++ b/reproduction/multi-criteria-cws/models.py @@ -1,7 +1,6 @@ import fastNLP import torch import math -from fastNLP.modules.encoder.transformer import TransformerEncoder from fastNLP.modules.decoder.crf import ConditionalRandomField from fastNLP import Const import copy @@ -181,7 +180,6 @@ def make_CWS( freeze=True, ): c = copy.deepcopy - # encoder=TransformerEncoder(num_layers=N,model_size=d_model,inner_size=d_ff,key_size=d_model//h,value_size=d_model//h,num_head=h,dropout=dropout) encoder = transformer.make_encoder( N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff ) diff --git a/reproduction/text_classification/model/lstm_self_attention.py b/reproduction/text_classification/model/lstm_self_attention.py index 9a39049d..b79cb1b0 100644 --- a/reproduction/text_classification/model/lstm_self_attention.py +++ b/reproduction/text_classification/model/lstm_self_attention.py @@ -1,9 +1,8 @@ -import torch import torch.nn as nn from fastNLP.core.const import Const as C from fastNLP.modules.encoder.lstm import LSTM from fastNLP.embeddings.utils import get_embeddings -from fastNLP.modules.encoder.attention import SelfAttention +from fastNLP.modules.attention import SelfAttention from fastNLP.modules.decoder.mlp import MLP diff --git a/reproduction/text_classification/model/weight_drop.py b/reproduction/text_classification/model/weight_drop.py index 60fda179..688c8d54 100644 --- a/reproduction/text_classification/model/weight_drop.py +++ b/reproduction/text_classification/model/weight_drop.py @@ -44,7 +44,7 @@ class WeightDrop(torch.nn.Module): def forward(self, *args): self._setweights() - return self.module.forward(*args) + return self.module.forward() if __name__ == '__main__': import torch diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index 26126711..2e619bcb 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -40,8 +40,7 @@ class TestBertEmbedding(unittest.TestCase): result = embed(words) self.assertEqual(result.size(), (1, 4, 16)) - embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, - only_use_pretrain_bpe=True) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) embed.eval() words = torch.LongTensor([[2, 3, 4, 0]]) result = embed(words) @@ -49,53 +48,30 @@ class TestBertEmbedding(unittest.TestCase): # 自动截断而不报错 embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, - only_use_pretrain_bpe=True, auto_truncate=True) + auto_truncate=True) words = torch.LongTensor([[2, 3, 4, 1]*10, [2, 3]+[0]*38]) result = embed(words) self.assertEqual(result.size(), (2, 40, 16)) - def test_bert_embedding_2(self): - # 测试only_use_pretrain_vocab与truncate_embed是否正常工作 - with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f: - num_word = len(f.readlines()) - Embedding = BertEmbedding - vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split()) - embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', - only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) - embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS] - self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab)) - - embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', - only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1) - embed_bpe_vocab_size = num_word # 排除NotInBERT - self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab)) - - embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', - only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1) - embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS] - self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab)) - - embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', - only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1) - embed_bpe_vocab_size = num_word+1 # 新增##a - self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab)) - - # 测试各种情况下以下tensor的值是相等的 - embed1.eval() - embed2.eval() - embed3.eval() - embed4.eval() - tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]]) - t1 = embed1(tensor) - t2 = embed2(tensor) - t3 = embed3(tensor) - t4 = embed4(tensor) - - self.assertEqual((t1-t2).sum(), 0) - self.assertEqual((t1-t3).sum(), 0) - self.assertEqual((t1-t4).sum(), 0) + def test_save_load(self): + bert_save_test = 'bert_save_test' + try: + os.makedirs(bert_save_test, exist_ok=True) + vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, + auto_truncate=True) + + embed.save(bert_save_test) + load_embed = BertEmbedding.load(bert_save_test) + words = torch.randint(len(vocab), size=(2, 20)) + embed.eval(), load_embed.eval() + self.assertEqual((embed(words) - load_embed(words)).sum(), 0) + + finally: + import shutil + shutil.rmtree(bert_save_test) class TestBertWordPieceEncoder(unittest.TestCase): @@ -120,11 +96,30 @@ class TestBertWordPieceEncoder(unittest.TestCase): ds.set_input('words') words = torch.LongTensor(ds['words'].get([0, 1])) embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', - pool_method='first', include_cls_sep=True, pooled_cls=False) + pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) embed.eval() words_res = embed(words) # 检查word piece什么的是正常work的 self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) - self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) \ No newline at end of file + self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) + + def test_save_load(self): + bert_save_test = 'bert_save_test' + try: + os.makedirs(bert_save_test, exist_ok=True) + embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.0, + layers='-2') + ds = DataSet({'words': ["this is a test . [SEP]".split()]}) + embed.index_datasets(ds, field_name='words') + self.assertTrue(ds.has_field('word_pieces')) + words = torch.LongTensor([[1, 2, 3, 4]]) + embed.save(bert_save_test) + load_embed = BertWordPieceEncoder.load(bert_save_test) + embed.eval(), load_embed.eval() + self.assertEqual((embed(words) - load_embed(words)).sum(), 0) + finally: + import shutil + shutil.rmtree(bert_save_test) + diff --git a/test/embeddings/test_gpt2_embedding.py b/test/embeddings/test_gpt2_embedding.py index 01e00410..d31f20bc 100644 --- a/test/embeddings/test_gpt2_embedding.py +++ b/test/embeddings/test_gpt2_embedding.py @@ -255,14 +255,17 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): result = embed(torch.LongTensor([[1, 2, 3, 4]])) def test_generate(self): - weight_path = 'test/data_for_tests/embedding/small_gpt2' + # weight_path = 'test/data_for_tests/embedding/small_gpt2' + weight_path = 'en' encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True) # 测试一下各项东西是否正常work - print(encoder.generate_from_str('this', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, + print(encoder.generate_from_str('This', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, + repetition_penalty=1.0, length_penalty=1.0)) + print(encoder.generate_from_str('This day', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, repetition_penalty=1.0, length_penalty=1.0)) - print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, + print(encoder.generate_from_str('This', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, repetition_penalty=1.0, length_penalty=1.0)) - print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, + print(encoder.generate_from_str('This', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, repetition_penalty=2.0, length_penalty=1.5)) diff --git a/test/embeddings/test_roberta_embedding.py b/test/embeddings/test_roberta_embedding.py index c2e80a8a..4cfc1ca3 100644 --- a/test/embeddings/test_roberta_embedding.py +++ b/test/embeddings/test_roberta_embedding.py @@ -47,7 +47,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): ds.set_input('words') words = torch.LongTensor(ds['words'].get([0, 1])) embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, - pool_method='first', include_cls_sep=True, pooled_cls=False) + pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) embed.eval() words_res = embed(words) @@ -183,6 +183,24 @@ class TestRobertWordPieceEncoder(unittest.TestCase): torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin') print(model(torch.LongTensor([[0,1,2,3]]))) + def test_save_load(self): + bert_save_test = 'roberta_save_test' + try: + os.makedirs(bert_save_test, exist_ok=True) + embed = RobertaWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_roberta', word_dropout=0.0, + layers='-2') + ds = DataSet({'words': ["this is a test . [SEP]".split()]}) + embed.index_datasets(ds, field_name='words') + self.assertTrue(ds.has_field('word_pieces')) + words = torch.LongTensor([[1, 2, 3, 4]]) + embed.save(bert_save_test) + load_embed = RobertaWordPieceEncoder.load(bert_save_test) + embed.eval(), load_embed.eval() + self.assertEqual((embed(words) - load_embed(words)).sum(), 0) + finally: + import shutil + shutil.rmtree(bert_save_test) + class TestRobertaEmbedding(unittest.TestCase): def test_roberta_embedding_1(self): @@ -250,3 +268,20 @@ class TestRobertaEmbedding(unittest.TestCase): self.assertEqual((t1-t2).sum(), 0) self.assertEqual((t1-t3).sum(), 0) self.assertEqual((t1-t4).sum(), 0) + + def test_save_load(self): + bert_save_test = 'roberta_save_test' + try: + os.makedirs(bert_save_test, exist_ok=True) + vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) + embed = RobertaEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_roberta', + word_dropout=0.1, + auto_truncate=True) + embed.save(bert_save_test) + load_embed = RobertaEmbedding.load(bert_save_test) + words = torch.randint(len(vocab), size=(2, 20)) + embed.eval(), load_embed.eval() + self.assertEqual((embed(words) - load_embed(words)).sum(), 0) + finally: + import shutil + shutil.rmtree(bert_save_test) diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index 755bb5cd..2b10a2d0 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -108,6 +108,56 @@ class TestLoad(unittest.TestCase): for v1i, v2i in zip(v1, v2): self.assertAlmostEqual(v1i, v2i, places=4) + def test_save_load_static_embed(self): + static_test_folder = 'static_save_test' + try: + # 测试包含no_create_entry + os.makedirs(static_test_folder, exist_ok=True) + + vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) + vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt') + embed.save(static_test_folder) + load_embed = StaticEmbedding.load(static_test_folder) + words = torch.randint(len(vocab), size=(2, 20)) + self.assertEqual((embed(words) - load_embed(words)).sum(), 0) + + # 测试不包含no_create_entry + vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt') + embed.save(static_test_folder) + load_embed = StaticEmbedding.load(static_test_folder) + words = torch.randint(len(vocab), size=(2, 20)) + self.assertEqual((embed(words) - load_embed(words)).sum(), 0) + + # 测试lower, min_freq + vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt', min_freq=2, lower=True) + embed.save(static_test_folder) + load_embed = StaticEmbedding.load(static_test_folder) + words = torch.randint(len(vocab), size=(2, 20)) + self.assertEqual((embed(words) - load_embed(words)).sum(), 0) + + # 测试random的embedding + vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) + vocab = vocab.add_word_lst(['b'], no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=4, min_freq=2, lower=True, + normalize=True) + embed.weight.data += 0.2 # 使得它不是normalize + embed.save(static_test_folder) + load_embed = StaticEmbedding.load(static_test_folder) + words = torch.randint(len(vocab), size=(2, 20)) + self.assertEqual((embed(words) - load_embed(words)).sum(), 0) + + finally: + if os.path.isdir(static_test_folder): + import shutil + shutil.rmtree(static_test_folder) + + def read_static_embed(fp): """ diff --git a/test/io/loader/test_classification_loader.py b/test/io/loader/test_classification_loader.py index f4ecd47d..72db136c 100644 --- a/test/io/loader/test_classification_loader.py +++ b/test/io/loader/test_classification_loader.py @@ -30,11 +30,11 @@ class TestLoad(unittest.TestCase): 'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), 'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), - 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (7, 6, 6), False), + 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False), } for k, v in data_set_dict.items(): path, loader, data_set, warns = v - with self.subTest(loader=loader): + with self.subTest(path=path): if warns: with self.assertWarns(Warning): data_bundle = loader().load(path) @@ -45,5 +45,6 @@ class TestLoad(unittest.TestCase): self.assertEqual(len(data_set), data_bundle.num_dataset) for x, y in zip(data_set, data_bundle.iter_datasets()): name, dataset = y - self.assertEqual(x, len(dataset)) + with self.subTest(split=name): + self.assertEqual(x, len(dataset)) diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py index 70367f6d..d2b221c5 100644 --- a/test/io/loader/test_matching_loader.py +++ b/test/io/loader/test_matching_loader.py @@ -32,7 +32,7 @@ class TestMatchingLoad(unittest.TestCase): 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), - 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), + 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 6, 5), False), } for k, v in data_set_dict.items(): path, loader, instance, warns = v @@ -46,5 +46,6 @@ class TestMatchingLoad(unittest.TestCase): self.assertEqual(len(instance), data_bundle.num_dataset) for x, y in zip(instance, data_bundle.iter_datasets()): name, dataset = y - self.assertEqual(x, len(dataset)) + with self.subTest(path=path, split=name): + self.assertEqual(x, len(dataset)) diff --git a/test/io/pipe/test_classification.py b/test/io/pipe/test_classification.py index e8081590..c6bd5444 100644 --- a/test/io/pipe/test_classification.py +++ b/test/io/pipe/test_classification.py @@ -70,7 +70,7 @@ class TestRunClassificationPipe(unittest.TestCase): } for k, v in data_set_dict.items(): path, pipe, data_set, vocab, warns = v - with self.subTest(pipe=pipe): + with self.subTest(path=path): if 'Chn' not in k: if warns: with self.assertWarns(Warning): diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py index bfd65db2..ea687b2e 100644 --- a/test/io/pipe/test_matching.py +++ b/test/io/pipe/test_matching.py @@ -39,7 +39,7 @@ class TestRunMatchingPipe(unittest.TestCase): 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), - 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), + 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 6, 5), (36, 2), False), } for k, v in data_set_dict.items(): path, pipe1, pipe2, data_set, vocab, warns = v @@ -58,7 +58,8 @@ class TestRunMatchingPipe(unittest.TestCase): print(data_bundle2) for x, y in zip(data_set, data_bundle1.iter_datasets()): name, dataset = y - self.assertEqual(x, len(dataset)) + with self.subTest(path=path, split=name): + self.assertEqual(x, len(dataset)) self.assertEqual(len(data_set), data_bundle2.num_dataset) for x, y in zip(data_set, data_bundle2.iter_datasets()): name, dataset = y diff --git a/test/models/test_seq2seq_generator.py b/test/models/test_seq2seq_generator.py new file mode 100644 index 00000000..ac21281f --- /dev/null +++ b/test/models/test_seq2seq_generator.py @@ -0,0 +1,76 @@ + +import unittest +from fastNLP.models import SequenceGeneratorModel +from fastNLP.models import LSTMSeq2SeqModel, TransformerSeq2SeqModel +from fastNLP import Vocabulary, DataSet +import torch +from fastNLP.embeddings import StaticEmbedding +from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric +from fastNLP import Callback + + +def prepare_env(): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) + + src_words_idx = [[3, 1, 2], [1, 2]] + # tgt_words_idx = [[1, 2, 3, 4], [2, 3]] + src_seq_len = [3, 2] + # tgt_seq_len = [4, 2] + + ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, + 'tgt_seq_len':src_seq_len}) + + ds.set_input('src_tokens', 'tgt_tokens', 'src_seq_len') + ds.set_target('tgt_seq_len', 'tgt_tokens') + + return embed, ds + + +class ExitCallback(Callback): + def __init__(self): + super().__init__() + + def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): + if eval_result['AccuracyMetric']['acc']==1: + raise KeyboardInterrupt() + + +class TestSeq2SeqGeneratorModel(unittest.TestCase): + def test_run(self): + # 检测是否能够使用SequenceGeneratorModel训练, 透传预测 + embed, ds = prepare_env() + model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, + dim_ff=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True) + trainer = Trainer(ds, model1, optimizer=None, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'), + batch_size=32, sampler=None, drop_last=False, update_every=1, + num_workers=0, n_epochs=100, print_every=5, + dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), metric_key=None, + validate_every=-1, save_path=None, use_tqdm=False, device=None, + callbacks=ExitCallback(), check_code_level=0) + res = trainer.train() + self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1) + + embed, ds = prepare_env() + model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + num_layers=1, hidden_size=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True, attention=True) + optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) + trainer = Trainer(ds, model2, optimizer=optimizer, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'), + batch_size=32, sampler=None, drop_last=False, update_every=1, + num_workers=0, n_epochs=200, print_every=1, + dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), + metric_key=None, + validate_every=-1, save_path=None, use_tqdm=False, device=None, + callbacks=ExitCallback(), check_code_level=0) + res = trainer.train() + self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1) + + + + diff --git a/test/models/test_seq2seq_model.py b/test/models/test_seq2seq_model.py new file mode 100644 index 00000000..fc35b02e --- /dev/null +++ b/test/models/test_seq2seq_model.py @@ -0,0 +1,114 @@ + +import unittest + +from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel +from fastNLP import Vocabulary +from fastNLP.embeddings import StaticEmbedding +import torch +from torch import optim +import torch.nn.functional as F +from fastNLP import seq_len_to_mask + + +def prepare_env(): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) + + src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) + tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) + src_seq_len = torch.LongTensor([3, 2]) + tgt_seq_len = torch.LongTensor([4, 2]) + + return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len + + +def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): + optimizer = optim.Adam(model.parameters(), lr=1e-2) + mask = seq_len_to_mask(tgt_seq_len).eq(0) + target = tgt_words_idx.masked_fill(mask, -100) + + for i in range(100): + optimizer.zero_grad() + pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size + loss = F.cross_entropy(pred.transpose(1, 2), target) + loss.backward() + optimizer.step() + + right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() + return right_count + + +class TestTransformerSeq2SeqModel(unittest.TestCase): + def test_run(self): + # 测试能否跑通 + embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() + for pos_embed in ['learned', 'sin']: + with self.subTest(pos_embed=pos_embed): + model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True) + + output = model(src_words_idx, tgt_words_idx, src_seq_len) + self.assertEqual(output['pred'].size(), (2, 4, len(embed))) + + for bind_encoder_decoder_embed in [True, False]: + tgt_embed = None + for bind_decoder_input_output_embed in [True, False]: + if bind_encoder_decoder_embed == False: + tgt_embed = embed + with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed, + bind_decoder_input_output_embed=bind_decoder_input_output_embed): + model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, + pos_embed='sin', max_position=20, num_layers=2, + d_model=30, n_head=6, dim_ff=20, dropout=0.1, + bind_encoder_decoder_embed=bind_encoder_decoder_embed, + bind_decoder_input_output_embed=bind_decoder_input_output_embed) + + output = model(src_words_idx, tgt_words_idx, src_seq_len) + self.assertEqual(output['pred'].size(), (2, 4, len(embed))) + + def test_train(self): + # 测试能否train到overfit + embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() + + model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True) + + right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) + self.assertEqual(right_count, tgt_words_idx.nelement()) + + +class TestLSTMSeq2SeqModel(unittest.TestCase): + def test_run(self): + # 测试能否跑通 + embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() + + for bind_encoder_decoder_embed in [True, False]: + tgt_embed = None + for bind_decoder_input_output_embed in [True, False]: + if bind_encoder_decoder_embed == False: + tgt_embed = embed + with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed, + bind_decoder_input_output_embed=bind_decoder_input_output_embed): + model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, + num_layers=2, hidden_size=20, dropout=0.1, + bind_encoder_decoder_embed=bind_encoder_decoder_embed, + bind_decoder_input_output_embed=bind_decoder_input_output_embed) + output = model(src_words_idx, tgt_words_idx, src_seq_len) + self.assertEqual(output['pred'].size(), (2, 4, len(embed))) + + def test_train(self): + embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() + + model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, + num_layers=1, hidden_size=20, dropout=0.1, + bind_encoder_decoder_embed=True, + bind_decoder_input_output_embed=True) + + right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) + self.assertEqual(right_count, tgt_words_idx.nelement()) + diff --git a/test/modules/decoder/test_seq2seq_decoder.py b/test/modules/decoder/test_seq2seq_decoder.py new file mode 100644 index 00000000..00437edb --- /dev/null +++ b/test/modules/decoder/test_seq2seq_decoder.py @@ -0,0 +1,50 @@ +import unittest + +import torch + +from fastNLP import Vocabulary +from fastNLP.embeddings import StaticEmbedding +from fastNLP.modules import TransformerSeq2SeqDecoder +from fastNLP.modules import LSTMSeq2SeqDecoder +from fastNLP import seq_len_to_mask + + +class TestTransformerSeq2SeqDecoder(unittest.TestCase): + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, embedding_dim=10) + + encoder_output = torch.randn(2, 3, 10) + src_seq_len = torch.LongTensor([3, 2]) + encoder_mask = seq_len_to_mask(src_seq_len) + + for flag in [True, False]: + with self.subTest(bind_decoder_input_output_embed=flag): + decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, + d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, + bind_decoder_input_output_embed = True) + state = decoder.init_state(encoder_output, encoder_mask) + output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) + self.assertEqual(output.size(), (2, 4, len(vocab))) + + +class TestLSTMDecoder(unittest.TestCase): + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) + + encoder_output = torch.randn(2, 3, 10) + tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) + src_seq_len = torch.LongTensor([3, 2]) + encoder_mask = seq_len_to_mask(src_seq_len) + + for flag in [True, False]: + for attention in [True, False]: + with self.subTest(bind_decoder_input_output_embed=flag, attention=attention): + decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, + dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) + state = decoder.init_state(encoder_output, encoder_mask) + output = decoder(tgt_words_idx, state) + self.assertEqual(tuple(output.size()), (2, 4, len(vocab))) diff --git a/test/modules/encoder/test_seq2seq_encoder.py b/test/modules/encoder/test_seq2seq_encoder.py new file mode 100644 index 00000000..08c03145 --- /dev/null +++ b/test/modules/encoder/test_seq2seq_encoder.py @@ -0,0 +1,30 @@ +import unittest + +import torch + +from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder +from fastNLP import Vocabulary +from fastNLP.embeddings import StaticEmbedding + + +class TestTransformerSeq2SeqEncoder(unittest.TestCase): + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + embed = StaticEmbedding(vocab, embedding_dim=5) + encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2) + words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) + seq_len = torch.LongTensor([3]) + encoder_output, encoder_mask = encoder(words_idx, seq_len) + self.assertEqual(encoder_output.size(), (1, 3, 10)) + + +class TestBiLSTMEncoder(unittest.TestCase): + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + embed = StaticEmbedding(vocab, embedding_dim=5) + encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1) + words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) + seq_len = torch.LongTensor([3]) + + encoder_output, encoder_mask = encoder(words_idx, seq_len) + self.assertEqual(encoder_mask.size(), (1, 3)) diff --git a/test/modules/generator/__init__.py b/test/modules/generator/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/test/modules/generator/__init__.py @@ -0,0 +1 @@ + diff --git a/test/modules/generator/test_seq2seq_generator.py b/test/modules/generator/test_seq2seq_generator.py new file mode 100644 index 00000000..d7c0fbfa --- /dev/null +++ b/test/modules/generator/test_seq2seq_generator.py @@ -0,0 +1,110 @@ +import unittest + +import torch +from fastNLP.modules.generator import SequenceGenerator +from fastNLP.modules import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State +from fastNLP import Vocabulary +from fastNLP.embeddings import StaticEmbedding +from torch import nn +from fastNLP import seq_len_to_mask + + +def prepare_env(): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) + + encoder_output = torch.randn(2, 3, 10) + src_seq_len = torch.LongTensor([3, 2]) + encoder_mask = seq_len_to_mask(src_seq_len) + + return embed, encoder_output, encoder_mask + + +class TestSequenceGenerator(unittest.TestCase): + def test_run(self): + # 测试能否运行 (1) 初始化decoder,(2) decode一发 + embed, encoder_output, encoder_mask = prepare_env() + + for do_sample in [True, False]: + for num_beams in [1, 3, 5]: + with self.subTest(do_sample=do_sample, num_beams=num_beams): + decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, + dropout=0.3, bind_decoder_input_output_embed=True, attention=True) + state = decoder.init_state(encoder_output, encoder_mask) + generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, + do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, + repetition_penalty=1, length_penalty=1.0, pad_token_id=0) + generator.generate(state=state, tokens=None) + + decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), + d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, + bind_decoder_input_output_embed=True) + state = decoder.init_state(encoder_output, encoder_mask) + generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, + do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, + repetition_penalty=1, length_penalty=1.0, pad_token_id=0) + generator.generate(state=state, tokens=None) + + # 测试一下其它值 + decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), + d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, + dropout=0.1, + bind_decoder_input_output_embed=True) + state = decoder.init_state(encoder_output, encoder_mask) + generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, + do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, + eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) + generator.generate(state=state, tokens=None) + + def test_greedy_decode(self): + # 测试能否正确的generate + class GreedyDummyDecoder(Seq2SeqDecoder): + def __init__(self, decoder_output): + super().__init__() + self.cur_length = 0 + self.decoder_output = decoder_output + + def decode(self, tokens, state): + self.cur_length += 1 + scores = self.decoder_output[:, self.cur_length] + return scores + + class DummyState(State): + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + + def reorder_state(self, indices: torch.LongTensor): + self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) + + # greedy + for beam_search in [1, 3]: + decoder_output = torch.randn(2, 10, 5) + path = decoder_output.argmax(dim=-1) # 2 x 4 + decoder = GreedyDummyDecoder(decoder_output) + with self.subTest(beam_search=beam_search): + generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, + do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, + eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) + decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) + + self.assertEqual(decode_path.eq(path).sum(), path.numel()) + + # greedy check eos_token_id + for beam_search in [1, 3]: + decoder_output = torch.randn(2, 10, 5) + decoder_output[:, :7, 4].fill_(-100) + decoder_output[0, 7, 4] = 1000 # 在第8个结束 + decoder_output[1, 5, 4] = 1000 + path = decoder_output.argmax(dim=-1) # 2 x 4 + decoder = GreedyDummyDecoder(decoder_output) + with self.subTest(beam_search=beam_search): + generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, + do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, + eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) + decode_path = generator.generate(DummyState(decoder), + tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) + self.assertEqual(decode_path.size(1), 8) # 长度为8 + self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8) + self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6)