You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 11 kB


  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import re
  4. import os
  5. import shutil
  6. import copy
  7. import datetime
  8. import numpy as np
  9. from rouge import Rouge
  10. from .logger import *
  11. # from data import *
  12. import sys
  13. sys.setrecursionlimit(10000)
  14. REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}",
  15. "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'}
  16. def clean(x):
  17. return re.sub(
  18. r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''",
  19. lambda m: REMAP.get(m.group()), x)
  20. def rouge_eval(hyps, refer):
  21. rouge = Rouge()
  22. # print(hyps)
  23. # print(refer)
  24. # print(rouge.get_scores(hyps, refer))
  25. try:
  26. score = rouge.get_scores(hyps, refer)[0]
  27. mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]])
  28. except:
  29. mean_score = 0.0
  30. return mean_score
  31. def rouge_all(hyps, refer):
  32. rouge = Rouge()
  33. score = rouge.get_scores(hyps, refer)[0]
  34. # mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]])
  35. return score
  36. def eval_label(match_true, pred, true, total, match):
  37. match_true, pred, true, match = match_true.float(), pred.float(), true.float(), match.float()
  38. try:
  39. accu = match / total
  40. precision = match_true / pred
  41. recall = match_true / true
  42. F = 2 * precision * recall / (precision + recall)
  43. except ZeroDivisionError:
  44. F = 0.0
  45. logger.error("[Error] float division by zero")
  46. return accu, precision, recall, F
  47. def pyrouge_score(hyps, refer, remap = True):
  48. from pyrouge import Rouge155
  49. nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
  50. PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime)
  51. SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold')
  52. MODEL_PATH = os.path.join(PYROUGE_ROOT,'system')
  53. if os.path.exists(SYSTEM_PATH):
  54. shutil.rmtree(SYSTEM_PATH)
  55. os.makedirs(SYSTEM_PATH)
  56. if os.path.exists(MODEL_PATH):
  57. shutil.rmtree(MODEL_PATH)
  58. os.makedirs(MODEL_PATH)
  59. if remap == True:
  60. refer = clean(refer)
  61. hyps = clean(hyps)
  62. system_file = os.path.join(SYSTEM_PATH, 'Reference.0.txt')
  63. model_file = os.path.join(MODEL_PATH, 'Model.A.0.txt')
  64. with open(system_file, 'wb') as f:
  65. f.write(refer.encode('utf-8'))
  66. with open(model_file, 'wb') as f:
  67. f.write(hyps.encode('utf-8'))
  68. r = Rouge155('/home/dqwang/ROUGE/RELEASE-1.5.5')
  69. r.system_dir = SYSTEM_PATH
  70. r.model_dir = MODEL_PATH
  71. r.system_filename_pattern = 'Reference.(\d+).txt'
  72. r.model_filename_pattern = 'Model.[A-Z].#ID#.txt'
  73. output = r.convert_and_evaluate(rouge_args="-e /home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d")
  74. output_dict = r.output_to_dict(output)
  75. shutil.rmtree(PYROUGE_ROOT)
  76. scores = {}
  77. scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {}
  78. scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score']
  79. scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score']
  80. scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score']
  81. return scores
  82. def pyrouge_score_all(hyps_list, refer_list, remap = True):
  83. from pyrouge import Rouge155
  84. nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
  85. PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime)
  86. SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold')
  87. MODEL_PATH = os.path.join(PYROUGE_ROOT,'system')
  88. if os.path.exists(SYSTEM_PATH):
  89. shutil.rmtree(SYSTEM_PATH)
  90. os.makedirs(SYSTEM_PATH)
  91. if os.path.exists(MODEL_PATH):
  92. shutil.rmtree(MODEL_PATH)
  93. os.makedirs(MODEL_PATH)
  94. assert len(hyps_list) == len(refer_list)
  95. for i in range(len(hyps_list)):
  96. system_file = os.path.join(SYSTEM_PATH, 'Reference.%d.txt' % i)
  97. model_file = os.path.join(MODEL_PATH, 'Model.A.%d.txt' % i)
  98. refer = clean(refer_list[i]) if remap else refer_list[i]
  99. hyps = clean(hyps_list[i]) if remap else hyps_list[i]
  100. with open(system_file, 'wb') as f:
  101. f.write(refer.encode('utf-8'))
  102. with open(model_file, 'wb') as f:
  103. f.write(hyps.encode('utf-8'))
  104. r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5')
  105. r.system_dir = SYSTEM_PATH
  106. r.model_dir = MODEL_PATH
  107. r.system_filename_pattern = 'Reference.(\d+).txt'
  108. r.model_filename_pattern = 'Model.[A-Z].#ID#.txt'
  109. output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d")
  110. output_dict = r.output_to_dict(output)
  111. shutil.rmtree(PYROUGE_ROOT)
  112. scores = {}
  113. scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {}
  114. scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score']
  115. scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score']
  116. scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score']
  117. return scores
  118. def pyrouge_score_all_multi(hyps_list, refer_list, remap = True):
  119. from pyrouge import Rouge155
  120. nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
  121. PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime)
  122. SYSTEM_PATH = os.path.join(PYROUGE_ROOT, 'system')
  123. MODEL_PATH = os.path.join(PYROUGE_ROOT, 'gold')
  124. if os.path.exists(SYSTEM_PATH):
  125. shutil.rmtree(SYSTEM_PATH)
  126. os.makedirs(SYSTEM_PATH)
  127. if os.path.exists(MODEL_PATH):
  128. shutil.rmtree(MODEL_PATH)
  129. os.makedirs(MODEL_PATH)
  130. assert len(hyps_list) == len(refer_list)
  131. for i in range(len(hyps_list)):
  132. system_file = os.path.join(SYSTEM_PATH, 'Model.%d.txt' % i)
  133. # model_file = os.path.join(MODEL_PATH, 'Reference.A.%d.txt' % i)
  134. hyps = clean(hyps_list[i]) if remap else hyps_list[i]
  135. with open(system_file, 'wb') as f:
  136. f.write(hyps.encode('utf-8'))
  137. referType = ["A", "B", "C", "D", "E", "F", "G"]
  138. for j in range(len(refer_list[i])):
  139. model_file = os.path.join(MODEL_PATH, "Reference.%s.%d.txt" % (referType[j], i))
  140. refer = clean(refer_list[i][j]) if remap else refer_list[i][j]
  141. with open(model_file, 'wb') as f:
  142. f.write(refer.encode('utf-8'))
  143. r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5')
  144. r.system_dir = SYSTEM_PATH
  145. r.model_dir = MODEL_PATH
  146. r.system_filename_pattern = 'Model.(\d+).txt'
  147. r.model_filename_pattern = 'Reference.[A-Z].#ID#.txt'
  148. output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d")
  149. output_dict = r.output_to_dict(output)
  150. shutil.rmtree(PYROUGE_ROOT)
  151. scores = {}
  152. scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {}
  153. scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score']
  154. scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score']
  155. scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score']
  156. return scores
  157. def cal_label(article, abstract):
  158. hyps_list = article
  159. refer = abstract
  160. scores = []
  161. for hyps in hyps_list:
  162. mean_score = rouge_eval(hyps, refer)
  163. scores.append(mean_score)
  164. selected = []
  165. selected.append(int(np.argmax(scores)))
  166. selected_sent_cnt = 1
  167. best_rouge = np.max(scores)
  168. while selected_sent_cnt < len(hyps_list):
  169. cur_max_rouge = 0.0
  170. cur_max_idx = -1
  171. for i in range(len(hyps_list)):
  172. if i not in selected:
  173. temp = copy.deepcopy(selected)
  174. temp.append(i)
  175. hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)])
  176. cur_rouge = rouge_eval(hyps, refer)
  177. if cur_rouge > cur_max_rouge:
  178. cur_max_rouge = cur_rouge
  179. cur_max_idx = i
  180. if cur_max_rouge != 0.0 and cur_max_rouge >= best_rouge:
  181. selected.append(cur_max_idx)
  182. selected_sent_cnt += 1
  183. best_rouge = cur_max_rouge
  184. else:
  185. break
  186. # label = np.zeros(len(hyps_list), dtype=int)
  187. # label[np.array(selected)] = 1
  188. # return list(label)
  189. return selected
  190. def cal_label_limited3(article, abstract):
  191. hyps_list = article
  192. refer = abstract
  193. scores = []
  194. for hyps in hyps_list:
  195. try:
  196. mean_score = rouge_eval(hyps, refer)
  197. scores.append(mean_score)
  198. except ValueError:
  199. scores.append(0.0)
  200. selected = []
  201. selected.append(np.argmax(scores))
  202. selected_sent_cnt = 1
  203. best_rouge = np.max(scores)
  204. while selected_sent_cnt < len(hyps_list) and selected_sent_cnt < 3:
  205. cur_max_rouge = 0.0
  206. cur_max_idx = -1
  207. for i in range(len(hyps_list)):
  208. if i not in selected:
  209. temp = copy.deepcopy(selected)
  210. temp.append(i)
  211. hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)])
  212. cur_rouge = rouge_eval(hyps, refer)
  213. if cur_rouge > cur_max_rouge:
  214. cur_max_rouge = cur_rouge
  215. cur_max_idx = i
  216. selected.append(cur_max_idx)
  217. selected_sent_cnt += 1
  218. best_rouge = cur_max_rouge
  219. # logger.info(selected)
  220. # label = np.zeros(len(hyps_list), dtype=int)
  221. # label[np.array(selected)] = 1
  222. # return list(label)
  223. return selected
  224. import torch
  225. def flip(x, dim):
  226. xsize = x.size()
  227. dim = x.dim() + dim if dim < 0 else dim
  228. x = x.contiguous()
  229. x = x.view(-1, *xsize[dim:]).contiguous()
  230. x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1,
  231. -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
  232. return x.view(xsize)
  233. def get_attn_key_pad_mask(seq_k, seq_q):
  234. ''' For masking out the padding part of key sequence. '''
  235. # Expand to fit the shape of key query attention matrix.
  236. len_q = seq_q.size(1)
  237. padding_mask = seq_k.eq(0.0)
  238. padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
  239. return padding_mask
  240. def get_non_pad_mask(seq):
  241. assert seq.dim() == 2
  242. return seq.ne(0.0).type(torch.float).unsqueeze(-1)