You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_stridedslice_op.py 4.3 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. from mindspore import Tensor
  19. from mindspore.ops import operations as P
  20. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  21. @pytest.mark.level0
  22. @pytest.mark.platform_x86_gpu_training
  23. @pytest.mark.env_onecard
  24. def test_stridedslice():
  25. x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32))
  26. y = P.StridedSlice()(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  27. expect = np.array([[[[62, 63],
  28. [67, 68]],
  29. [[82, 83],
  30. [87, 88]]]])
  31. assert np.allclose(y.asnumpy(), expect)
  32. y = P.StridedSlice()(x, (1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2))
  33. expect = np.array([[[[64, 62],
  34. [69, 67]],
  35. [[84, 82],
  36. [89, 87]]]])
  37. assert np.allclose(y.asnumpy(), expect)
  38. y = P.StridedSlice()(x, (1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1))
  39. expect = np.array([[[[64, 63, 62],
  40. [69, 68, 67]],
  41. [[84, 83, 82],
  42. [89, 88, 87]]]])
  43. assert np.allclose(y.asnumpy(), expect)
  44. # ME infer fault
  45. # y = P.StridedSlice()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2))
  46. # expect = np.array([[[[78, 76],
  47. # [73, 71],
  48. # [68, 66]],
  49. # [[98, 96],
  50. # [93, 91],
  51. # [88, 86]]]])
  52. # assert np.allclose(y.asnumpy(), expect)
  53. # y = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  54. # expect = np.array([[[[ 62, 63],
  55. # [ 67, 68]],
  56. # [[ 82, 83],
  57. # [ 87, 88]],
  58. # [[102, 103],
  59. # [107, 108]]]])
  60. # assert np.allclose(y.asnumpy(), expect)
  61. op = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100)
  62. y = op(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  63. expect = np.array([[[[60, 61, 62, 63],
  64. [65, 66, 67, 68],
  65. [70, 71, 72, 73],
  66. [75, 76, 77, 78]],
  67. [[80, 81, 82, 83],
  68. [85, 86, 87, 88],
  69. [90, 91, 92, 93],
  70. [95, 96, 97, 98]],
  71. [[100, 101, 102, 103],
  72. [105, 106, 107, 108],
  73. [110, 111, 112, 113],
  74. [115, 116, 117, 118]]]])
  75. assert np.allclose(y.asnumpy(), expect)
  76. x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32))
  77. y = P.StridedSlice()(x, (1, 0, 0), (2, -3, 3), (1, 1, 3))
  78. expect = np.array([[[20]]])
  79. assert np.allclose(y.asnumpy(), expect)
  80. x_np = np.arange(0, 4*5).reshape(4, 5).astype(np.float32)
  81. y = Tensor(x_np)[:, ::-1]
  82. expect = x_np[:, ::-1]
  83. assert np.allclose(y.asnumpy(), expect)
  84. x = Tensor(np.arange(0, 2 * 3 * 4 * 5 * 4 * 3 * 2).reshape(2, 3, 4, 5, 4, 3, 2).astype(np.float32))
  85. y = P.StridedSlice()(x, (1, 0, 0, 2, 1, 2, 0), (2, 2, 2, 4, 2, 3, 2), (1, 1, 1, 1, 1, 1, 2))
  86. expect = np.array([[[[[[[1498.]]],
  87. [[[1522.]]]],
  88. [[[[1618.]]],
  89. [[[1642.]]]]],
  90. [[[[[1978.]]],
  91. [[[2002.]]]],
  92. [[[[2098.]]],
  93. [[[2122.]]]]]]])
  94. assert np.allclose(y.asnumpy(), expect)