|
|
|
@@ -31,7 +31,7 @@ class ROC(Metric): |
|
|
|
range [0,num_classes-1]. Default: None. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> 1) binary classification example |
|
|
|
>>> # 1) binary classification example |
|
|
|
>>> x = Tensor(np.array([3, 1, 4, 2])) |
|
|
|
>>> y = Tensor(np.array([0, 1, 2, 3])) |
|
|
|
>>> metric = ROC(pos_label=2) |
|
|
|
@@ -42,7 +42,7 @@ class ROC(Metric): |
|
|
|
[0., 1, 1., 1., 1.] |
|
|
|
[5, 4, 3, 2, 1] |
|
|
|
>>> |
|
|
|
>>> 2) multiclass classification example |
|
|
|
>>> # 2) multiclass classification example |
|
|
|
>>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05], |
|
|
|
... [0.05, 0.05, 0.05, 0.75]])) |
|
|
|
>>> y = Tensor(np.array([0, 1, 2, 3])) |
|
|
|
@@ -101,6 +101,7 @@ class ROC(Metric): |
|
|
|
def update(self, *inputs): |
|
|
|
""" |
|
|
|
Update state with predictions and targets. |
|
|
|
|
|
|
|
Args: |
|
|
|
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. |
|
|
|
In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]` |
|
|
|
|