You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Metric.rst 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. mindspore.nn.Metric
  2. ====================
  3. .. py:class:: mindspore.nn.Metric
  4. 用于计算评估指标的基类。
  5. 在计算评估指标时需要调用 `clear` 、 `update` 和 `eval` 三个方法,在继承该类自定义评估指标时,也需要实现这三个方法。其中,`update` 用于计算中间过程的内部结果,`eval` 用于计算最终评估结果,`clear` 用于重置中间结果。
  6. 请勿直接使用该类,需使用子类如 :class:`mindspore.nn.MAE` 、 :class:`mindspore.nn.Recall` 等。
  7. .. py:method:: clear()
  8. :abstractmethod:
  9. 清除内部评估结果。
  10. .. note::
  11. 所有子类都必须重写此接口。
  12. .. py:method:: eval()
  13. :abstractmethod:
  14. 计算最终评估结果。
  15. .. note::
  16. 所有子类都必须重写此接口。
  17. .. py:method:: indexes
  18. :property:
  19. 获取当前的 `indexes` 值。默认为None,调用 `set_indexes` 方法可修改 `indexes` 值。
  20. .. py:method:: set_indexes(indexes)
  21. 该接口用于重排 `update` 的输入。
  22. 给定(label0, label1, logits)作为 `update` 的输入,将 `indexes` 设置为[2, 1],则最终使用(logits, label1)作为 `update` 的真实输入。
  23. .. note::
  24. 在继承该类自定义评估函数时,需要用装饰器 `mindspore.nn.rearrange_inputs` 修饰 `update` 方法,否则配置的 `indexes` 值不生效。
  25. **参数:**
  26. - **indexes** (List(int)) - logits和标签的目标顺序。
  27. **输出:**
  28. :class:`Metric` ,类实例本身。
  29. **异常:**
  30. - **ValueError** - 如果输入的index类型不是list或其元素类型不全为int。
  31. **样例:**
  32. >>> import numpy as np
  33. >>> from mindspore import nn, Tensor
  34. >>>
  35. >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
  36. >>> y = Tensor(np.array([1, 0, 1]))
  37. >>> y2 = Tensor(np.array([0, 0, 1]))
  38. >>> metric = nn.Accuracy('classification').set_indexes([0, 2])
  39. >>> metric.clear()
  40. >>> # indexes为[0, 2],使用x作为预测值,y2作为真实标签
  41. >>> metric.update(x, y, y2)
  42. >>> accuracy = metric.eval()
  43. >>> print(accuracy)
  44. 0.3333333333333333
  45. .. py:method:: update(*inputs)
  46. :abstractmethod:
  47. 更新内部评估结果。
  48. .. note::
  49. 所有子类都必须重写此接口。
  50. **参数:**
  51. - **inputs** - 可变长度输入参数列表。通常是预测值和对应的真实标签。