You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 3.9 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import sys
  2. sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])
  3. import torch
  4. import argparse
  5. from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
  6. from fastNLP.core.dataset import DataSet
  7. from fastNLP.core.instance import Instance
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument('--pipe', type=str, default='')
  10. parser.add_argument('--gold_data', type=str, default='')
  11. parser.add_argument('--new_data', type=str)
  12. args = parser.parse_args()
  13. pipe = torch.load(args.pipe)['pipeline']
  14. for p in pipe:
  15. if p.field_name == 'word_list':
  16. print(p.field_name)
  17. p.field_name = 'gold_words'
  18. elif p.field_name == 'pos_list':
  19. print(p.field_name)
  20. p.field_name = 'gold_pos'
  21. data = ConllxDataLoader().load(args.gold_data)
  22. ds = DataSet()
  23. for ins1, ins2 in zip(add_seg_tag(data), data):
  24. ds.append(Instance(words=ins1[0], tag=ins1[1],
  25. gold_words=ins2[0], gold_pos=ins2[1],
  26. gold_heads=ins2[2], gold_head_tags=ins2[3]))
  27. ds = pipe(ds)
  28. seg_threshold = 0.
  29. pos_threshold = 0.
  30. parse_threshold = 0.74
  31. def get_heads(ins, head_f, word_f):
  32. head_pred = []
  33. for i, idx in enumerate(ins[head_f]):
  34. j = idx - 1 if idx != 0 else i
  35. head_pred.append(ins[word_f][j])
  36. return head_pred
  37. def evaluate(ins):
  38. seg_count = sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j])
  39. pos_count = sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j])
  40. head_count = sum([1 for i, j in zip(ins['heads'], ins['gold_heads']) if i == j])
  41. total = len(ins['gold_words'])
  42. return seg_count / total, pos_count / total, head_count / total
  43. def is_ok(x):
  44. seg, pos, head = x[1]
  45. return seg > seg_threshold and pos > pos_threshold and head > parse_threshold
  46. res_list = []
  47. for i, ins in enumerate(ds):
  48. res_list.append((i, evaluate(ins)))
  49. res_list = list(filter(is_ok, res_list))
  50. print('{} {}'.format(len(ds), len(res_list)))
  51. seg_cor, pos_cor, head_cor, label_cor, total = 0,0,0,0,0
  52. for i, _ in res_list:
  53. ins = ds[i]
  54. # print(i)
  55. # print('gold_words:\t', ins['gold_words'])
  56. # print('predict_words:\t', ins['word_list'])
  57. # print('gold_tag:\t', ins['gold_pos'])
  58. # print('predict_tag:\t', ins['pos_list'])
  59. # print('gold_heads:\t', ins['gold_heads'])
  60. # print('predict_heads:\t', ins['heads'].tolist())
  61. # print('gold_head_tags:\t', ins['gold_head_tags'])
  62. # print('predict_labels:\t', ins['labels'])
  63. # print()
  64. head_pred = ins['heads']
  65. head_gold = ins['gold_heads']
  66. label_pred = ins['labels']
  67. label_gold = ins['gold_head_tags']
  68. total += len(head_gold)
  69. seg_cor += sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j])
  70. pos_cor += sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j])
  71. length = len(head_gold)
  72. for i in range(length):
  73. head_cor += 1 if head_pred[i] == head_gold[i] else 0
  74. label_cor += 1 if head_pred[i] == head_gold[i] and label_gold[i] == label_pred[i] else 0
  75. print('SEG: {}, POS: {}, UAS: {}, LAS: {}'.format(seg_cor/total, pos_cor/total, head_cor/total, label_cor/total))
  76. colln_path = args.gold_data
  77. new_colln_path = args.new_data
  78. index_list = [x[0] for x in res_list]
  79. with open(colln_path, 'r', encoding='utf-8') as f1, \
  80. open(new_colln_path, 'w', encoding='utf-8') as f2:
  81. for idx, ins in enumerate(ds):
  82. if idx in index_list:
  83. length = len(ins['gold_words'])
  84. pad = ['_' for _ in range(length)]
  85. for x in zip(
  86. map(str, range(1, length+1)), ins['gold_words'], ins['gold_words'], ins['gold_pos'],
  87. pad, pad, map(str, ins['gold_heads']), ins['gold_head_tags']):
  88. new_lines = '\t'.join(x)
  89. f2.write(new_lines)
  90. f2.write('\n')
  91. f2.write('\n')