You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_stack_embeddings.py 1.4 kB

123456789101112131415161718192021222324252627282930313233
  1. import unittest
  2. import torch
  3. from fastNLP import Vocabulary, DataSet, Instance
  4. from fastNLP.embeddings import LSTMCharEmbedding, CNNCharEmbedding, StackEmbedding
  5. class TestCharEmbed(unittest.TestCase):
  6. def test_case_1(self):
  7. ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['hello', 'Jack'])])
  8. vocab = Vocabulary().from_dataset(ds, field_name='words')
  9. self.assertEqual(len(vocab), 5)
  10. cnn_embed = CNNCharEmbedding(vocab, embed_size=60)
  11. lstm_embed = LSTMCharEmbedding(vocab, embed_size=70)
  12. embed = StackEmbedding([cnn_embed, lstm_embed])
  13. x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
  14. y = embed(x)
  15. self.assertEqual(tuple(y.size()), (2, 3, 130))
  16. def test_case_2(self):
  17. # 测试只需要拥有一样的index就可以concat
  18. ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['hello', 'Jack'])])
  19. vocab1 = Vocabulary().from_dataset(ds, field_name='words')
  20. vocab2 = Vocabulary().from_dataset(ds, field_name='words')
  21. self.assertEqual(len(vocab1), 5)
  22. cnn_embed = CNNCharEmbedding(vocab1, embed_size=60)
  23. lstm_embed = LSTMCharEmbedding(vocab2, embed_size=70)
  24. embed = StackEmbedding([cnn_embed, lstm_embed])
  25. x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
  26. y = embed(x)
  27. self.assertEqual(tuple(y.size()), (2, 3, 130))