You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metric.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from fastNLP.core.metrics import MetricBase
  2. class RelayMetric(MetricBase):
  3. def __init__(self, pred=None, pred_mask=None, target=None, start_seg_mask=None):
  4. super().__init__()
  5. self._init_param_map(pred=pred, pred_mask=pred_mask, target=target, start_seg_mask=start_seg_mask)
  6. self.tp = 0
  7. self.rec = 0
  8. self.pre = 0
  9. def evaluate(self, pred, pred_mask, target, start_seg_mask):
  10. """
  11. 给定每个batch,累计一下结果。
  12. :param pred: 预测的结果,为当前位置的开始的segment的(长度-1)
  13. :param pred_mask: 当前位置预测有segment开始
  14. :param target: 当前位置开始的segment的(长度-1)
  15. :param start_seg_mask: 当前有segment结束
  16. :return:
  17. """
  18. self.tp += ((pred.long().eq(target.long())).__and__(pred_mask.byte().__and__(start_seg_mask.byte()))).sum().item()
  19. self.rec += start_seg_mask.sum().item()
  20. self.pre += pred_mask.sum().item()
  21. def get_metric(self, reset=True):
  22. """
  23. 在所有数据都计算结束之后,得到performance
  24. :param reset:
  25. :return:
  26. """
  27. pre = self.tp/(self.pre + 1e-12)
  28. rec = self.tp/(self.rec + 1e-12)
  29. f = 2*pre*rec/(1e-12 + pre + rec)
  30. if reset:
  31. self.tp = 0
  32. self.rec = 0
  33. self.pre = 0
  34. self.bigger_than_L = 0
  35. return {'f': round(f, 6), 'pre': round(pre, 6), 'rec': round(rec, 6)}