You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.HShrink.rst 1.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. mindspore.nn.HShrink
  2. =============================
  3. .. py:class:: mindspore.nn.HShrink(lambd=0.5)
  4. 按元素计算Hard Shrink函数,公式定义如下:
  5. .. math::
  6. \text{HardShrink}(x) =
  7. \begin{cases}
  8. x, & \text{ if } x > \lambda \\
  9. x, & \text{ if } x < -\lambda \\
  10. 0, & \text{ otherwise }
  11. \end{cases}
  12. **参数:**
  13. **lambd** (float) - Hard Shrink公式定义的阈值。默认值:0.5。
  14. **输入:**
  15. - **input_x** (Tensor) - Hard Shrink的输入,数据类型为float16或float32。
  16. **输出:**
  17. Tensor,shape和数据类型与输入相同。
  18. **支持平台:**
  19. ``Ascend``
  20. **异常:**
  21. - **TypeError** - `lambd` 不是float。
  22. - **TypeError** - `input_x` 的dtype既不是float16也不是float32。
  23. **样例:**
  24. >>> input_x = Tensor(np.array([[ 0.5, 1, 2.0],[0.0533,0.0776,-2.1233]]),mstype.float32)
  25. >>> hshrink = nn.HShrink()
  26. >>> output = hshrink(input_x)
  27. >>> print(output)
  28. [[ 0. 1. 2. ]
  29. [ 0. 0. -2.1233]]