You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Tril.rst 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. mindspore.nn.Tril
  2. =================
  3. .. py:class:: mindspore.nn.Tril
  4. 返回一个Tensor,其中第 `k` 个对角线以上的元素被置为零。
  5. 矩阵的下三角把矩阵分成对角线上和对角线下的元素。
  6. 参数 `k` 控制着矩阵的对角线。如果 `k` 为0,则保留主对角线上和下面的所有元素。正值包括主对角线上方尽可能多的对角线,类似地,负值排除主对角线下方尽可能多的对角线。
  7. **输入:**
  8. - **x** (Tensor):输入Tensor。数据类型为Number。shape为 :math:`(N,*)`,其中 :math:`*` 表示任意的附加维度数。
  9. - **k** (Int):对角线的索引。默认值:0。
  10. **输出:**
  11. Tensor,shape和数据类型与 `x` 相同。
  12. **异常:**
  13. - **TypeError:** `k` 不是int。
  14. - **ValueError:** `x` 的维度小于1。
  15. **支持平台:**
  16. ``Ascend`` ``GPU`` ``CPU``
  17. **样例:**
  18. >>> x = Tensor(np.array([[ 1, 2, 3, 4],
  19. ... [ 5, 6, 7, 8],
  20. ... [10, 11, 12, 13],
  21. ... [14, 15, 16, 17]]))
  22. >>> tril = nn.Tril()
  23. >>> result = tril(x)
  24. >>> print(result)
  25. [[ 1 0 0 0]
  26. [ 5 6 0 0]
  27. [10 11 12 0]
  28. [14 15 16 17]]
  29. >>> x = Tensor(np.array([[ 1, 2, 3, 4],
  30. ... [ 5, 6, 7, 8],
  31. ... [10, 11, 12, 13],
  32. ... [14, 15, 16, 17]]))
  33. >>> tril = nn.Tril()
  34. >>> result = tril(x, 1)
  35. >>> print(result)
  36. [[ 1 2 0 0]
  37. [ 5 6 7 0]
  38. [10 11 12 13]
  39. [14 15 16 17]]
  40. >>> x = Tensor(np.array([[ 1, 2, 3, 4],
  41. ... [ 5, 6, 7, 8],
  42. ... [10, 11, 12, 13],
  43. ... [14, 15, 16, 17]]))
  44. >>> tril = nn.Tril()
  45. >>> result = tril(x, 2)
  46. >>> print(result)
  47. [[ 1 2 3 0]
  48. [ 5 6 7 8]
  49. [10 11 12 13]
  50. [14 15 16 17]]
  51. >>> x = Tensor(np.array([[ 1, 2, 3, 4],
  52. ... [ 5, 6, 7, 8],
  53. ... [10, 11, 12, 13],
  54. ... [14, 15, 16, 17]]))
  55. >>> tril = nn.Tril()
  56. >>> result = tril(x, -1)
  57. >>> print(result)
  58. [[ 0 0 0 0]
  59. [ 5 0 0 0]
  60. [10 11 0 0]
  61. [14 15 16 0]]