You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_char_embedding.py 1.1 kB

123456789101112131415161718192021222324252627282930
  1. import pytest
  2. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  3. if _NEED_IMPORT_TORCH:
  4. import torch
  5. from fastNLP import Vocabulary, DataSet, Instance
  6. from fastNLP.embeddings.torch.char_embedding import LSTMCharEmbedding, CNNCharEmbedding
  7. @pytest.mark.torch
  8. class TestCharEmbed:
  9. # @pytest.mark.test
  10. def test_case_1(self):
  11. ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])])
  12. vocab = Vocabulary().from_dataset(ds, field_name='words')
  13. assert len(vocab)==5
  14. embed = LSTMCharEmbedding(vocab, embed_size=3)
  15. x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
  16. y = embed(x)
  17. assert tuple(y.size()) == (2, 3, 3)
  18. # @pytest.mark.test
  19. def test_case_2(self):
  20. ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])])
  21. vocab = Vocabulary().from_dataset(ds, field_name='words')
  22. assert len(vocab)==5
  23. embed = CNNCharEmbedding(vocab, embed_size=3)
  24. x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
  25. y = embed(x)
  26. assert tuple(y.size()) == (2, 3, 3)