Browse Source

!13456 multdiceloss api

From: @lijiaqi0612
Reviewed-by: @zh_qh,@kisnwang
Signed-off-by: @kisnwang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
fd811b4dd0
2 changed files with 5 additions and 4 deletions
  1. +2
    -2
      mindspore/nn/loss/loss.py
  2. +3
    -2
      mindspore/nn/metrics/roc.py

+ 2
- 2
mindspore/nn/loss/loss.py View File

@@ -490,9 +490,9 @@ class MultiClassDiceLoss(_Loss):

Inputs:
- **y_pred** (Tensor) - Tensor of shape (N, C, ...). The y_pred dimension should be greater than 1. The data
type must be float16 or float32.
type must be float16 or float32.
- **y** (Tensor) - Tensor of shape (N, C, ...). The y dimension should be greater than 1. The data type must be
float16 or float32.
loat16 or float32.

Outputs:
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses.


+ 3
- 2
mindspore/nn/metrics/roc.py View File

@@ -31,7 +31,7 @@ class ROC(Metric):
range [0,num_classes-1]. Default: None.

Examples:
>>> 1) binary classification example
>>> # 1) binary classification example
>>> x = Tensor(np.array([3, 1, 4, 2]))
>>> y = Tensor(np.array([0, 1, 2, 3]))
>>> metric = ROC(pos_label=2)
@@ -42,7 +42,7 @@ class ROC(Metric):
[0., 1, 1., 1., 1.]
[5, 4, 3, 2, 1]
>>>
>>> 2) multiclass classification example
>>> # 2) multiclass classification example
>>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05],
... [0.05, 0.05, 0.05, 0.75]]))
>>> y = Tensor(np.array([0, 1, 2, 3]))
@@ -101,6 +101,7 @@ class ROC(Metric):
def update(self, *inputs):
"""
Update state with predictions and targets.

Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray.
In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]`


Loading…
Cancel
Save