You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_biaffine_parser.py 4.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric
  2. import fastNLP
  3. import unittest
  4. data_file = """
  5. 1 The _ DET DT _ 3 det _ _
  6. 2 new _ ADJ JJ _ 3 amod _ _
  7. 3 rate _ NOUN NN _ 6 nsubj _ _
  8. 4 will _ AUX MD _ 6 aux _ _
  9. 5 be _ VERB VB _ 6 cop _ _
  10. 6 payable _ ADJ JJ _ 0 root _ _
  11. 9 cents _ NOUN NNS _ 4 nmod _ _
  12. 10 from _ ADP IN _ 12 case _ _
  13. 11 seven _ NUM CD _ 12 nummod _ _
  14. 12 cents _ NOUN NNS _ 4 nmod _ _
  15. 13 a _ DET DT _ 14 det _ _
  16. 14 share _ NOUN NN _ 12 nmod:npmod _ _
  17. 15 . _ PUNCT . _ 4 punct _ _
  18. 1 The _ DET DT _ 3 det _ _
  19. 2 new _ ADJ JJ _ 3 amod _ _
  20. 3 rate _ NOUN NN _ 6 nsubj _ _
  21. 4 will _ AUX MD _ 6 aux _ _
  22. 5 be _ VERB VB _ 6 cop _ _
  23. 6 payable _ ADJ JJ _ 0 root _ _
  24. 7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _
  25. 8 15 _ NUM CD _ 7 nummod _ _
  26. 9 . _ PUNCT . _ 6 punct _ _
  27. 1 A _ DET DT _ 3 det _ _
  28. 2 record _ NOUN NN _ 3 compound _ _
  29. 3 date _ NOUN NN _ 7 nsubjpass _ _
  30. 4 has _ AUX VBZ _ 7 aux _ _
  31. 5 n't _ PART RB _ 7 neg _ _
  32. 6 been _ AUX VBN _ 7 auxpass _ _
  33. 7 set _ VERB VBN _ 0 root _ _
  34. 8 . _ PUNCT . _ 7 punct _ _
  35. """
  36. def init_data():
  37. ds = fastNLP.DataSet()
  38. v = {'word_seq': fastNLP.Vocabulary(),
  39. 'pos_seq': fastNLP.Vocabulary(),
  40. 'label_true': fastNLP.Vocabulary()}
  41. data = []
  42. for line in data_file.split('\n'):
  43. line = line.split()
  44. if len(line) == 0 and len(data) > 0:
  45. data = list(zip(*data))
  46. ds.append(fastNLP.Instance(word_seq=data[1],
  47. pos_seq=data[4],
  48. arc_true=data[6],
  49. label_true=data[7]))
  50. data = []
  51. elif len(line) > 0:
  52. data.append(line)
  53. for name in ['word_seq', 'pos_seq', 'label_true']:
  54. ds.apply(lambda x: ['<st>']+list(x[name])+['<ed>'], new_field_name=name)
  55. ds.apply(lambda x: v[name].add_word_lst(x[name]))
  56. for name in ['word_seq', 'pos_seq', 'label_true']:
  57. ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name)
  58. ds.apply(lambda x: [0]+list(map(int, x['arc_true']))+[1], new_field_name='arc_true')
  59. ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens')
  60. ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True)
  61. ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True)
  62. return ds, v['word_seq'], v['pos_seq'], v['label_true']
  63. class TestBiaffineParser(unittest.TestCase):
  64. def test_train(self):
  65. ds, v1, v2, v3 = init_data()
  66. model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30,
  67. pos_vocab_size=len(v2), pos_emb_dim=30,
  68. num_label=len(v3))
  69. trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds,
  70. loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
  71. n_epochs=10, use_cuda=False, use_tqdm=False)
  72. trainer.train(load_best_model=False)