You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Metric.rst 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. mindspore.nn.Metric
  2. ====================
  3. .. py:class:: mindspore.nn.Metric
  4. ڼָĻࡣ
  5. ڼָʱҪ `clear` `update` `eval` ڼ̳иԶָʱҲҪʵУ`update` ڼм̵ڲ`eval` ڼ`clear` м
  6. ֱʹø࣬ʹ :class:`mindspore.nn.MAE` :class:`mindspore.nn.Recall` ȡ
  7. .. py:method:: clear()
  8. :abstract:
  9. ڲΪ
  10. .. note::
  11. ඼д˽ӿڡ
  12. .. py:method:: eval()
  13. :abstract:
  14. ˼Ϊ
  15. .. note::
  16. ඼д˽ӿڡ
  17. .. py:method:: indexes
  18. :property:
  19. ȡǰ `indexes` ֵĬΪNone `set_indexes` ޸ `indexes` ֵ
  20. .. py:method:: set_indexes(indexes)
  21. ýӿ `update` 롣
  22. (label0, label1, logits)Ϊ `update` 룬 `indexes` Ϊ[2, 1]ʹ(logits, label1)Ϊ `update` ʵ롣
  23. .. note::
  24. ڼ̳иԶʱҪװ `mindspore.nn.rearrange_inputs` `update` õ `indexes` ֵЧ
  25. ****
  26. **indexes** (List(int)) - logitsͱǩĿ˳
  27. ****
  28. :class:`Metric` ʵ
  29. ****
  30. >>> import numpy as np
  31. >>> from mindspore import nn, Tensor
  32. >>>
  33. >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
  34. >>> y = Tensor(np.array([1, 0, 1]))
  35. >>> y2 = Tensor(np.array([0, 0, 1]))
  36. >>> metric = nn.Accuracy('classification').set_indexes([0, 2])
  37. >>> metric.clear()
  38. >>> # indexesΪ[0, 2]ʹxΪԤֵy2Ϊʵǩ
  39. >>> metric.update(x, y, y2)
  40. >>> accuracy = metric.eval()
  41. >>> print(accuracy)
  42. 0.3333333333333333
  43. .. py:method:: update(*inputs)
  44. :abstract:
  45. ˸ڲΪ
  46. .. note::
  47. ඼д˽ӿڡ
  48. ****
  49. **inputs** - ɱ䳤бͨԤֵͶӦʵǩ