You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metrics.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import numpy as np
  2. import json
  3. from os.path import join
  4. import torch
  5. import logging
  6. import tempfile
  7. import subprocess as sp
  8. from datetime import timedelta
  9. from time import time
  10. from pyrouge import Rouge155
  11. from pyrouge.utils import log
  12. from fastNLP.core.losses import LossBase
  13. from fastNLP.core.metrics import MetricBase
  14. _ROUGE_PATH = '/path/to/RELEASE-1.5.5'
  15. class MyBCELoss(LossBase):
  16. def __init__(self, pred=None, target=None, mask=None):
  17. super(MyBCELoss, self).__init__()
  18. self._init_param_map(pred=pred, target=target, mask=mask)
  19. self.loss_func = torch.nn.BCELoss(reduction='none')
  20. def get_loss(self, pred, target, mask):
  21. loss = self.loss_func(pred, target.float())
  22. loss = (loss * mask.float()).sum()
  23. return loss
  24. class LossMetric(MetricBase):
  25. def __init__(self, pred=None, target=None, mask=None):
  26. super(LossMetric, self).__init__()
  27. self._init_param_map(pred=pred, target=target, mask=mask)
  28. self.loss_func = torch.nn.BCELoss(reduction='none')
  29. self.avg_loss = 0.0
  30. self.nsamples = 0
  31. def evaluate(self, pred, target, mask):
  32. batch_size = pred.size(0)
  33. loss = self.loss_func(pred, target.float())
  34. loss = (loss * mask.float()).sum()
  35. self.avg_loss += loss
  36. self.nsamples += batch_size
  37. def get_metric(self, reset=True):
  38. self.avg_loss = self.avg_loss / self.nsamples
  39. eval_result = {'loss': self.avg_loss}
  40. if reset:
  41. self.avg_loss = 0
  42. self.nsamples = 0
  43. return eval_result
  44. class RougeMetric(MetricBase):
  45. def __init__(self, data_path, dec_path, ref_path, n_total, n_ext=3, ngram_block=3, pred=None, target=None, mask=None):
  46. super(RougeMetric, self).__init__()
  47. self._init_param_map(pred=pred, target=target, mask=mask)
  48. self.data_path = data_path
  49. self.dec_path = dec_path
  50. self.ref_path = ref_path
  51. self.n_total = n_total
  52. self.n_ext = n_ext
  53. self.ngram_block = ngram_block
  54. self.cur_idx = 0
  55. self.ext = []
  56. self.start = time()
  57. @staticmethod
  58. def eval_rouge(dec_dir, ref_dir):
  59. assert _ROUGE_PATH is not None
  60. log.get_global_console_logger().setLevel(logging.WARNING)
  61. dec_pattern = '(\d+).dec'
  62. ref_pattern = '#ID#.ref'
  63. cmd = '-c 95 -r 1000 -n 2 -m'
  64. with tempfile.TemporaryDirectory() as tmp_dir:
  65. Rouge155.convert_summaries_to_rouge_format(
  66. dec_dir, join(tmp_dir, 'dec'))
  67. Rouge155.convert_summaries_to_rouge_format(
  68. ref_dir, join(tmp_dir, 'ref'))
  69. Rouge155.write_config_static(
  70. join(tmp_dir, 'dec'), dec_pattern,
  71. join(tmp_dir, 'ref'), ref_pattern,
  72. join(tmp_dir, 'settings.xml'), system_id=1
  73. )
  74. cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl')
  75. + ' -e {} '.format(join(_ROUGE_PATH, 'data'))
  76. + cmd
  77. + ' -a {}'.format(join(tmp_dir, 'settings.xml')))
  78. output = sp.check_output(cmd.split(' '), universal_newlines=True)
  79. R_1 = float(output.split('\n')[3].split(' ')[3])
  80. R_2 = float(output.split('\n')[7].split(' ')[3])
  81. R_L = float(output.split('\n')[11].split(' ')[3])
  82. print(output)
  83. return R_1, R_2, R_L
  84. def evaluate(self, pred, target, mask):
  85. pred = pred + mask.float()
  86. pred = pred.cpu().data.numpy()
  87. ext_ids = np.argsort(-pred, 1)
  88. for sent_id in ext_ids:
  89. self.ext.append(sent_id)
  90. self.cur_idx += 1
  91. print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
  92. self.cur_idx, self.n_total, self.cur_idx/self.n_total*100, timedelta(seconds=int(time()-self.start))
  93. ), end='')
  94. def get_metric(self, use_ngram_block=True, reset=True):
  95. def check_n_gram(sentence, n, dic):
  96. tokens = sentence.split(' ')
  97. s_len = len(tokens)
  98. for i in range(s_len):
  99. if i + n > s_len:
  100. break
  101. if ' '.join(tokens[i: i + n]) in dic:
  102. return False
  103. return True # no n_gram overlap
  104. # load original data
  105. data = []
  106. with open(self.data_path) as f:
  107. for line in f:
  108. cur_data = json.loads(line)
  109. if 'text' in cur_data:
  110. new_data = {}
  111. new_data['article'] = cur_data['text']
  112. new_data['abstract'] = cur_data['summary']
  113. data.append(new_data)
  114. else:
  115. data.append(cur_data)
  116. # write decode sentences and references
  117. if use_ngram_block == True:
  118. print('\nStart {}-gram blocking !!!'.format(self.ngram_block))
  119. for i, ext_ids in enumerate(self.ext):
  120. dec, ref = [], []
  121. if use_ngram_block == False:
  122. n_sent = min(len(data[i]['article']), self.n_ext)
  123. for j in range(n_sent):
  124. idx = ext_ids[j]
  125. dec.append(data[i]['article'][idx])
  126. else:
  127. n_sent = len(ext_ids)
  128. dic = {}
  129. for j in range(n_sent):
  130. sent = data[i]['article'][ext_ids[j]]
  131. if check_n_gram(sent, self.ngram_block, dic) == True:
  132. dec.append(sent)
  133. # update dic
  134. tokens = sent.split(' ')
  135. s_len = len(tokens)
  136. for k in range(s_len):
  137. if k + self.ngram_block > s_len:
  138. break
  139. dic[' '.join(tokens[k: k + self.ngram_block])] = 1
  140. if len(dec) >= self.n_ext:
  141. break
  142. for sent in data[i]['abstract']:
  143. ref.append(sent)
  144. with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f:
  145. for sent in dec:
  146. print(sent, file=f)
  147. with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f:
  148. for sent in ref:
  149. print(sent, file=f)
  150. print('\nStart evaluating ROUGE score !!!')
  151. R_1, R_2, R_L = RougeMetric.eval_rouge(self.dec_path, self.ref_path)
  152. eval_result = {'ROUGE-1': R_1, 'ROUGE-2': R_2, 'ROUGE-L':R_L}
  153. if reset == True:
  154. self.cur_idx = 0
  155. self.ext = []
  156. self.start = time()
  157. return eval_result