* add ConfusionMatrix, ConfusionMatrixMetric * add confusionmatrix to utils * add ConfusionMatrixmetric * add ConfusionMatrixMetric * init for test * begin test * test finish * doc finishtags/v0.5.5
| @@ -7,7 +7,8 @@ __all__ = [ | |||
| "AccuracyMetric", | |||
| "SpanFPreRecMetric", | |||
| "CMRC2018Metric", | |||
| "ClassifyFPreRecMetric" | |||
| "ClassifyFPreRecMetric", | |||
| "ConfusionMatrixMetric" | |||
| ] | |||
| import inspect | |||
| @@ -15,6 +16,7 @@ import warnings | |||
| from abc import abstractmethod | |||
| from collections import defaultdict | |||
| from typing import Union | |||
| from copy import deepcopy | |||
| import re | |||
| import numpy as np | |||
| @@ -27,6 +29,7 @@ from .utils import _check_arg_dict_list | |||
| from .utils import _get_func_signature | |||
| from .utils import seq_len_to_mask | |||
| from .vocabulary import Vocabulary | |||
| from .utils import ConfusionMatrix | |||
| class MetricBase(object): | |||
| @@ -276,6 +279,95 @@ class MetricBase(object): | |||
| return | |||
| class ConfusionMatrixMetric(MetricBase): | |||
| r""" | |||
| 分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | |||
| 最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例} | |||
| ConfusionMatrix实例的print()函数将输出矩阵字符串。 | |||
| pred_dict = {"pred": torch.Tensor([2,1,3])} | |||
| target_dict = {'target': torch.Tensor([2,2,1])} | |||
| metric = ConfusionMatrixMetric() | |||
| metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
| print(metric.get_metric()) | |||
| {'confusion_matrix': | |||
| target 1.0 2.0 3.0 all | |||
| pred | |||
| 1.0 0 1 0 1 | |||
| 2.0 0 1 0 1 | |||
| 3.0 1 0 0 1 | |||
| all 1 2 0 3} | |||
| """ | |||
| def __init__(self, vocab=None, pred=None, target=None, seq_len=None): | |||
| """ | |||
| :param vocab: vocab词表类,要求有to_word()方法。 | |||
| :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
| :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
| :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | |||
| """ | |||
| super().__init__() | |||
| self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
| self.confusion_matrix = ConfusionMatrix(vocab=vocab) | |||
| def evaluate(self, pred, target, seq_len=None): | |||
| """ | |||
| evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
| :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
| torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||
| :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||
| torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||
| :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]). | |||
| """ | |||
| if not isinstance(pred, torch.Tensor): | |||
| raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(pred)}.") | |||
| if not isinstance(target, torch.Tensor): | |||
| raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(target)}.") | |||
| if seq_len is not None and not isinstance(seq_len, torch.Tensor): | |||
| raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(seq_len)}.") | |||
| if pred.dim() == target.dim(): | |||
| pass | |||
| elif pred.dim() == target.dim() + 1: | |||
| pred = pred.argmax(dim=-1) | |||
| if seq_len is None and target.dim() > 1: | |||
| warnings.warn("You are not passing `seq_len` to exclude pad.") | |||
| else: | |||
| raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
| f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
| f"{pred.size()[:-1]}, got {target.size()}.") | |||
| target = target.to(pred) | |||
| if seq_len is not None and target.dim() > 1: | |||
| for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()): | |||
| l=int(l) | |||
| self.confusion_matrix.add_pred_target(p[:l], t[:l]) | |||
| elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出 | |||
| for p, t in zip(pred.tolist(), target.tolist()): | |||
| self.confusion_matrix.add_pred_target(p, t) | |||
| else: | |||
| self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist()) | |||
| def get_metric(self,reset=True): | |||
| """ | |||
| get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
| :param bool reset: 在调用完get_metric后是否清空评价指标统计量. | |||
| :return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} | |||
| """ | |||
| confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)} | |||
| if reset: | |||
| self.confusion_matrix.clear() | |||
| return confusion | |||
| class AccuracyMetric(MetricBase): | |||
| """ | |||
| 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | |||
| @@ -8,18 +8,22 @@ __all__ = [ | |||
| "get_seq_len" | |||
| ] | |||
| import _pickle | |||
| import inspect | |||
| import os | |||
| import warnings | |||
| from collections import Counter, namedtuple | |||
| from copy import deepcopy | |||
| from typing import List | |||
| import _pickle | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| from typing import List | |||
| from ._logger import logger | |||
| from prettytable import PrettyTable | |||
| from ._logger import logger | |||
| from ._parallel_utils import _model_contains_inner_module | |||
| # from .vocabulary import Vocabulary | |||
| try: | |||
| from apex import amp | |||
| @@ -30,6 +34,98 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||
| 'varargs']) | |||
| class ConfusionMatrix: | |||
| """a dict can provide Confusion Matrix""" | |||
| def __init__(self, vocab=None): | |||
| """ | |||
| :param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 | |||
| """ | |||
| if vocab and not hasattr(vocab, 'to_word'): | |||
| raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary," | |||
| f"got {type(vocab)}.") | |||
| self.confusiondict={} #key: pred index, value:target word ocunt | |||
| self.predcount={} #key:pred index, value:count | |||
| self.targetcount={} #key:target index, value:count | |||
| self.vocab=vocab | |||
| def add_pred_target(self, pred, target): #一组结果 | |||
| """ | |||
| 通过这个函数向ConfusionMatrix加入一组预测结果 | |||
| :param list pred: 预测的标签列表 | |||
| :param list target: 真实值的标签列表 | |||
| :return ConfusionMatrix | |||
| confusion=ConfusionMatrix() | |||
| pred = [2,1,3] | |||
| target = [2,2,1] | |||
| confusion.add_pred_target(pred, target) | |||
| print(confusion) | |||
| target 1 2 3 all | |||
| pred | |||
| 1 0 1 0 1 | |||
| 2 0 1 0 1 | |||
| 3 1 0 0 1 | |||
| all 1 2 0 3 | |||
| """ | |||
| for p,t in zip(pred,target): #<int, int> | |||
| self.predcount[p]=self.predcount.get(p,0)+ 1 | |||
| self.targetcount[t]=self.targetcount.get(t,0)+1 | |||
| if p in self.confusiondict: | |||
| self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1 | |||
| else: | |||
| self.confusiondict[p]={} | |||
| self.confusiondict[p][t]= 1 | |||
| return self.confusiondict | |||
| def clear(self): | |||
| """ | |||
| 清除一些值,等待再次新加入 | |||
| :return: | |||
| """ | |||
| self.confusiondict={} | |||
| self.targetcount={} | |||
| self.predcount={} | |||
| def __repr__(self): | |||
| """ | |||
| :return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 | |||
| """ | |||
| row2idx={} | |||
| idx2row={} | |||
| # 已知的所有键/label | |||
| totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys())))) | |||
| lenth=len(totallabel) | |||
| # namedict key :idx value:word/idx | |||
| namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel]) | |||
| for label,idx in zip(totallabel,range(lenth)): | |||
| idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... | |||
| row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... | |||
| # 这里打印东西 | |||
| #表头 | |||
| head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"] | |||
| output="\t".join(head) + "\n" + "pred" + "\n" | |||
| #内容 | |||
| for i in row2idx.keys(): #第i行 | |||
| p=row2idx[i] | |||
| h=namedict[p] | |||
| l=[0 for _ in range(lenth)] | |||
| if self.confusiondict.get(p,None): | |||
| for t,c in self.confusiondict[p].items(): | |||
| l[idx2row[t]] = c #完成一行 | |||
| l=[h]+[str(n) for n in l]+[str(sum(l))] | |||
| output+="\t".join(l) +"\n" | |||
| #表尾 | |||
| tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()] | |||
| tail=["all"]+[str(n) for n in tail]+[str(sum(tail))] | |||
| output+="\t".join(tail) | |||
| return output | |||
| class Option(dict): | |||
| """a dict can treat keys as attributes""" | |||
| @@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric | |||
| from fastNLP.core.metrics import _pred_topk, _accuracy_topk | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from collections import Counter | |||
| from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric | |||
| from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric | |||
| def _generate_tags(encoding_type, number_labels=4): | |||
| @@ -44,6 +44,141 @@ def _convert_res_to_fastnlp_res(metric_result): | |||
| allen_result[key] = round(value, 6) | |||
| return allen_result | |||
| class TestConfusionMatrixMetric(unittest.TestCase): | |||
| def test_ConfusionMatrixMetric1(self): | |||
| pred_dict = {"pred": torch.zeros(4,3)} | |||
| target_dict = {'target': torch.zeros(4)} | |||
| metric = ConfusionMatrixMetric() | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| print(metric.get_metric()) | |||
| def test_ConfusionMatrixMetric2(self): | |||
| # (2) with corrupted size | |||
| try: | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4)} | |||
| metric = ConfusionMatrixMetric() | |||
| metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
| print(metric.get_metric()) | |||
| except Exception as e: | |||
| print(e) | |||
| return | |||
| print("No exception catches.") | |||
| def test_ConfusionMatrixMetric3(self): | |||
| # (3) the second batch is corrupted size | |||
| try: | |||
| metric = ConfusionMatrixMetric() | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| print(metric.get_metric()) | |||
| except Exception as e: | |||
| print(e) | |||
| return | |||
| assert(True, False), "No exception catches." | |||
| def test_ConfusionMatrixMetric4(self): | |||
| # (4) check reset | |||
| metric = ConfusionMatrixMetric() | |||
| pred_dict = {"pred": torch.randn(4, 3, 2)} | |||
| target_dict = {'target': torch.ones(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| res = metric.get_metric() | |||
| self.assertTrue(isinstance(res, dict)) | |||
| print(res) | |||
| def test_ConfusionMatrixMetric5(self): | |||
| # (5) check numpy array is not acceptable | |||
| try: | |||
| metric = ConfusionMatrixMetric() | |||
| pred_dict = {"pred": np.zeros((4, 3, 2))} | |||
| target_dict = {'target': np.zeros((4, 3))} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| except Exception as e: | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_ConfusionMatrixMetric6(self): | |||
| # (6) check map, match | |||
| metric = ConfusionMatrixMetric(pred='predictions', target='targets') | |||
| pred_dict = {"predictions": torch.randn(4, 3, 2)} | |||
| target_dict = {'targets': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| res = metric.get_metric() | |||
| print(res) | |||
| def test_ConfusionMatrixMetric7(self): | |||
| # (7) check map, include unused | |||
| try: | |||
| metric = ConfusionMatrixMetric(pred='prediction', target='targets') | |||
| pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||
| target_dict = {'targets': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| except Exception as e: | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_ConfusionMatrixMetric8(self): | |||
| # (8) check _fast_metric | |||
| try: | |||
| metric = ConfusionMatrixMetric() | |||
| pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} | |||
| target_dict = {'targets': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| print(metric.get_metric()) | |||
| except Exception as e: | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_duplicate(self): | |||
| # 0.4.1的潜在bug,不能出现形参重复的情况 | |||
| metric = ConfusionMatrixMetric(pred='predictions', target='targets') | |||
| pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} | |||
| target_dict = {'targets':torch.zeros(4, 3), 'target': 0} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| print(metric.get_metric()) | |||
| def test_seq_len(self): | |||
| N = 256 | |||
| seq_len = torch.zeros(N).long() | |||
| seq_len[0] = 2 | |||
| pred = {'pred': torch.ones(N, 2)} | |||
| target = {'target': torch.ones(N, 2), 'seq_len': seq_len} | |||
| metric = ConfusionMatrixMetric() | |||
| metric(pred_dict=pred, target_dict=target) | |||
| metric.get_metric(reset=False) | |||
| seq_len[1:] = 1 | |||
| metric(pred_dict=pred, target_dict=target) | |||
| metric.get_metric() | |||
| def test_vocab(self): | |||
| vocab = Vocabulary() | |||
| word_list = "this is a word list".split() | |||
| vocab.update(word_list) | |||
| pred_dict = {"pred": torch.zeros(4,3)} | |||
| target_dict = {'target': torch.zeros(4)} | |||
| metric = ConfusionMatrixMetric(vocab=vocab) | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| print(metric.get_metric()) | |||
| class TestAccuracyMetric(unittest.TestCase): | |||
| def test_AccuracyMetric1(self): | |||
| # (1) only input, targets passed | |||
| @@ -133,7 +268,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| def test_AccuaryMetric8(self): | |||
| try: | |||
| metric = AccuracyMetric(pred='predictions', target='targets') | |||
| pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||
| pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
| target_dict = {'targets': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
| self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||