You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_biaffine_parser.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import unittest
  2. import fastNLP
  3. from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric
  4. from .model_runner import *
  5. def prepare_parser_data():
  6. index = 'index'
  7. ds = DataSet({index: list(range(N_SAMPLES))})
  8. ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE),
  9. field_name=index, new_field_name=C.INPUTS(0),
  10. is_input=True)
  11. ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS),
  12. field_name=C.INPUTS(0), new_field_name=C.INPUTS(1),
  13. is_input=True)
  14. # target1 is heads, should in range(0, len(words))
  15. ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)),
  16. field_name=C.INPUTS(0), new_field_name=C.TARGETS(0),
  17. is_target=True)
  18. ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS),
  19. field_name=C.INPUTS(0), new_field_name=C.TARGETS(1),
  20. is_target=True)
  21. ds.apply_field(len, field_name=C.INPUTS(0), new_field_name=C.INPUT_LEN,
  22. is_input=True, is_target=True)
  23. return ds
  24. class TestBiaffineParser(unittest.TestCase):
  25. def test_train(self):
  26. model = BiaffineParser(init_embed=(VOCAB_SIZE, 10),
  27. pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
  28. rnn_hidden_size=10,
  29. arc_mlp_size=10,
  30. label_mlp_size=10,
  31. num_label=NUM_CLS, encoder='var-lstm')
  32. ds = prepare_parser_data()
  33. RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric())
  34. def test_train2(self):
  35. model = BiaffineParser(init_embed=(VOCAB_SIZE, 10),
  36. pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
  37. rnn_hidden_size=16,
  38. arc_mlp_size=10,
  39. label_mlp_size=10,
  40. num_label=NUM_CLS, encoder='transformer')
  41. ds = prepare_parser_data()
  42. RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric())