You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Precision.txt 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. Class mindspore.nn.Precision(eval_type='classification')
  2. 计算'classification'单标签数据分类和'multilabel'多标签数据分类的精度。
  3. 此函数创建两个局部变量:math:`\text{true_positive}`和:math:`\text{false_positive}` 用于计算精度。计算方式为:math:`\text{true_positive}`除以:math:`\text{true_positive}`与:math:`\text{false_positive}`的和,是一个幂等操作,此值最终作为精度返回。
  4. .. math::
  5. \text{precision} = \frac{\text{true_positive}}{\text{true_positive} + \text{false_positive}}
  6. 注:
  7. 在多标签情况下,:math:`y`和:math:`y_{pred}`的元素必须为0或1。
  8. 参数:
  9. eval_type (str):支持'classification'和'multilabel'。默认值:'classification'。
  10. 示例:
  11. >>> import numpy as np
  12. >>> from mindspore import nn, Tensor
  13. >>>
  14. >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
  15. >>> y = Tensor(np.array([1, 0, 1]))
  16. >>> metric = nn.Precision('classification')
  17. >>> metric.clear()
  18. >>> metric.update(x, y)
  19. >>> precision = metric.eval()
  20. >>> print(precision)
  21. [0.5 1. ]
  22. clear()
  23. 内部评估结果清零。
  24. eval(average=False)
  25. 计算精度。
  26. 参数:
  27. average (bool):指定是否计算平均精度。默认值:False。
  28. 返回:
  29. numpy.float64,计算结果。
  30. update(*inputs)
  31. 使用预测值`y_pred`和真实标签`y`更新局部变量。
  32. 参数:
  33. inputs:输入`y_pred`和`y`。`y_pred` 和 `y` 支持Tensor、list或numpy.ndarray类型。
  34. 对于'classification'情况,`y_pred`在大多数情况下由范围:math:`[0, 1]`中的浮点数组成,shape为:math:`(N, C)`,其中:math:`N`是样本数,:math:`C`是类别数。
  35. `y` 由整数值组成,如果是one_hot编码格式,shape是:math:`(N,C)`;如果是类别索引,shape是:math:`(N,)`。
  36. 对于'multilabel'情况,`y_pred`和`y`只能是值为0或1的one-hot编码格式,其中值为1的索引表示正类别。`y_pred`和`y`的shape都是:math:`(N,C)`。
  37. 异常:
  38. ValueError:inputs数量不是2。