#!/usr/bin/python # -*- coding: utf-8 -*- import re import os import shutil import copy import datetime import numpy as np from rouge import Rouge from .logger import * # from data import * import sys sys.setrecursionlimit(10000) REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} def clean(x): return re.sub( r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", lambda m: REMAP.get(m.group()), x) def rouge_eval(hyps, refer): rouge = Rouge() # print(hyps) # print(refer) # print(rouge.get_scores(hyps, refer)) try: score = rouge.get_scores(hyps, refer)[0] mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]) except: mean_score = 0.0 return mean_score def rouge_all(hyps, refer): rouge = Rouge() score = rouge.get_scores(hyps, refer)[0] # mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]) return score def eval_label(match_true, pred, true, total, match): match_true, pred, true, match = match_true.float(), pred.float(), true.float(), match.float() try: accu = match / total precision = match_true / pred recall = match_true / true F = 2 * precision * recall / (precision + recall) except ZeroDivisionError: F = 0.0 logger.error("[Error] float division by zero") return accu, precision, recall, F def pyrouge_score(hyps, refer, remap = True): from pyrouge import Rouge155 nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold') MODEL_PATH = os.path.join(PYROUGE_ROOT,'system') if os.path.exists(SYSTEM_PATH): shutil.rmtree(SYSTEM_PATH) os.makedirs(SYSTEM_PATH) if os.path.exists(MODEL_PATH): shutil.rmtree(MODEL_PATH) os.makedirs(MODEL_PATH) if remap == True: refer = clean(refer) hyps = clean(hyps) system_file = os.path.join(SYSTEM_PATH, 'Reference.0.txt') model_file = os.path.join(MODEL_PATH, 'Model.A.0.txt') with open(system_file, 'wb') as f: f.write(refer.encode('utf-8')) with open(model_file, 'wb') as f: f.write(hyps.encode('utf-8')) r = Rouge155('/home/dqwang/ROUGE/RELEASE-1.5.5') r.system_dir = SYSTEM_PATH r.model_dir = MODEL_PATH r.system_filename_pattern = 'Reference.(\d+).txt' r.model_filename_pattern = 'Model.[A-Z].#ID#.txt' output = r.convert_and_evaluate(rouge_args="-e /home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") output_dict = r.output_to_dict(output) shutil.rmtree(PYROUGE_ROOT) scores = {} scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] return scores def pyrouge_score_all(hyps_list, refer_list, remap = True): from pyrouge import Rouge155 nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) SYSTEM_PATH = os.path.join(PYROUGE_ROOT,'gold') MODEL_PATH = os.path.join(PYROUGE_ROOT,'system') if os.path.exists(SYSTEM_PATH): shutil.rmtree(SYSTEM_PATH) os.makedirs(SYSTEM_PATH) if os.path.exists(MODEL_PATH): shutil.rmtree(MODEL_PATH) os.makedirs(MODEL_PATH) assert len(hyps_list) == len(refer_list) for i in range(len(hyps_list)): system_file = os.path.join(SYSTEM_PATH, 'Reference.%d.txt' % i) model_file = os.path.join(MODEL_PATH, 'Model.A.%d.txt' % i) refer = clean(refer_list[i]) if remap else refer_list[i] hyps = clean(hyps_list[i]) if remap else hyps_list[i] with open(system_file, 'wb') as f: f.write(refer.encode('utf-8')) with open(model_file, 'wb') as f: f.write(hyps.encode('utf-8')) r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5') r.system_dir = SYSTEM_PATH r.model_dir = MODEL_PATH r.system_filename_pattern = 'Reference.(\d+).txt' r.model_filename_pattern = 'Model.[A-Z].#ID#.txt' output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") output_dict = r.output_to_dict(output) shutil.rmtree(PYROUGE_ROOT) scores = {} scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] return scores def pyrouge_score_all_multi(hyps_list, refer_list, remap = True): from pyrouge import Rouge155 nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') PYROUGE_ROOT = os.path.join('/remote-home/dqwang/', nowTime) SYSTEM_PATH = os.path.join(PYROUGE_ROOT, 'system') MODEL_PATH = os.path.join(PYROUGE_ROOT, 'gold') if os.path.exists(SYSTEM_PATH): shutil.rmtree(SYSTEM_PATH) os.makedirs(SYSTEM_PATH) if os.path.exists(MODEL_PATH): shutil.rmtree(MODEL_PATH) os.makedirs(MODEL_PATH) assert len(hyps_list) == len(refer_list) for i in range(len(hyps_list)): system_file = os.path.join(SYSTEM_PATH, 'Model.%d.txt' % i) # model_file = os.path.join(MODEL_PATH, 'Reference.A.%d.txt' % i) hyps = clean(hyps_list[i]) if remap else hyps_list[i] with open(system_file, 'wb') as f: f.write(hyps.encode('utf-8')) referType = ["A", "B", "C", "D", "E", "F", "G"] for j in range(len(refer_list[i])): model_file = os.path.join(MODEL_PATH, "Reference.%s.%d.txt" % (referType[j], i)) refer = clean(refer_list[i][j]) if remap else refer_list[i][j] with open(model_file, 'wb') as f: f.write(refer.encode('utf-8')) r = Rouge155('/remote-home/dqwang/ROUGE/RELEASE-1.5.5') r.system_dir = SYSTEM_PATH r.model_dir = MODEL_PATH r.system_filename_pattern = 'Model.(\d+).txt' r.model_filename_pattern = 'Reference.[A-Z].#ID#.txt' output = r.convert_and_evaluate(rouge_args="-e /remote-home/dqwang/ROUGE/RELEASE-1.5.5/data -a -m -n 2 -d") output_dict = r.output_to_dict(output) shutil.rmtree(PYROUGE_ROOT) scores = {} scores['rouge-1'], scores['rouge-2'], scores['rouge-l'] = {}, {}, {} scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'] = output_dict['rouge_1_precision'], output_dict['rouge_1_recall'], output_dict['rouge_1_f_score'] scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'] = output_dict['rouge_2_precision'], output_dict['rouge_2_recall'], output_dict['rouge_2_f_score'] scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'] = output_dict['rouge_l_precision'], output_dict['rouge_l_recall'], output_dict['rouge_l_f_score'] return scores def cal_label(article, abstract): hyps_list = article refer = abstract scores = [] for hyps in hyps_list: mean_score = rouge_eval(hyps, refer) scores.append(mean_score) selected = [] selected.append(int(np.argmax(scores))) selected_sent_cnt = 1 best_rouge = np.max(scores) while selected_sent_cnt < len(hyps_list): cur_max_rouge = 0.0 cur_max_idx = -1 for i in range(len(hyps_list)): if i not in selected: temp = copy.deepcopy(selected) temp.append(i) hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)]) cur_rouge = rouge_eval(hyps, refer) if cur_rouge > cur_max_rouge: cur_max_rouge = cur_rouge cur_max_idx = i if cur_max_rouge != 0.0 and cur_max_rouge >= best_rouge: selected.append(cur_max_idx) selected_sent_cnt += 1 best_rouge = cur_max_rouge else: break # label = np.zeros(len(hyps_list), dtype=int) # label[np.array(selected)] = 1 # return list(label) return selected def cal_label_limited3(article, abstract): hyps_list = article refer = abstract scores = [] for hyps in hyps_list: try: mean_score = rouge_eval(hyps, refer) scores.append(mean_score) except ValueError: scores.append(0.0) selected = [] selected.append(np.argmax(scores)) selected_sent_cnt = 1 best_rouge = np.max(scores) while selected_sent_cnt < len(hyps_list) and selected_sent_cnt < 3: cur_max_rouge = 0.0 cur_max_idx = -1 for i in range(len(hyps_list)): if i not in selected: temp = copy.deepcopy(selected) temp.append(i) hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)]) cur_rouge = rouge_eval(hyps, refer) if cur_rouge > cur_max_rouge: cur_max_rouge = cur_rouge cur_max_idx = i selected.append(cur_max_idx) selected_sent_cnt += 1 best_rouge = cur_max_rouge # logger.info(selected) # label = np.zeros(len(hyps_list), dtype=int) # label[np.array(selected)] = 1 # return list(label) return selected import torch def flip(x, dim): xsize = x.size() dim = x.dim() + dim if dim < 0 else dim x = x.contiguous() x = x.view(-1, *xsize[dim:]).contiguous() x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] return x.view(xsize) def get_attn_key_pad_mask(seq_k, seq_q): ''' For masking out the padding part of key sequence. ''' # Expand to fit the shape of key query attention matrix. len_q = seq_q.size(1) padding_mask = seq_k.eq(0.0) padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk return padding_mask def get_non_pad_mask(seq): assert seq.dim() == 2 return seq.ne(0.0).type(torch.float).unsqueeze(-1)