| @@ -1,13 +1,6 @@ | |||
| import _pickle | |||
| import os | |||
| import numpy as np | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.field import TextField, LabelField | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| # the first vocab in dict with the index = 5 | |||
| @@ -53,258 +46,3 @@ def pickle_exist(pickle_path, pickle_name): | |||
| return True | |||
| else: | |||
| return False | |||
| class Preprocessor(object): | |||
| """Preprocessors are responsible for converting data of strings into data of indices. | |||
| During the pre-processing, the following pickle files will be built: | |||
| - "word2id.pkl", a Vocabulary object, mapping words to indices. | |||
| - "class2id.pkl", a Vocabulary object, mapping labels to indices. | |||
| - "data_train.pkl", a DataSet object for training | |||
| - "data_dev.pkl", a DataSet object for validation, if train_dev_split > 0. | |||
| - "data_test.pkl", a DataSet object for testing, if test_data is not None. | |||
| These four pickle files are expected to be saved in the given pickle directory once they are constructed. | |||
| Preprocessors will check if those files are already in the directory and will reuse them in future calls. | |||
| """ | |||
| def __init__(self, label_is_seq=False, share_vocab=False, add_char_field=False): | |||
| """ | |||
| :param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve | |||
| several special tokens for sequence processing. | |||
| :param share_vocab: bool, whether word sequence and label sequence share the same vocabulary. Typically, this | |||
| is only available when label_is_seq is True. Default: False. | |||
| :param add_char_field: bool, whether to add character representations to all TextFields. Default: False. | |||
| """ | |||
| print("Preprocessor is about to deprecate. Please use DataSet class.") | |||
| self.data_vocab = Vocabulary() | |||
| if label_is_seq is True: | |||
| if share_vocab is True: | |||
| self.label_vocab = self.data_vocab | |||
| else: | |||
| self.label_vocab = Vocabulary() | |||
| else: | |||
| self.label_vocab = Vocabulary(need_default=False) | |||
| self.character_vocab = Vocabulary(need_default=False) | |||
| self.add_char_field = add_char_field | |||
| @property | |||
| def vocab_size(self): | |||
| return len(self.data_vocab) | |||
| @property | |||
| def num_classes(self): | |||
| return len(self.label_vocab) | |||
| @property | |||
| def char_vocab_size(self): | |||
| if self.character_vocab is None: | |||
| self.build_char_dict() | |||
| return len(self.character_vocab) | |||
| def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | |||
| """Main pre-processing pipeline. | |||
| :param train_dev_data: three-level list, with either single label or multiple labels in a sample. | |||
| :param test_data: three-level list, with either single label or multiple labels in a sample. (optional) | |||
| :param pickle_path: str, the path to save the pickle files. | |||
| :param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set. | |||
| :param cross_val: bool, whether to do cross validation. | |||
| :param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True. | |||
| :return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset. | |||
| If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset | |||
| is a list of DataSet objects; Otherwise, each dataset is a DataSet object. | |||
| """ | |||
| if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | |||
| self.data_vocab = load_pickle(pickle_path, "word2id.pkl") | |||
| self.label_vocab = load_pickle(pickle_path, "class2id.pkl") | |||
| else: | |||
| self.data_vocab, self.label_vocab = self.build_dict(train_dev_data) | |||
| save_pickle(self.data_vocab, pickle_path, "word2id.pkl") | |||
| save_pickle(self.label_vocab, pickle_path, "class2id.pkl") | |||
| self.build_reverse_dict() | |||
| train_set = [] | |||
| dev_set = [] | |||
| if not cross_val: | |||
| if not pickle_exist(pickle_path, "data_train.pkl"): | |||
| if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | |||
| split = int(len(train_dev_data) * train_dev_split) | |||
| data_dev = train_dev_data[: split] | |||
| data_train = train_dev_data[split:] | |||
| train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab) | |||
| dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab) | |||
| save_pickle(dev_set, pickle_path, "data_dev.pkl") | |||
| print("{} of the training data is split for validation. ".format(train_dev_split)) | |||
| else: | |||
| train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab) | |||
| save_pickle(train_set, pickle_path, "data_train.pkl") | |||
| else: | |||
| train_set = load_pickle(pickle_path, "data_train.pkl") | |||
| if pickle_exist(pickle_path, "data_dev.pkl"): | |||
| dev_set = load_pickle(pickle_path, "data_dev.pkl") | |||
| else: | |||
| # cross_val is True | |||
| if not pickle_exist(pickle_path, "data_train_0.pkl"): | |||
| # cross validation | |||
| data_cv = self.cv_split(train_dev_data, n_fold) | |||
| for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): | |||
| data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab) | |||
| data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab) | |||
| save_pickle( | |||
| data_train_cv, pickle_path, | |||
| "data_train_{}.pkl".format(i)) | |||
| save_pickle( | |||
| data_dev_cv, pickle_path, | |||
| "data_dev_{}.pkl".format(i)) | |||
| train_set.append(data_train_cv) | |||
| dev_set.append(data_dev_cv) | |||
| print("{}-fold cross validation.".format(n_fold)) | |||
| else: | |||
| for i in range(n_fold): | |||
| data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i)) | |||
| data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i)) | |||
| train_set.append(data_train_cv) | |||
| dev_set.append(data_dev_cv) | |||
| # prepare test data if provided | |||
| test_set = [] | |||
| if test_data is not None: | |||
| if not pickle_exist(pickle_path, "data_test.pkl"): | |||
| test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab) | |||
| save_pickle(test_set, pickle_path, "data_test.pkl") | |||
| # return preprocessed results | |||
| results = [train_set] | |||
| if cross_val or train_dev_split > 0: | |||
| results.append(dev_set) | |||
| if test_data: | |||
| results.append(test_set) | |||
| if len(results) == 1: | |||
| return results[0] | |||
| else: | |||
| return tuple(results) | |||
| def build_dict(self, data): | |||
| for example in data: | |||
| word, label = example | |||
| self.data_vocab.update(word) | |||
| self.label_vocab.update(label) | |||
| return self.data_vocab, self.label_vocab | |||
| def build_char_dict(self): | |||
| char_collection = set() | |||
| for word in self.data_vocab.word2idx: | |||
| if len(word) == 0: | |||
| continue | |||
| for ch in word: | |||
| if ch not in char_collection: | |||
| char_collection.add(ch) | |||
| self.character_vocab.update(list(char_collection)) | |||
| def build_reverse_dict(self): | |||
| self.data_vocab.build_reverse_vocab() | |||
| self.label_vocab.build_reverse_vocab() | |||
| def data_split(self, data, train_dev_split): | |||
| """Split data into train and dev set.""" | |||
| split = int(len(data) * train_dev_split) | |||
| data_dev = data[: split] | |||
| data_train = data[split:] | |||
| return data_train, data_dev | |||
| def cv_split(self, data, n_fold): | |||
| """Split data for cross validation. | |||
| :param data: list of string | |||
| :param n_fold: int | |||
| :return data_cv: | |||
| :: | |||
| [ | |||
| (data_train, data_dev), # 1st fold | |||
| (data_train, data_dev), # 2nd fold | |||
| ... | |||
| ] | |||
| """ | |||
| data_copy = data.copy() | |||
| np.random.shuffle(data_copy) | |||
| fold_size = round(len(data_copy) / n_fold) | |||
| data_cv = [] | |||
| for i in range(n_fold - 1): | |||
| start = i * fold_size | |||
| end = (i + 1) * fold_size | |||
| data_dev = data_copy[start:end] | |||
| data_train = data_copy[:start] + data_copy[end:] | |||
| data_cv.append((data_train, data_dev)) | |||
| start = (n_fold - 1) * fold_size | |||
| data_dev = data_copy[start:] | |||
| data_train = data_copy[:start] | |||
| data_cv.append((data_train, data_dev)) | |||
| return data_cv | |||
| def convert_to_dataset(self, data, vocab, label_vocab): | |||
| """Convert list of indices into a DataSet object. | |||
| :param data: list. Entries are strings. | |||
| :param vocab: a dict, mapping string (token) to index (int). | |||
| :param label_vocab: a dict, mapping string (label) to index (int). | |||
| :return data_set: a DataSet object | |||
| """ | |||
| use_word_seq = False | |||
| use_label_seq = False | |||
| use_label_str = False | |||
| # construct a DataSet object and fill it with Instances | |||
| data_set = DataSet() | |||
| for example in data: | |||
| words, label = example[0], example[1] | |||
| instance = Instance() | |||
| if isinstance(words, list): | |||
| x = TextField(words, is_target=False) | |||
| instance.add_field("word_seq", x) | |||
| use_word_seq = True | |||
| else: | |||
| raise NotImplementedError("words is a {}".format(type(words))) | |||
| if isinstance(label, list): | |||
| y = TextField(label, is_target=True) | |||
| instance.add_field("label_seq", y) | |||
| use_label_seq = True | |||
| elif isinstance(label, str): | |||
| y = LabelField(label, is_target=True) | |||
| instance.add_field("label", y) | |||
| use_label_str = True | |||
| else: | |||
| raise NotImplementedError("label is a {}".format(type(label))) | |||
| data_set.append(instance) | |||
| # convert strings to indices | |||
| if use_word_seq: | |||
| data_set.index_field("word_seq", vocab) | |||
| if use_label_seq: | |||
| data_set.index_field("label_seq", label_vocab) | |||
| if use_label_str: | |||
| data_set.index_field("label", label_vocab) | |||
| return data_set | |||
| class SeqLabelPreprocess(Preprocessor): | |||
| def __init__(self): | |||
| print("[FastNLP warning] SeqLabelPreprocess is about to deprecate. Please use Preprocess directly.") | |||
| super(SeqLabelPreprocess, self).__init__() | |||
| class ClassPreprocess(Preprocessor): | |||
| def __init__(self): | |||
| print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | |||
| super(ClassPreprocess, self).__init__() | |||
| @@ -13,69 +13,3 @@ class BaseModel(torch.nn.Module): | |||
| def fit(self, train_data, dev_data=None, **train_args): | |||
| trainer = Trainer(**train_args) | |||
| trainer.train(self, train_data, dev_data) | |||
| class Vocabulary(object): | |||
| """A look-up table that allows you to access `Lexeme` objects. The `Vocab` | |||
| instance also provides access to the `StringStore`, and owns underlying | |||
| data that is shared between `Doc` objects. | |||
| """ | |||
| def __init__(self): | |||
| """Create the vocabulary. | |||
| RETURNS (Vocab): The newly constructed object. | |||
| """ | |||
| self.data_frame = None | |||
| class Document(object): | |||
| """A sequence of Token objects. Access sentences and named entities, export | |||
| annotations to numpy arrays, losslessly serialize to compressed binary | |||
| strings. The `Doc` object holds an array of `Token` objects. The | |||
| Python-level `Token` and `Span` objects are views of this array, i.e. | |||
| they don't own the data themselves. -- spacy | |||
| """ | |||
| def __init__(self, vocab, words=None, spaces=None): | |||
| """Create a Doc object. | |||
| vocab (Vocab): A vocabulary object, which must match any models you | |||
| want to use (e.g. tokenizer, parser, entity recognizer). | |||
| words (list or None): A list of unicode strings, to add to the document | |||
| as words. If `None`, defaults to empty list. | |||
| spaces (list or None): A list of boolean values, of the same length as | |||
| words. True means that the word is followed by a space, False means | |||
| it is not. If `None`, defaults to `[True]*len(words)` | |||
| user_data (dict or None): Optional extra data to attach to the Doc. | |||
| RETURNS (Doc): The newly constructed object. | |||
| """ | |||
| self.vocab = vocab | |||
| self.spaces = spaces | |||
| self.words = words | |||
| if spaces is None: | |||
| self.spaces = [True] * len(self.words) | |||
| elif len(spaces) != len(self.words): | |||
| raise ValueError("dismatch spaces and words") | |||
| def get_chunker(self, vocab): | |||
| return None | |||
| def push_back(self, vocab): | |||
| pass | |||
| class Token(object): | |||
| """An individual token – i.e. a word, punctuation symbol, whitespace, | |||
| etc. | |||
| """ | |||
| def __init__(self, vocab, doc, offset): | |||
| """Construct a `Token` object. | |||
| vocab (Vocabulary): A storage container for lexical types. | |||
| doc (Document): The parent document. | |||
| offset (int): The index of the token within the document. | |||
| """ | |||
| self.vocab = vocab | |||
| self.doc = doc | |||
| self.token = doc[offset] | |||
| self.i = offset | |||
| @@ -103,7 +103,7 @@ class CharLM(nn.Module): | |||
| x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) | |||
| # [num_seq, seq_len, total_num_filters] | |||
| x, hidden = self.lstm(x) | |||
| x = self.lstm(x) | |||
| # [seq_len, num_seq, hidden_size] | |||
| x = self.dropout(x) | |||
| @@ -1,12 +1,14 @@ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| # from torch.nn.init import xavier_uniform | |||
| from fastNLP.modules.utils import initial_parameter | |||
| # from torch.nn.init import xavier_uniform | |||
| class ConvCharEmbedding(nn.Module): | |||
| def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5),initial_method = None): | |||
| def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): | |||
| """ | |||
| Character Level Word Embedding | |||
| :param char_emb_size: the size of character level embedding. Default: 50 | |||
| @@ -21,7 +23,7 @@ class ConvCharEmbedding(nn.Module): | |||
| nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | |||
| for i in range(len(kernels))]) | |||
| initial_parameter(self,initial_method) | |||
| initial_parameter(self, initial_method) | |||
| def forward(self, x): | |||
| """ | |||
| @@ -56,7 +58,7 @@ class LSTMCharEmbedding(nn.Module): | |||
| :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. | |||
| """ | |||
| def __init__(self, char_emb_size=50, hidden_size=None , initial_method= None): | |||
| def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | |||
| super(LSTMCharEmbedding, self).__init__() | |||
| self.hidden_size = char_emb_size if hidden_size is None else hidden_size | |||
| @@ -66,6 +68,7 @@ class LSTMCharEmbedding(nn.Module): | |||
| bias=True, | |||
| batch_first=True) | |||
| initial_parameter(self, initial_method) | |||
| def forward(self, x): | |||
| """ | |||
| :param x:[ n_batch*n_word, word_length, char_emb_size] | |||
| @@ -79,20 +82,3 @@ class LSTMCharEmbedding(nn.Module): | |||
| _, hidden = self.lstm(x, (h0, c0)) | |||
| return hidden[0].squeeze().unsqueeze(2) | |||
| if __name__ == "__main__": | |||
| batch_size = 128 | |||
| char_emb = 100 | |||
| word_length = 1 | |||
| x = torch.Tensor(batch_size, char_emb, word_length) | |||
| x = x.transpose(1, 2) | |||
| cce = ConvCharEmbedding(char_emb) | |||
| y = cce(x) | |||
| print("CNN Char Emb input: ", x.shape) | |||
| print("CNN Char Emb output: ", y.shape) # [128, 100] | |||
| lce = LSTMCharEmbedding(char_emb) | |||
| o = lce(x) | |||
| print("LSTM Char Emb input: ", x.shape) | |||
| print("LSTM Char Emb size: ", o.shape) | |||
| @@ -1,24 +1,8 @@ | |||
| from fastNLP.core.loss import Loss | |||
| from fastNLP.core.preprocess import Preprocessor | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.loader.dataset_loader import LMDataSetLoader | |||
| from fastNLP.models.char_language_model import CharLM | |||
| PICKLE = "./save/" | |||
| def train(): | |||
| loader = LMDataSetLoader() | |||
| train_data = loader.load() | |||
| pre = Preprocessor(label_is_seq=True, share_vocab=True) | |||
| train_set = pre.run(train_data, pickle_path=PICKLE) | |||
| model = CharLM(50, 50, pre.vocab_size, pre.char_vocab_size) | |||
| trainer = Trainer(task="language_model", loss=Loss("cross_entropy")) | |||
| trainer.train(model, train_set) | |||
| pass | |||
| if __name__ == "__main__": | |||
| @@ -1,72 +0,0 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.preprocess import SeqLabelPreprocess | |||
| data = [ | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| ] | |||
| class TestCase1(unittest.TestCase): | |||
| def test(self): | |||
| if os.path.exists("./save"): | |||
| for root, dirs, files in os.walk("./save", topdown=False): | |||
| for name in files: | |||
| os.remove(os.path.join(root, name)) | |||
| for name in dirs: | |||
| os.rmdir(os.path.join(root, name)) | |||
| result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, | |||
| pickle_path="./save") | |||
| self.assertEqual(len(result), 2) | |||
| self.assertEqual(type(result[0]), DataSet) | |||
| self.assertEqual(type(result[1]), DataSet) | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| class TestCase2(unittest.TestCase): | |||
| def test(self): | |||
| if os.path.exists("./save"): | |||
| for root, dirs, files in os.walk("./save", topdown=False): | |||
| for name in files: | |||
| os.remove(os.path.join(root, name)) | |||
| for name in dirs: | |||
| os.rmdir(os.path.join(root, name)) | |||
| result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, | |||
| pickle_path="./save", train_dev_split=0.4, | |||
| cross_val=False) | |||
| self.assertEqual(len(result), 3) | |||
| self.assertEqual(type(result[0]), DataSet) | |||
| self.assertEqual(type(result[1]), DataSet) | |||
| self.assertEqual(type(result[2]), DataSet) | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| class TestCase3(unittest.TestCase): | |||
| def test(self): | |||
| num_folds = 2 | |||
| result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data, | |||
| pickle_path="./save", train_dev_split=0.4, | |||
| cross_val=True, n_fold=num_folds) | |||
| self.assertEqual(len(result), 2) | |||
| self.assertEqual(len(result[0]), num_folds) | |||
| self.assertEqual(len(result[1]), num_folds) | |||
| for data_set in result[0] + result[1]: | |||
| self.assertEqual(type(data_set), DataSet) | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| @@ -0,0 +1,25 @@ | |||
| import unittest | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.models.char_language_model import CharLM | |||
| class TestCharLM(unittest.TestCase): | |||
| def test_case_1(self): | |||
| char_emb_dim = 50 | |||
| word_emb_dim = 50 | |||
| vocab_size = 1000 | |||
| num_char = 24 | |||
| max_word_len = 21 | |||
| num_seq = 64 | |||
| seq_len = 32 | |||
| model = CharLM(char_emb_dim, word_emb_dim, vocab_size, num_char) | |||
| x = torch.from_numpy(np.random.randint(0, num_char, size=(num_seq, seq_len, max_word_len + 2))) | |||
| self.assertEqual(tuple(x.shape), (num_seq, seq_len, max_word_len + 2)) | |||
| y = model(x) | |||
| self.assertEqual(tuple(y.shape), (num_seq * seq_len, vocab_size)) | |||
| @@ -0,0 +1,28 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP.modules.encoder.char_embedding import ConvCharEmbedding, LSTMCharEmbedding | |||
| class TestCharEmbed(unittest.TestCase): | |||
| def test_case_1(self): | |||
| batch_size = 128 | |||
| char_emb = 100 | |||
| word_length = 1 | |||
| x = torch.Tensor(batch_size, char_emb, word_length) | |||
| x = x.transpose(1, 2) | |||
| cce = ConvCharEmbedding(char_emb) | |||
| y = cce(x) | |||
| self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb)) | |||
| print("CNN Char Emb input: ", x.shape) | |||
| self.assertEqual(tuple(y.shape), (batch_size, char_emb, 1)) | |||
| print("CNN Char Emb output: ", y.shape) # [128, 100] | |||
| lce = LSTMCharEmbedding(char_emb) | |||
| o = lce(x) | |||
| self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb)) | |||
| print("LSTM Char Emb input: ", x.shape) | |||
| self.assertEqual(tuple(o.shape), (batch_size, char_emb, 1)) | |||
| print("LSTM Char Emb size: ", o.shape) | |||
| @@ -1,9 +1,11 @@ | |||
| import unittest | |||
| import numpy as np | |||
| import torch | |||
| import unittest | |||
| from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM | |||
| class TestMaskedRnn(unittest.TestCase): | |||
| def test_case_1(self): | |||
| masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | |||
| @@ -16,13 +18,20 @@ class TestMaskedRnn(unittest.TestCase): | |||
| y = masked_rnn(x, mask=mask) | |||
| def test_case_2(self): | |||
| masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=False, batch_first=True) | |||
| x = torch.tensor([[[1.0], [2.0]]]) | |||
| print(x.size()) | |||
| y = masked_rnn(x) | |||
| mask = torch.tensor([[[1], [1]]]) | |||
| y = masked_rnn(x, mask=mask) | |||
| xx = torch.tensor([[[1.0]]]) | |||
| #y, hidden = masked_rnn.step(xx) | |||
| #step() still has a bug | |||
| #y, hidden = masked_rnn.step(xx, mask=mask) | |||
| input_size = 12 | |||
| batch = 16 | |||
| hidden = 10 | |||
| masked_rnn = VarMaskedFastLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) | |||
| x = torch.randn((batch, input_size)) | |||
| output, _ = masked_rnn.step(x) | |||
| self.assertEqual(tuple(output.shape), (batch, hidden)) | |||
| xx = torch.randn((batch, 32, input_size)) | |||
| y, _ = masked_rnn(xx) | |||
| self.assertEqual(tuple(y.shape), (batch, 32, hidden)) | |||
| xx = torch.randn((batch, 32, input_size)) | |||
| mask = torch.from_numpy(np.random.randint(0, 2, size=(batch, 32))).to(xx) | |||
| y, _ = masked_rnn(xx, mask=mask) | |||
| self.assertEqual(tuple(y.shape), (batch, 32, hidden)) | |||