You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_char_embedding.py 1.0 kB

1234567891011121314151617181920212223242526
  1. import unittest
  2. import torch
  3. from fastNLP import Vocabulary, DataSet, Instance
  4. from fastNLP.embeddings.char_embedding import LSTMCharEmbedding, CNNCharEmbedding
  5. class TestCharEmbed(unittest.TestCase):
  6. def test_case_1(self):
  7. ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])])
  8. vocab = Vocabulary().from_dataset(ds, field_name='words')
  9. self.assertEqual(len(vocab), 5)
  10. embed = LSTMCharEmbedding(vocab, embed_size=60)
  11. x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
  12. y = embed(x)
  13. self.assertEqual(tuple(y.size()), (2, 3, 60))
  14. def test_case_2(self):
  15. ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])])
  16. vocab = Vocabulary().from_dataset(ds, field_name='words')
  17. self.assertEqual(len(vocab), 5)
  18. embed = CNNCharEmbedding(vocab, embed_size=60)
  19. x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
  20. y = embed(x)
  21. self.assertEqual(tuple(y.size()), (2, 3, 60))