You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Metric.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. from __future__ import division
  18. import torch
  19. from rouge import Rouge
  20. from fastNLP.core.const import Const
  21. from fastNLP.core.metrics import MetricBase
  22. from tools.logger import *
  23. from tools.utils import pyrouge_score_all, pyrouge_score_all_multi
  24. class LabelFMetric(MetricBase):
  25. def __init__(self, pred=None, target=None):
  26. super().__init__()
  27. self._init_param_map(pred=pred, target=target)
  28. self.match = 0.0
  29. self.pred = 0.0
  30. self.true = 0.0
  31. self.match_true = 0.0
  32. self.total = 0.0
  33. def evaluate(self, pred, target):
  34. """
  35. :param pred: [batch, N] int
  36. :param target: [batch, N] int
  37. :return:
  38. """
  39. target = target.data
  40. pred = pred.data
  41. # logger.debug(pred.size())
  42. # logger.debug(pred[:5,:])
  43. batch, N = pred.size()
  44. self.pred += pred.sum()
  45. self.true += target.sum()
  46. self.match += (pred == target).sum()
  47. self.match_true += ((pred == target) & (pred == 1)).sum()
  48. self.total += batch * N
  49. def get_metric(self, reset=True):
  50. self.match,self.pred, self.true, self.match_true, self.total = self.match.float(),self.pred.float(), self.true.float(), self.match_true.float(), self.total
  51. logger.debug((self.match,self.pred, self.true, self.match_true, self.total))
  52. try:
  53. accu = self.match / self.total
  54. precision = self.match_true / self.pred
  55. recall = self.match_true / self.true
  56. F = 2 * precision * recall / (precision + recall)
  57. except ZeroDivisionError:
  58. F = 0.0
  59. logger.error("[Error] float division by zero")
  60. if reset:
  61. self.pred, self.true, self.match_true, self.match, self.total = 0, 0, 0, 0, 0
  62. ret = {"accu": accu.cpu(), "p":precision.cpu(), "r":recall.cpu(), "f": F.cpu()}
  63. logger.info(ret)
  64. return ret
  65. class RougeMetric(MetricBase):
  66. def __init__(self, hps, pred=None, text=None, refer=None):
  67. super().__init__()
  68. self._hps = hps
  69. self._init_param_map(pred=pred, text=text, summary=refer)
  70. self.hyps = []
  71. self.refers = []
  72. def evaluate(self, pred, text, summary):
  73. """
  74. :param prediction: [batch, N]
  75. :param text: [batch, N]
  76. :param summary: [batch, N]
  77. :return:
  78. """
  79. batch_size, N = pred.size()
  80. for j in range(batch_size):
  81. original_article_sents = text[j]
  82. sent_max_number = len(original_article_sents)
  83. refer = "\n".join(summary[j])
  84. hyps = "\n".join(original_article_sents[id] for id in range(len(pred[j])) if
  85. pred[j][id] == 1 and id < sent_max_number)
  86. if sent_max_number < self._hps.m and len(hyps) <= 1:
  87. print("sent_max_number is too short %d, Skip!", sent_max_number)
  88. continue
  89. if len(hyps) >= 1 and hyps != '.':
  90. self.hyps.append(hyps)
  91. self.refers.append(refer)
  92. elif refer == "." or refer == "":
  93. logger.error("Refer is None!")
  94. logger.debug(refer)
  95. elif hyps == "." or hyps == "":
  96. logger.error("hyps is None!")
  97. logger.debug("sent_max_number:%d", sent_max_number)
  98. logger.debug("pred:")
  99. logger.debug(pred[j])
  100. logger.debug(hyps)
  101. else:
  102. logger.error("Do not select any sentences!")
  103. logger.debug("sent_max_number:%d", sent_max_number)
  104. logger.debug(original_article_sents)
  105. logger.debug(refer)
  106. continue
  107. def get_metric(self, reset=True):
  108. pass
  109. class FastRougeMetric(RougeMetric):
  110. def __init__(self, hps, pred=None, text=None, refer=None):
  111. super().__init__(hps, pred, text, refer)
  112. def get_metric(self, reset=True):
  113. logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers))
  114. if len(self.hyps) == 0 or len(self.refers) == 0 :
  115. logger.error("During testing, no hyps or refers is selected!")
  116. return
  117. rouge = Rouge()
  118. scores_all = rouge.get_scores(self.hyps, self.refers, avg=True)
  119. if reset:
  120. self.hyps = []
  121. self.refers = []
  122. logger.info(scores_all)
  123. return scores_all
  124. class PyRougeMetric(RougeMetric):
  125. def __init__(self, hps, pred=None, text=None, refer=None):
  126. super().__init__(hps, pred, text, refer)
  127. def get_metric(self, reset=True):
  128. logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers))
  129. if len(self.hyps) == 0 or len(self.refers) == 0:
  130. logger.error("During testing, no hyps or refers is selected!")
  131. return
  132. if isinstance(self.refers[0], list):
  133. logger.info("Multi Reference summaries!")
  134. scores_all = pyrouge_score_all_multi(self.hyps, self.refers)
  135. else:
  136. scores_all = pyrouge_score_all(self.hyps, self.refers)
  137. if reset:
  138. self.hyps = []
  139. self.refers = []
  140. logger.info(scores_all)
  141. return scores_all