You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.MeanSurfaceDistance.rst 2.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. mindspore.nn.MeanSurfaceDistance
  2. ===============================================
  3. .. py:class:: mindspore.nn.MeanSurfaceDistance(symmetric=False, distance_metric='euclidean')
  4. 计算从 `y_pred` 到 `y` 的平均表面距离。通常情况下,用来衡量分割任务中,预测情况和真实情况之间的差异度。
  5. 给定两个集合A和B,S(A)表示A的表面像素,任意v到S(A)的最短距离定义为:
  6. .. math::
  7. {\text{dis}}\left (v, S(A)\right ) = \underset{s_{A} \in S(A)}{\text{min }}\rVert v - s_{A} \rVert
  8. 从集合B到集合A的平均表面距离(Average Surface Distance)为:
  9. .. math::
  10. AvgSurDis(B \rightarrow A) = \frac{\sum_{s_{B} \in S(B)}^{} {\text{dis} \left
  11. ( s_{B}, S(A) \right )} } {\left | S(B) \right |}
  12. 其中 \|\|\*\|\| 表示距离度量。 \|\*\| 表示元素的数量。
  13. 从集合B到集合A以及从集合A到集合B的表面距离平均值为:
  14. .. math::
  15. MeanSurDis(A \leftrightarrow B) = \frac{\sum_{s_{A} \in S(A)}^{} {\text{dis} \left ( s_{A}, S(B) \right )}
  16. + \sum_{s_{B} \in S(B)}^{} {\text{dis} \left ( s_{B}, S(A) \right )} }{\left | S(A) \right | +
  17. \left | S(B) \right |}
  18. **参数:**
  19. - **distance_metric** (string) - 支持如下三种距离计算方法:"euclidean"、"chessboard"或"taxicab"。默认值:"euclidean"。
  20. - **symmetric** (bool) - 是否计算 `y_pred` 和 `y` 之间的对称平均平面距离。如果为False,计算方式为 :math:`AvgSurDis(y_pred\rightarrow y)` , 如果为True,计算方式为 :math:`MeanSurDis(y_pred \leftrightarrow y)` 。默认值:False。
  21. .. py:method:: clear()
  22. 内部评估结果清零。
  23. .. py:method:: eval()
  24. 计算平均表面距离。
  25. **返回:**
  26. numpy.float64,计算得到的平均表面距离值。
  27. **异常:**
  28. - **RuntimeError** - 如果没有先调用update方法。
  29. .. py:method:: update(*inputs)
  30. 使用 `y_pred`、`y` 和 `label_idx` 更新内部评估结果。
  31. **参数:**
  32. - **inputs** - `y_pred`、`y` 和 `label_idx`。`y_pred` 和 `y` 为Tensor,list或numpy.ndarray,`y_pred` 是预测的二值图像。`y` 是实际的二值图像。`label_idx` 数据类型为int或float,表示像素点的类别值。
  33. **异常:**
  34. - **ValueError** - 输入的数量不等于3。
  35. - **TypeError** - `label_idx` 的数据类型不是int或float。
  36. - **ValueError** - `label_idx` 的值不在y_pred或y中。
  37. - **ValueError** - `y_pred` 和 `y` 的shape不同。