You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_char_language_model.py 692 B

12345678910111213141516171819202122232425
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from fastNLP.models.char_language_model import CharLM
  5. class TestCharLM(unittest.TestCase):
  6. def test_case_1(self):
  7. char_emb_dim = 50
  8. word_emb_dim = 50
  9. vocab_size = 1000
  10. num_char = 24
  11. max_word_len = 21
  12. num_seq = 64
  13. seq_len = 32
  14. model = CharLM(char_emb_dim, word_emb_dim, vocab_size, num_char)
  15. x = torch.from_numpy(np.random.randint(0, num_char, size=(num_seq, seq_len, max_word_len + 2)))
  16. self.assertEqual(tuple(x.shape), (num_seq, seq_len, max_word_len + 2))
  17. y = model(x)
  18. self.assertEqual(tuple(y.shape), (num_seq * seq_len, vocab_size))