You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Tril.rst 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. mindspore.nn.Tril
  2. =================
  3. .. py:class:: mindspore.nn.Tril
  4. 返回一个Tensor,指定主对角线以上的元素被置为零。
  5. 将矩阵元素沿主对角线分为上三角和下三角(包含对角线)。
  6. 参数 `k` 控制对角线的选择。若 `k` 为0,则沿主对角线分割并保留下三角所有元素。若 `k` 为正值,则沿主对角线向上选择对角线 `k` ,并保留下三角所有元素。若 `k` 为负值,则沿主对角线向下选择对角线 `k` ,并保留下三角所有元素。
  7. **输入:**
  8. - **x** (Tensor):输入Tensor。数据类型为`number <https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.html#mindspore.dtype>`_。
  9. - **k** (int):对角线的索引。默认值:0。假设输入的矩阵的维度分别为d1,d2,则k的范围应在[-min(d1, d2)+1, min(d1, d2)-1],超出该范围时输出值与输入 `x` 一致。
  10. **输出:**
  11. Tensor,数据类型和shape与 `x` 相同。
  12. **异常:**
  13. - **TypeError:** `k` 不是int。
  14. - **ValueError:** `x` 的维度小于1。
  15. **支持平台:**
  16. ``Ascend`` ``GPU`` ``CPU``
  17. **样例:**
  18. >>> # case1: k = 0
  19. >>> x = Tensor(np.array([[ 1, 2, 3, 4],
  20. ... [ 5, 6, 7, 8],
  21. ... [10, 11, 12, 13],
  22. ... [14, 15, 16, 17]]))
  23. >>> tril = nn.Tril()
  24. >>> result = tril(x)
  25. >>> print(result)
  26. [[ 1 0 0 0]
  27. [ 5 6 0 0]
  28. [10 11 12 0]
  29. [14 15 16 17]]
  30. >>> # case2: k = 1
  31. >>> x = Tensor(np.array([[ 1, 2, 3, 4],
  32. ... [ 5, 6, 7, 8],
  33. ... [10, 11, 12, 13],
  34. ... [14, 15, 16, 17]]))
  35. >>> tril = nn.Tril()
  36. >>> result = tril(x, 1)
  37. >>> print(result)
  38. [[ 1 2 0 0]
  39. [ 5 6 7 0]
  40. [10 11 12 13]
  41. [14 15 16 17]]
  42. >>> # case3: k = 2
  43. >>> x = Tensor(np.array([[ 1, 2, 3, 4],
  44. ... [ 5, 6, 7, 8],
  45. ... [10, 11, 12, 13],
  46. ... [14, 15, 16, 17]]))
  47. >>> tril = nn.Tril()
  48. >>> result = tril(x, 2)
  49. >>> print(result)
  50. [[ 1 2 3 0]
  51. [ 5 6 7 8]
  52. [10 11 12 13]
  53. [14 15 16 17]]
  54. >>> # case4: k = -1
  55. >>> x = Tensor(np.array([[ 1, 2, 3, 4],
  56. ... [ 5, 6, 7, 8],
  57. ... [10, 11, 12, 13],
  58. ... [14, 15, 16, 17]]))
  59. >>> tril = nn.Tril()
  60. >>> result = tril(x, -1)
  61. >>> print(result)
  62. [[ 0 0 0 0]
  63. [ 5 0 0 0]
  64. [10 11 0 0]
  65. [14 15 16 0]]