You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.RootMeanSquareDistance.rst 2.6 kB

4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. mindspore.nn.RootMeanSquareDistance
  2. ======================================
  3. .. py:class:: mindspore.nn.RootMeanSquareDistance(symmetric=False, distance_metric='euclidean')
  4. 计算从 `y_pred` 到 `y` 的均方根表面距离。
  5. 给定两个集合A和B,S(A)表示A的表面像素,任意v到S(A)的最短距离定义为:
  6. .. math::
  7. {\text{dis}}\left (v, S(A)\right ) = \underset{s_{A} \in S(A)}{\text{min }}\rVert v - s_{A} \rVert
  8. 从集合B到集合A的均方根表面距离(Root Mean Square Surface Distance)为:
  9. .. math::
  10. RmsSurDis(B \rightarrow A) = \sqrt{\frac{\sum_{s_{B} \in S(B)}^{} {\text{dis}^2 \left ( s_{B}, S(A)
  11. \right )} }{\left | S(B) \right |}}
  12. 其中 \|\|\*\|\| 表示距离度量。 \|\*\| 表示元素的数量。
  13. 从集合B到集合A以及从集合A到集合B的表面距离平均值为:
  14. .. math::
  15. RmsSurDis(A \leftrightarrow B) = \sqrt{\frac{\sum_{s_{A} \in S(A)}^{} {\text{dis} \left ( s_{A},
  16. S(B) \right ) ^{2}} + \sum_{s_{B} \in S(B)}^{} {\text{dis} \left ( s_{B}, S(A) \right ) ^{2}}}{\left | S(A)
  17. \right | + \left | S(B) \right |}}
  18. **参数:**
  19. - **distance_metric** (string) - 支持如下三种距离计算方法:"euclidean"、"chessboard" 或 "taxicab"。默认值:"euclidean"。
  20. - **symmetric** (bool) - 是否计算 `y_pred` 和 `y` 之间的对称平均平面距离。如果为False,计算方式为 :math:`RmsSurDis(y_{pred} , y)`, 如果为True,计算方式为 :math:`RmsSurDis(y_{pred} \leftrightarrow y)`。默认值:False。
  21. .. py:method:: clear()
  22. 内部评估结果清零。
  23. .. py:method:: eval()
  24. 计算均方根表面距离。
  25. **返回:**
  26. numpy.float64,计算得到的均方根表面距离值。
  27. **异常:**
  28. - **RuntimeError** - 如果没有先调用update方法,则会报错。
  29. .. py:method:: update(*inputs)
  30. 使用 `y_pred`、`y` 和 `label_idx` 更新内部评估结果。
  31. **参数:**
  32. - **inputs** - `y_pred`、`y` 和 `label_idx`。`y_pred` 和 `y` 为Tensor,list或numpy.ndarray,`y_pred` 是预测的二值图像。`y` 是实际的二值图像。`label_idx` 数据类型为int或float,表示像素点的类别值。
  33. **异常:**
  34. - **ValueError** - 输入的数量不等于3。
  35. - **TypeError** - `label_idx` 的数据类型不是int或float。
  36. - **ValueError** - `label_idx` 的值不在y_pred或y中。
  37. - **ValueError** - `y_pred` 和 `y` 的shape不同。