Browse Source

!30643 modify cn api docs

Merge pull request !30643 from changzherui/code_docs_mod_cn_api_0228
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
5554b062e9
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 13 additions and 21 deletions
  1. +2
    -5
      docs/api/api_python/nn/mindspore.nn.ConfusionMatrix.rst
  2. +8
    -14
      mindspore/python/mindspore/nn/metrics/confusion_matrix.py
  3. +1
    -1
      mindspore/python/mindspore/nn/metrics/dice.py
  4. +1
    -1
      mindspore/python/mindspore/ops/operations/debug_ops.py
  5. +1
    -0
      mindspore/python/mindspore/train/callback/_loss_monitor.py

+ 2
- 5
docs/api/api_python/nn/mindspore.nn.ConfusionMatrix.rst View File

@@ -31,10 +31,6 @@ mindspore.nn.ConfusionMatrix
numpy.ndarray,计算的结果。
**异常:**
- **RuntimeError** - 没有先调用update方法。
.. py:method:: update(*inputs)
使用y_pred和y更新内部评估结果。
@@ -46,5 +42,6 @@ mindspore.nn.ConfusionMatrix
**异常:**
- **RuntimeError** - 没有先调用update方法。
- **ValueError** - 输入参数的数量不等于2。
- **ValueError** - 如果预测值和标签的维度不一致。

+ 8
- 14
mindspore/python/mindspore/nn/metrics/confusion_matrix.py View File

@@ -28,7 +28,7 @@ class ConfusionMatrix(Metric):

Args:
num_classes (int): Number of classes in the dataset.
normalize (str): Normalization mode for confusion matrix. Choose from:
normalize (str): Normalization mode for confusion matrix. Default: "no_norm". Choose from:

- **'no_norm'** (None) - No Normalization is used. Default: None.
- **'target'** (str) - Normalization based on target value.
@@ -54,17 +54,11 @@ class ConfusionMatrix(Metric):
[[1. 1.]
[1. 1.]]
"""
TARGET = "target"
PREDICTION = "prediction"
ALL = "all"
NO_NORM = "no_norm"

def __init__(self, num_classes, normalize=NO_NORM, threshold=0.5):
def __init__(self, num_classes, normalize="no_norm", threshold=0.5):
super(ConfusionMatrix, self).__init__()

self.num_classes = validator.check_value_type("num_classes", num_classes, [int])
if normalize != ConfusionMatrix.TARGET and normalize != ConfusionMatrix.PREDICTION and \
normalize != ConfusionMatrix.ALL and normalize is not ConfusionMatrix.NO_NORM:
if normalize not in ["target", "prediction", "all", "no_norm"]:
raise ValueError("For 'ConfusionMatrix', the argument 'normalize' should be in "
"['all', 'prediction', 'label', 'no_norm'(None)], but got {}.".format(normalize))

@@ -90,7 +84,7 @@ class ConfusionMatrix(Metric):

Raises:
ValueError: If the number of inputs is not 2.
ValueError: If the lengths of `candidate_corpus` and `reference_corpus` are not equal.
ValueError: If the dim of y_pred and y are not equal.
"""
if len(inputs) != 2:
raise ValueError("For 'ConfusionMatrix.update', it needs 2 inputs (predicted value, true value), "
@@ -133,11 +127,11 @@ class ConfusionMatrix(Metric):
matrix_target = confusion_matrix / confusion_matrix.sum(axis=1, keepdims=True)
matrix_pred = confusion_matrix / confusion_matrix.sum(axis=0, keepdims=True)
matrix_all = confusion_matrix / confusion_matrix.sum()
normalize_dict = {ConfusionMatrix.TARGET: matrix_target,
ConfusionMatrix.PREDICTION: matrix_pred,
ConfusionMatrix.ALL: matrix_all}
normalize_dict = {"target": matrix_target,
"prediction": matrix_pred,
"all": matrix_all}

if self.normalize == ConfusionMatrix.NO_NORM:
if self.normalize == "no_norm":
return confusion_matrix

matrix = normalize_dict.get(self.normalize)


+ 1
- 1
mindspore/python/mindspore/nn/metrics/dice.py View File

@@ -73,7 +73,7 @@ class Dice(Metric):

Raises:
ValueError: If the number of the inputs is not 2.
RuntimeError: If y_pred and y do not have the same shape.
ValueError: If y_pred and y do not have the same shape.
"""
if len(inputs) != 2:
raise ValueError("For 'Dice.update', it needs 2 inputs (predicted value, true value), "


+ 1
- 1
mindspore/python/mindspore/ops/operations/debug_ops.py View File

@@ -420,7 +420,7 @@ class Print(PrimitiveWithInfer):
In pynative mode, please use python print function.
In graph mode, the bool, int and float would be converted into Tensor to print,
str remains unchanged.
This function is used for debug. When too many print data at the same time,
This function is used for debugging. When too much data is printed at the same time,
in order not to affect the main process, the framework may discard some data. At this time,
if you need to record the data completely, you can recommended to use the `Summary` function. Please check
`Summary <https://www.mindspore.cn/mindinsight/docs/zh-CN/master/summary_record.html?highlight=summary#>`_.


+ 1
- 0
mindspore/python/mindspore/train/callback/_loss_monitor.py View File

@@ -38,6 +38,7 @@ class LossMonitor(Callback):

Raises:
ValueError: If per_print_times is not an integer or less than zero.
ValueError: If has_trained_epoch is not an integer or less than zero.

Examples:
>>> from mindspore import Model, nn


Loading…
Cancel
Save