from fastNLP.core.metrics import MetricBase class RelayMetric(MetricBase): def __init__(self, pred=None, pred_mask=None, target=None, start_seg_mask=None): super().__init__() self._init_param_map(pred=pred, pred_mask=pred_mask, target=target, start_seg_mask=start_seg_mask) self.tp = 0 self.rec = 0 self.pre = 0 def evaluate(self, pred, pred_mask, target, start_seg_mask): """ 给定每个batch,累计一下结果。 :param pred: 预测的结果,为当前位置的开始的segment的(长度-1) :param pred_mask: 当前位置预测有segment开始 :param target: 当前位置开始的segment的(长度-1) :param start_seg_mask: 当前有segment结束 :return: """ self.tp += ((pred.long().eq(target.long())).__and__(pred_mask.byte().__and__(start_seg_mask.byte()))).sum().item() self.rec += start_seg_mask.sum().item() self.pre += pred_mask.sum().item() def get_metric(self, reset=True): """ 在所有数据都计算结束之后,得到performance :param reset: :return: """ pre = self.tp/(self.pre + 1e-12) rec = self.tp/(self.rec + 1e-12) f = 2*pre*rec/(1e-12 + pre + rec) if reset: self.tp = 0 self.rec = 0 self.pre = 0 self.bigger_than_L = 0 return {'f': round(f, 6), 'pre': round(pre, 6), 'rec': round(rec, 6)}