You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_stack_embeddings.py 739 B

1234567891011121314151617181920
  1. import unittest
  2. import torch
  3. from fastNLP import Vocabulary, DataSet, Instance
  4. from fastNLP.embeddings import LSTMCharEmbedding, CNNCharEmbedding, StackEmbedding
  5. class TestCharEmbed(unittest.TestCase):
  6. def test_case_1(self):
  7. ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['hello', 'Jack'])])
  8. vocab = Vocabulary().from_dataset(ds, field_name='words')
  9. self.assertEqual(len(vocab), 5)
  10. cnn_embed = CNNCharEmbedding(vocab, embed_size=60)
  11. lstm_embed = LSTMCharEmbedding(vocab, embed_size=70)
  12. embed = StackEmbedding([cnn_embed, lstm_embed])
  13. x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
  14. y = embed(x)
  15. self.assertEqual(tuple(y.size()), (2, 3, 130))