|
|
|
@@ -28,7 +28,7 @@ class ConfusionMatrix(Metric): |
|
|
|
|
|
|
|
Args: |
|
|
|
num_classes (int): Number of classes in the dataset. |
|
|
|
normalize (str): Normalization mode for confusion matrix. Choose from: |
|
|
|
normalize (str): Normalization mode for confusion matrix. Default: "no_norm". Choose from: |
|
|
|
|
|
|
|
- **'no_norm'** (None) - No Normalization is used. Default: None. |
|
|
|
- **'target'** (str) - Normalization based on target value. |
|
|
|
@@ -54,17 +54,11 @@ class ConfusionMatrix(Metric): |
|
|
|
[[1. 1.] |
|
|
|
[1. 1.]] |
|
|
|
""" |
|
|
|
TARGET = "target" |
|
|
|
PREDICTION = "prediction" |
|
|
|
ALL = "all" |
|
|
|
NO_NORM = "no_norm" |
|
|
|
|
|
|
|
def __init__(self, num_classes, normalize=NO_NORM, threshold=0.5): |
|
|
|
def __init__(self, num_classes, normalize="no_norm", threshold=0.5): |
|
|
|
super(ConfusionMatrix, self).__init__() |
|
|
|
|
|
|
|
self.num_classes = validator.check_value_type("num_classes", num_classes, [int]) |
|
|
|
if normalize != ConfusionMatrix.TARGET and normalize != ConfusionMatrix.PREDICTION and \ |
|
|
|
normalize != ConfusionMatrix.ALL and normalize is not ConfusionMatrix.NO_NORM: |
|
|
|
if normalize not in ["target", "prediction", "all", "no_norm"]: |
|
|
|
raise ValueError("For 'ConfusionMatrix', the argument 'normalize' should be in " |
|
|
|
"['all', 'prediction', 'label', 'no_norm'(None)], but got {}.".format(normalize)) |
|
|
|
|
|
|
|
@@ -90,7 +84,7 @@ class ConfusionMatrix(Metric): |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: If the number of inputs is not 2. |
|
|
|
ValueError: If the lengths of `candidate_corpus` and `reference_corpus` are not equal. |
|
|
|
ValueError: If the dim of y_pred and y are not equal. |
|
|
|
""" |
|
|
|
if len(inputs) != 2: |
|
|
|
raise ValueError("For 'ConfusionMatrix.update', it needs 2 inputs (predicted value, true value), " |
|
|
|
@@ -133,11 +127,11 @@ class ConfusionMatrix(Metric): |
|
|
|
matrix_target = confusion_matrix / confusion_matrix.sum(axis=1, keepdims=True) |
|
|
|
matrix_pred = confusion_matrix / confusion_matrix.sum(axis=0, keepdims=True) |
|
|
|
matrix_all = confusion_matrix / confusion_matrix.sum() |
|
|
|
normalize_dict = {ConfusionMatrix.TARGET: matrix_target, |
|
|
|
ConfusionMatrix.PREDICTION: matrix_pred, |
|
|
|
ConfusionMatrix.ALL: matrix_all} |
|
|
|
normalize_dict = {"target": matrix_target, |
|
|
|
"prediction": matrix_pred, |
|
|
|
"all": matrix_all} |
|
|
|
|
|
|
|
if self.normalize == ConfusionMatrix.NO_NORM: |
|
|
|
if self.normalize == "no_norm": |
|
|
|
return confusion_matrix |
|
|
|
|
|
|
|
matrix = normalize_dict.get(self.normalize) |
|
|
|
|