From c8e866959f7c1a9be155e510c71b59195777e394 Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Wed, 17 Mar 2021 10:40:18 +0800 Subject: [PATCH] multdiceloss api --- mindspore/nn/loss/loss.py | 4 ++-- mindspore/nn/metrics/roc.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index d9b39b8674..836a422422 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -490,9 +490,9 @@ class MultiClassDiceLoss(_Loss): Inputs: - **y_pred** (Tensor) - Tensor of shape (N, C, ...). The y_pred dimension should be greater than 1. The data - type must be float16 or float32. + type must be float16 or float32. - **y** (Tensor) - Tensor of shape (N, C, ...). The y dimension should be greater than 1. The data type must be - float16 or float32. + loat16 or float32. Outputs: Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses. diff --git a/mindspore/nn/metrics/roc.py b/mindspore/nn/metrics/roc.py index 5efba74097..29ecccbca9 100644 --- a/mindspore/nn/metrics/roc.py +++ b/mindspore/nn/metrics/roc.py @@ -31,7 +31,7 @@ class ROC(Metric): range [0,num_classes-1]. Default: None. Examples: - >>> 1) binary classification example + >>> # 1) binary classification example >>> x = Tensor(np.array([3, 1, 4, 2])) >>> y = Tensor(np.array([0, 1, 2, 3])) >>> metric = ROC(pos_label=2) @@ -42,7 +42,7 @@ class ROC(Metric): [0., 1, 1., 1., 1.] [5, 4, 3, 2, 1] >>> - >>> 2) multiclass classification example + >>> # 2) multiclass classification example >>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05], ... [0.05, 0.05, 0.05, 0.75]])) >>> y = Tensor(np.array([0, 1, 2, 3])) @@ -101,6 +101,7 @@ class ROC(Metric): def update(self, *inputs): """ Update state with predictions and targets. + Args: inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]`