You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_static_embedding.py 606 B

123456789101112131415
  1. import unittest
  2. from fastNLP.embeddings import StaticEmbedding
  3. from fastNLP import Vocabulary
  4. import torch
  5. class TestRandomSameEntry(unittest.TestCase):
  6. def test_same_vector(self):
  7. vocab = Vocabulary().add_word_lst(["The", "the", "THE"])
  8. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
  9. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]])
  10. words = embed(words)
  11. embed_0 = words[0, 0]
  12. for i in range(1, words.size(1)):
  13. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))