import unittest import torch from fastNLP import Vocabulary, DataSet, Instance from fastNLP.embeddings.char_embedding import LSTMCharEmbedding, CNNCharEmbedding class TestCharEmbed(unittest.TestCase): def test_case_1(self): ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) vocab = Vocabulary().from_dataset(ds, field_name='words') self.assertEqual(len(vocab), 5) embed = LSTMCharEmbedding(vocab, embed_size=60) x = torch.LongTensor([[2, 1, 0], [4, 3, 4]]) y = embed(x) self.assertEqual(tuple(y.size()), (2, 3, 60)) def test_case_2(self): ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) vocab = Vocabulary().from_dataset(ds, field_name='words') self.assertEqual(len(vocab), 5) embed = CNNCharEmbedding(vocab, embed_size=60) x = torch.LongTensor([[2, 1, 0], [4, 3, 4]]) y = embed(x) self.assertEqual(tuple(y.size()), (2, 3, 60))