import os import unittest from fastNLP.core.vocabulary import Vocabulary from fastNLP.io.embed_loader import EmbedLoader class TestEmbedLoader(unittest.TestCase): glove_path = './test/data_for_tests/glove.6B.50d_test.txt' pkl_path = './save' raw_texts = ["i am a cat", "this is a test of new batch", "ha ha", "I am a good boy .", "This is the most beautiful girl ." ] texts = [text.strip().split() for text in raw_texts] vocab = Vocabulary() vocab.update(texts) def test1(self): emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path) self.assertTrue(emb.shape[0] == (len(self.vocab))) self.assertTrue(emb.shape[1] == 50) os.remove(self.pkl_path) def test2(self): try: _ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path) self.fail(msg="load dismatch embedding") except ValueError: pass