You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sequence_labeling.py 1.2 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536
  1. import unittest
  2. from .model_runner import *
  3. from fastNLP.models.sequence_labeling import SeqLabeling, AdvSeqLabel
  4. from fastNLP.core.losses import LossInForward
  5. class TesSeqLabel(unittest.TestCase):
  6. def test_case1(self):
  7. # 测试能否正常运行CNN
  8. init_emb = (VOCAB_SIZE, 30)
  9. model = SeqLabeling(init_emb,
  10. hidden_size=30,
  11. num_classes=NUM_CLS)
  12. data = RUNNER.prepare_pos_tagging_data()
  13. data.set_input('target')
  14. loss = LossInForward()
  15. metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN)
  16. RUNNER.run_model(model, data, loss, metric)
  17. class TesAdvSeqLabel(unittest.TestCase):
  18. def test_case1(self):
  19. # 测试能否正常运行CNN
  20. init_emb = (VOCAB_SIZE, 30)
  21. model = AdvSeqLabel(init_emb,
  22. hidden_size=30,
  23. num_classes=NUM_CLS)
  24. data = RUNNER.prepare_pos_tagging_data()
  25. data.set_input('target')
  26. loss = LossInForward()
  27. metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN)
  28. RUNNER.run_model(model, data, loss, metric)