You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Metric.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. from __future__ import division
  18. import torch
  19. import torch.nn.functional as F
  20. from rouge import Rouge
  21. from fastNLP.core.const import Const
  22. from fastNLP.core.metrics import MetricBase
  23. # from tools.logger import *
  24. from fastNLP.core._logger import logger
  25. from tools.utils import pyrouge_score_all, pyrouge_score_all_multi
  26. class LossMetric(MetricBase):
  27. def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'):
  28. super().__init__()
  29. self._init_param_map(pred=pred, target=target, mask=mask)
  30. self.padding_idx = padding_idx
  31. self.reduce = reduce
  32. self.loss = 0.0
  33. self.iteration = 0
  34. def evaluate(self, pred, target, mask):
  35. """
  36. :param pred: [batch, N, 2]
  37. :param target: [batch, N]
  38. :param input_mask: [batch, N]
  39. :return:
  40. """
  41. batch, N, _ = pred.size()
  42. pred = pred.view(-1, 2)
  43. target = target.view(-1)
  44. loss = F.cross_entropy(input=pred, target=target,
  45. ignore_index=self.padding_idx, reduction=self.reduce)
  46. loss = loss.view(batch, -1)
  47. loss = loss.masked_fill(mask.eq(False), 0)
  48. loss = loss.sum(1).mean()
  49. self.loss += loss
  50. self.iteration += 1
  51. def get_metric(self, reset=True):
  52. epoch_avg_loss = self.loss / self.iteration
  53. if reset:
  54. self.loss = 0.0
  55. self.iteration = 0
  56. metric = {"loss": -epoch_avg_loss}
  57. logger.info(metric)
  58. return metric
  59. class LabelFMetric(MetricBase):
  60. def __init__(self, pred=None, target=None):
  61. super().__init__()
  62. self._init_param_map(pred=pred, target=target)
  63. self.match = 0.0
  64. self.pred = 0.0
  65. self.true = 0.0
  66. self.match_true = 0.0
  67. self.total = 0.0
  68. def evaluate(self, pred, target):
  69. """
  70. :param pred: [batch, N] int
  71. :param target: [batch, N] int
  72. :return:
  73. """
  74. target = target.data
  75. pred = pred.data
  76. # logger.debug(pred.size())
  77. # logger.debug(pred[:5,:])
  78. batch, N = pred.size()
  79. self.pred += pred.sum()
  80. self.true += target.sum()
  81. self.match += (pred == target).sum()
  82. self.match_true += ((pred == target) & (pred == 1)).sum()
  83. self.total += batch * N
  84. def get_metric(self, reset=True):
  85. self.match,self.pred, self.true, self.match_true, self.total = self.match.float(),self.pred.float(), self.true.float(), self.match_true.float(), self.total
  86. logger.debug((self.match,self.pred, self.true, self.match_true, self.total))
  87. try:
  88. accu = self.match / self.total
  89. precision = self.match_true / self.pred
  90. recall = self.match_true / self.true
  91. F = 2 * precision * recall / (precision + recall)
  92. except ZeroDivisionError:
  93. F = 0.0
  94. logger.error("[Error] float division by zero")
  95. if reset:
  96. self.pred, self.true, self.match_true, self.match, self.total = 0, 0, 0, 0, 0
  97. ret = {"accu": accu.cpu(), "p":precision.cpu(), "r":recall.cpu(), "f": F.cpu()}
  98. logger.info(ret)
  99. return ret
  100. class RougeMetric(MetricBase):
  101. def __init__(self, hps, pred=None, text=None, refer=None):
  102. super().__init__()
  103. self._hps = hps
  104. self._init_param_map(pred=pred, text=text, summary=refer)
  105. self.hyps = []
  106. self.refers = []
  107. def evaluate(self, pred, text, summary):
  108. """
  109. :param prediction: [batch, N]
  110. :param text: [batch, N]
  111. :param summary: [batch, N]
  112. :return:
  113. """
  114. batch_size, N = pred.size()
  115. for j in range(batch_size):
  116. original_article_sents = text[j]
  117. sent_max_number = len(original_article_sents)
  118. refer = "\n".join(summary[j])
  119. hyps = "\n".join(original_article_sents[id] for id in range(len(pred[j])) if
  120. pred[j][id] == 1 and id < sent_max_number)
  121. if sent_max_number < self._hps.m and len(hyps) <= 1:
  122. print("sent_max_number is too short %d, Skip!", sent_max_number)
  123. continue
  124. if len(hyps) >= 1 and hyps != '.':
  125. self.hyps.append(hyps)
  126. self.refers.append(refer)
  127. elif refer == "." or refer == "":
  128. logger.error("Refer is None!")
  129. logger.debug(refer)
  130. elif hyps == "." or hyps == "":
  131. logger.error("hyps is None!")
  132. logger.debug("sent_max_number:%d", sent_max_number)
  133. logger.debug("pred:")
  134. logger.debug(pred[j])
  135. logger.debug(hyps)
  136. else:
  137. logger.error("Do not select any sentences!")
  138. logger.debug("sent_max_number:%d", sent_max_number)
  139. logger.debug(original_article_sents)
  140. logger.debug(refer)
  141. continue
  142. def get_metric(self, reset=True):
  143. pass
  144. class FastRougeMetric(RougeMetric):
  145. def __init__(self, hps, pred=None, text=None, refer=None):
  146. super().__init__(hps, pred, text, refer)
  147. def get_metric(self, reset=True):
  148. logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers))
  149. if len(self.hyps) == 0 or len(self.refers) == 0 :
  150. logger.error("During testing, no hyps or refers is selected!")
  151. return
  152. rouge = Rouge()
  153. scores_all = rouge.get_scores(self.hyps, self.refers, avg=True)
  154. if reset:
  155. self.hyps = []
  156. self.refers = []
  157. logger.info(scores_all)
  158. return scores_all
  159. class PyRougeMetric(RougeMetric):
  160. def __init__(self, hps, pred=None, text=None, refer=None):
  161. super().__init__(hps, pred, text, refer)
  162. def get_metric(self, reset=True):
  163. logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers))
  164. if len(self.hyps) == 0 or len(self.refers) == 0:
  165. logger.error("During testing, no hyps or refers is selected!")
  166. return
  167. if isinstance(self.refers[0], list):
  168. logger.info("Multi Reference summaries!")
  169. scores_all = pyrouge_score_all_multi(self.hyps, self.refers)
  170. else:
  171. scores_all = pyrouge_score_all(self.hyps, self.refers)
  172. if reset:
  173. self.hyps = []
  174. self.refers = []
  175. logger.info(scores_all)
  176. return scores_all