You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.ConfusionMatrix.rst 1.7 kB

4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. mindspore.nn.ConfusionMatrix
  2. ============================
  3. .. py:class:: mindspore.nn.ConfusionMatrix(num_classes, normalize='no_norm', threshold=0.5)
  4. 计算混淆矩阵(confusion matrix),通常用于评估分类模型的性能,包括二分类和多分类场景。
  5. 如果您只想使用混淆矩阵,请使用该类。如果想计算"PPV"、"TPR"、"TNR"等,请使用'mindspore.nn.ConfusionMatrixMetric'类。
  6. **参数:**
  7. - **num_classes** (int) - 数据集中的类别数量。
  8. - **normalize** (str) - 计算ConfsMatrix的参数支持四种归一化模式,默认值:None。
  9. - **"no_norm"** (None):不使用标准化。
  10. - **"target"** (str):基于目标值的标准化。
  11. - **"prediction"** (str):基于预测值的标准化。
  12. - **"all"** (str):整个矩阵的标准化。
  13. - **threshold** (float) - 阈值,用于与输入Tensor进行比较。默认值:0.5。
  14. .. py:method:: clear()
  15. 重置评估结果。
  16. .. py:method:: eval()
  17. 计算混淆矩阵。
  18. **返回:**
  19. numpy.ndarray,计算的结果。
  20. .. py:method:: update(*inputs)
  21. 使用y_pred和y更新内部评估结果。
  22. **参数:**
  23. - ***inputs** (tuple) - 输入 `y_pred` 和 `y` 。 `y_pred` 和 `y` 是 `Tensor` 、列表或数组。
  24. `y_pred` 是预测值, `y` 是真实值, `y_pred` 的shape是 :math:`(N, C, ...)` 或 :math:`(N, ...)` , `y` 的shape是 :math:`(N, ...)` 。
  25. **异常:**
  26. - **ValueError** - 输入参数的数量不等于2。
  27. - **ValueError** - 如果预测值和标签的维度不一致。