You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_stridedslice_op.py 6.8 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. from mindspore import Tensor
  19. from mindspore.ops import operations as P
  20. def strided_slice(nptype):
  21. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  22. x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(nptype))
  23. y = P.StridedSlice()(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  24. expect = np.array([[[[62, 63],
  25. [67, 68]],
  26. [[82, 83],
  27. [87, 88]]]]).astype(nptype)
  28. assert np.allclose(y.asnumpy(), expect)
  29. y = P.StridedSlice()(x, (1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2))
  30. expect = np.array([[[[64, 62],
  31. [69, 67]],
  32. [[84, 82],
  33. [89, 87]]]]).astype(nptype)
  34. assert np.allclose(y.asnumpy(), expect)
  35. y = P.StridedSlice()(x, (1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1))
  36. expect = np.array([[[[64, 63, 62],
  37. [69, 68, 67]],
  38. [[84, 83, 82],
  39. [89, 88, 87]]]]).astype(nptype)
  40. assert np.allclose(y.asnumpy(), expect)
  41. y = P.StridedSlice()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2))
  42. expect = np.array([[[[78, 76],
  43. [73, 71],
  44. [68, 66]],
  45. [[98, 96],
  46. [93, 91],
  47. [88, 86]]]]).astype(nptype)
  48. assert np.allclose(y.asnumpy(), expect)
  49. # ME Infer fault
  50. # y = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  51. # expect = np.array([[[[62, 63],
  52. # [67, 68]],
  53. # [[82, 83],
  54. # [87, 88]],
  55. # [[102, 103],
  56. # [107, 108]]]]).astype(nptype)
  57. # assert np.allclose(y.asnumpy(), expect)
  58. op = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100)
  59. y = op(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  60. expect = np.array([[[[60, 61, 62, 63],
  61. [65, 66, 67, 68],
  62. [70, 71, 72, 73],
  63. [75, 76, 77, 78]],
  64. [[80, 81, 82, 83],
  65. [85, 86, 87, 88],
  66. [90, 91, 92, 93],
  67. [95, 96, 97, 98]],
  68. [[100, 101, 102, 103],
  69. [105, 106, 107, 108],
  70. [110, 111, 112, 113],
  71. [115, 116, 117, 118]]]]).astype(nptype)
  72. assert np.allclose(y.asnumpy(), expect)
  73. x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(nptype))
  74. y = P.StridedSlice()(x, (1, 0, 0), (2, -3, 3), (1, 1, 3))
  75. expect = np.array([[[20]]]).astype(nptype)
  76. assert np.allclose(y.asnumpy(), expect)
  77. x_np = np.arange(0, 4*5).reshape(4, 5).astype(nptype)
  78. y = Tensor(x_np)[:, ::-1]
  79. expect = x_np[:, ::-1]
  80. assert np.allclose(y.asnumpy(), expect)
  81. x = Tensor(np.arange(0, 2 * 3 * 4 * 5 * 4 * 3 * 2).reshape(2, 3, 4, 5, 4, 3, 2).astype(nptype))
  82. y = P.StridedSlice()(x, (1, 0, 0, 2, 1, 2, 0), (2, 2, 2, 4, 2, 3, 2), (1, 1, 1, 1, 1, 1, 2))
  83. expect = np.array([[[[[[[1498.]]],
  84. [[[1522.]]]],
  85. [[[[1618.]]],
  86. [[[1642.]]]]],
  87. [[[[[1978.]]],
  88. [[[2002.]]]],
  89. [[[[2098.]]],
  90. [[[2122.]]]]]]]).astype(nptype)
  91. assert np.allclose(y.asnumpy(), expect)
  92. @pytest.mark.level0
  93. @pytest.mark.platform_x86_gpu_training
  94. @pytest.mark.env_onecard
  95. def test_strided_slice_float32():
  96. strided_slice(np.float32)
  97. @pytest.mark.level0
  98. @pytest.mark.platform_x86_gpu_training
  99. @pytest.mark.env_onecard
  100. def test_strided_slice_float16():
  101. strided_slice(np.float16)
  102. @pytest.mark.level0
  103. @pytest.mark.platform_x86_gpu_training
  104. @pytest.mark.env_onecard
  105. def test_strided_slice_int64():
  106. strided_slice(np.int64)
  107. @pytest.mark.level0
  108. @pytest.mark.platform_x86_gpu_training
  109. @pytest.mark.env_onecard
  110. def test_strided_slice_int32():
  111. strided_slice(np.int32)
  112. @pytest.mark.level0
  113. @pytest.mark.platform_x86_gpu_training
  114. @pytest.mark.env_onecard
  115. def test_strided_slice_int16():
  116. strided_slice(np.int16)
  117. @pytest.mark.level0
  118. @pytest.mark.platform_x86_gpu_training
  119. @pytest.mark.env_onecard
  120. def test_strided_slice_int8():
  121. strided_slice(np.int8)
  122. @pytest.mark.level0
  123. @pytest.mark.platform_x86_gpu_training
  124. @pytest.mark.env_onecard
  125. def test_strided_slice_uint64():
  126. strided_slice(np.uint64)
  127. @pytest.mark.level0
  128. @pytest.mark.platform_x86_gpu_training
  129. @pytest.mark.env_onecard
  130. def test_strided_slice_uint32():
  131. strided_slice(np.uint32)
  132. @pytest.mark.level0
  133. @pytest.mark.platform_x86_gpu_training
  134. @pytest.mark.env_onecard
  135. def test_strided_slice_uint16():
  136. strided_slice(np.uint16)
  137. @pytest.mark.level0
  138. @pytest.mark.platform_x86_gpu_training
  139. @pytest.mark.env_onecard
  140. def test_strided_slice_uint8():
  141. strided_slice(np.uint8)
  142. @pytest.mark.level0
  143. @pytest.mark.platform_x86_gpu_training
  144. @pytest.mark.env_onecard
  145. def test_strided_slice_bool():
  146. strided_slice(np.bool)
  147. x = Tensor(np.arange(0, 4*4*4).reshape(4, 4, 4).astype(np.float32))
  148. y = x[-8:, :8]
  149. expect = np.array([[[0., 1., 2., 3.],
  150. [4., 5., 6., 7.],
  151. [8., 9., 10., 11.],
  152. [12., 13., 14., 15.]],
  153. [[16., 17., 18., 19.],
  154. [20., 21., 22., 23.],
  155. [24., 25., 26., 27.],
  156. [28., 29., 30., 31.]],
  157. [[32., 33., 34., 35.],
  158. [36., 37., 38., 39.],
  159. [40., 41., 42., 43.],
  160. [44., 45., 46., 47.]],
  161. [[48., 49., 50., 51.],
  162. [52., 53., 54., 55.],
  163. [56., 57., 58., 59.],
  164. [60., 61., 62., 63.]]])
  165. assert np.allclose(y.asnumpy(), expect)