import pytest from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from fastNLP import Vocabulary, DataSet, Instance from fastNLP.embeddings.torch.char_embedding import LSTMCharEmbedding, CNNCharEmbedding @pytest.mark.torch class TestCharEmbed: # @pytest.mark.test def test_case_1(self): ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) vocab = Vocabulary().from_dataset(ds, field_name='words') assert len(vocab)==5 embed = LSTMCharEmbedding(vocab, embed_size=3) x = torch.LongTensor([[2, 1, 0], [4, 3, 4]]) y = embed(x) assert tuple(y.size()) == (2, 3, 3) # @pytest.mark.test def test_case_2(self): ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) vocab = Vocabulary().from_dataset(ds, field_name='words') assert len(vocab)==5 embed = CNNCharEmbedding(vocab, embed_size=3) x = torch.LongTensor([[2, 1, 0], [4, 3, 4]]) y = embed(x) assert tuple(y.size()) == (2, 3, 3)