You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_stridedslice_grad_op.py 11 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import composite as C
  22. class StridedSliceNet(nn.Cell):
  23. def __init__(self, begin, end, stride, begin_mask=0, end_mask=0, ellipsis_mask=0):
  24. super(StridedSliceNet, self).__init__()
  25. self.begin = begin
  26. self.end = end
  27. self.strides = stride
  28. self.slice = P.StridedSlice(begin_mask, end_mask, ellipsis_mask)
  29. def construct(self, x):
  30. return self.slice(x, self.begin, self.end, self.strides)
  31. class GradData(nn.Cell):
  32. def __init__(self, network):
  33. super(GradData, self).__init__()
  34. self.grad = C.GradOperation(get_all=True, sens_param=False)
  35. self.network = network
  36. def construct(self, x):
  37. return self.grad(self.network)(x)
  38. def strided_slice_grad(nptype):
  39. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  40. x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(nptype))
  41. net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  42. dx = GradData(net)(x)
  43. expect = np.array([[[[0., 0., 0., 0., 0.],
  44. [0., 0., 0., 0., 0.],
  45. [0., 0., 0., 0., 0.],
  46. [0., 0., 0., 0., 0.]],
  47. [[0., 0., 0., 0., 0.],
  48. [0., 0., 0., 0., 0.],
  49. [0., 0., 0., 0., 0.],
  50. [0., 0., 0., 0., 0.]],
  51. [[0., 0., 0., 0., 0.],
  52. [0., 0., 0., 0., 0.],
  53. [0., 0., 0., 0., 0.],
  54. [0., 0., 0., 0., 0.]]],
  55. [[[0., 0., 1., 1., 0.],
  56. [0., 0., 1., 1., 0.],
  57. [0., 0., 0., 0., 0.],
  58. [0., 0., 0., 0., 0.]],
  59. [[0., 0., 1., 1., 0.],
  60. [0., 0., 1., 1., 0.],
  61. [0., 0., 0., 0., 0.],
  62. [0., 0., 0., 0., 0.]],
  63. [[0., 0., 0., 0., 0.],
  64. [0., 0., 0., 0., 0.],
  65. [0., 0., 0., 0., 0.],
  66. [0., 0., 0., 0., 0.]]]]).astype(nptype)
  67. assert np.allclose(dx[0].asnumpy(), expect)
  68. net = StridedSliceNet((1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2))
  69. dx = GradData(net)(x)
  70. expect = np.array([[[[0., 0., 0., 0., 0.],
  71. [0., 0., 0., 0., 0.],
  72. [0., 0., 0., 0., 0.],
  73. [0., 0., 0., 0., 0.]],
  74. [[0., 0., 0., 0., 0.],
  75. [0., 0., 0., 0., 0.],
  76. [0., 0., 0., 0., 0.],
  77. [0., 0., 0., 0., 0.]],
  78. [[0., 0., 0., 0., 0.],
  79. [0., 0., 0., 0., 0.],
  80. [0., 0., 0., 0., 0.],
  81. [0., 0., 0., 0., 0.]]],
  82. [[[0., 0., 1., 0., 1.],
  83. [0., 0., 1., 0., 1.],
  84. [0., 0., 0., 0., 0.],
  85. [0., 0., 0., 0., 0.]],
  86. [[0., 0., 1., 0., 1.],
  87. [0., 0., 1., 0., 1.],
  88. [0., 0., 0., 0., 0.],
  89. [0., 0., 0., 0., 0.]],
  90. [[0., 0., 0., 0., 0.],
  91. [0., 0., 0., 0., 0.],
  92. [0., 0., 0., 0., 0.],
  93. [0., 0., 0., 0., 0.]]]]).astype(nptype)
  94. assert np.allclose(dx[0].asnumpy(), expect)
  95. net = StridedSliceNet((1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1))
  96. dx = GradData(net)(x)
  97. expect = np.array([[[[0., 0., 0., 0., 0.],
  98. [0., 0., 0., 0., 0.],
  99. [0., 0., 0., 0., 0.],
  100. [0., 0., 0., 0., 0.]],
  101. [[0., 0., 0., 0., 0.],
  102. [0., 0., 0., 0., 0.],
  103. [0., 0., 0., 0., 0.],
  104. [0., 0., 0., 0., 0.]],
  105. [[0., 0., 0., 0., 0.],
  106. [0., 0., 0., 0., 0.],
  107. [0., 0., 0., 0., 0.],
  108. [0., 0., 0., 0., 0.]]],
  109. [[[0., 0., 1., 1., 1.],
  110. [0., 0., 1., 1., 1.],
  111. [0., 0., 0., 0., 0.],
  112. [0., 0., 0., 0., 0.]],
  113. [[0., 0., 1., 1., 1.],
  114. [0., 0., 1., 1., 1.],
  115. [0., 0., 0., 0., 0.],
  116. [0., 0., 0., 0., 0.]],
  117. [[0., 0., 0., 0., 0.],
  118. [0., 0., 0., 0., 0.],
  119. [0., 0., 0., 0., 0.],
  120. [0., 0., 0., 0., 0.]]]]).astype(nptype)
  121. assert np.allclose(dx[0].asnumpy(), expect)
  122. net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1),
  123. begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100)
  124. dx = GradData(net)(x)
  125. expect = np.array([[[[0., 0., 0., 0., 0.],
  126. [0., 0., 0., 0., 0.],
  127. [0., 0., 0., 0., 0.],
  128. [0., 0., 0., 0., 0.]],
  129. [[0., 0., 0., 0., 0.],
  130. [0., 0., 0., 0., 0.],
  131. [0., 0., 0., 0., 0.],
  132. [0., 0., 0., 0., 0.]],
  133. [[0., 0., 0., 0., 0.],
  134. [0., 0., 0., 0., 0.],
  135. [0., 0., 0., 0., 0.],
  136. [0., 0., 0., 0., 0.]]],
  137. [[[1., 1., 1., 1., 0.],
  138. [1., 1., 1., 1., 0.],
  139. [1., 1., 1., 1., 0.],
  140. [1., 1., 1., 1., 0.]],
  141. [[1., 1., 1., 1., 0.],
  142. [1., 1., 1., 1., 0.],
  143. [1., 1., 1., 1., 0.],
  144. [1., 1., 1., 1., 0.]],
  145. [[1., 1., 1., 1., 0.],
  146. [1., 1., 1., 1., 0.],
  147. [1., 1., 1., 1., 0.],
  148. [1., 1., 1., 1., 0.]]]]).astype(nptype)
  149. assert np.allclose(dx[0].asnumpy(), expect)
  150. x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32))
  151. net = StridedSliceNet((1, 0, 0), (2, -3, 3), (1, 1, 3))
  152. dx = GradData(net)(x)
  153. expect = np.array([[[0., 0., 0., 0., 0.],
  154. [0., 0., 0., 0., 0.],
  155. [0., 0., 0., 0., 0.],
  156. [0., 0., 0., 0., 0.]],
  157. [[1., 0., 0., 0., 0.],
  158. [0., 0., 0., 0., 0.],
  159. [0., 0., 0., 0., 0.],
  160. [0., 0., 0., 0., 0.]],
  161. [[0., 0., 0., 0., 0.],
  162. [0., 0., 0., 0., 0.],
  163. [0., 0., 0., 0., 0.],
  164. [0., 0., 0., 0., 0.]]]).astype(nptype)
  165. assert np.allclose(dx[0].asnumpy(), expect)
  166. x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(nptype))
  167. net = StridedSliceNet((0, 0, 0, 1, 1, 2, 2), (1, 1, 1, 2, 3, 3, 4), (1, 1, 1, 1, 1, 1, 1))
  168. dx = GradData(net)(x)
  169. expect = np.array([[[[[[[0., 0., 0., 0., 0.],
  170. [0., 0., 0., 0., 0.],
  171. [0., 0., 0., 0., 0.],
  172. [0., 0., 0., 0., 0.]],
  173. [[0., 0., 0., 0., 0.],
  174. [0., 0., 0., 0., 0.],
  175. [0., 0., 0., 0., 0.],
  176. [0., 0., 0., 0., 0.]],
  177. [[0., 0., 0., 0., 0.],
  178. [0., 0., 0., 0., 0.],
  179. [0., 0., 0., 0., 0.],
  180. [0., 0., 0., 0., 0.]]],
  181. [[[0., 0., 0., 0., 0.],
  182. [0., 0., 0., 0., 0.],
  183. [0., 0., 0., 0., 0.],
  184. [0., 0., 0., 0., 0.]],
  185. [[0., 0., 0., 0., 0.],
  186. [0., 0., 0., 0., 0.],
  187. [0., 0., 1., 1., 0.],
  188. [0., 0., 0., 0., 0.]],
  189. [[0., 0., 0., 0., 0.],
  190. [0., 0., 0., 0., 0.],
  191. [0., 0., 1., 1., 0.],
  192. [0., 0., 0., 0., 0.]]]]]]]).astype(nptype)
  193. assert np.allclose(dx[0].asnumpy(), expect)
  194. @pytest.mark.level0
  195. @pytest.mark.platform_x86_gpu_training
  196. @pytest.mark.env_onecard
  197. def test_strided_slice_grad_float64():
  198. strided_slice_grad(np.float64)
  199. @pytest.mark.level0
  200. @pytest.mark.platform_x86_gpu_training
  201. @pytest.mark.env_onecard
  202. def test_strided_slice_grad_float32():
  203. strided_slice_grad(np.float32)
  204. @pytest.mark.level0
  205. @pytest.mark.platform_x86_gpu_training
  206. @pytest.mark.env_onecard
  207. def test_strided_slice_grad_float16():
  208. strided_slice_grad(np.float16)
  209. @pytest.mark.level0
  210. @pytest.mark.platform_x86_gpu_training
  211. @pytest.mark.env_onecard
  212. def test_strided_slice_grad_int64():
  213. strided_slice_grad(np.int64)
  214. @pytest.mark.level0
  215. @pytest.mark.platform_x86_gpu_training
  216. @pytest.mark.env_onecard
  217. def test_strided_slice_grad_int32():
  218. strided_slice_grad(np.int32)
  219. @pytest.mark.level0
  220. @pytest.mark.platform_x86_gpu_training
  221. @pytest.mark.env_onecard
  222. def test_strided_slice_grad_int16():
  223. strided_slice_grad(np.int16)
  224. @pytest.mark.level0
  225. @pytest.mark.platform_x86_gpu_training
  226. @pytest.mark.env_onecard
  227. def test_strided_slice_grad_int8():
  228. strided_slice_grad(np.int8)
  229. @pytest.mark.level0
  230. @pytest.mark.platform_x86_gpu_training
  231. @pytest.mark.env_onecard
  232. def test_strided_slice_grad_uint64():
  233. strided_slice_grad(np.uint64)
  234. @pytest.mark.level0
  235. @pytest.mark.platform_x86_gpu_training
  236. @pytest.mark.env_onecard
  237. def test_strided_slice_grad_uint32():
  238. strided_slice_grad(np.uint32)
  239. @pytest.mark.level0
  240. @pytest.mark.platform_x86_gpu_training
  241. @pytest.mark.env_onecard
  242. def test_strided_slice_grad_uint16():
  243. strided_slice_grad(np.uint16)
  244. @pytest.mark.level0
  245. @pytest.mark.platform_x86_gpu_training
  246. @pytest.mark.env_onecard
  247. def test_strided_slice_grad_uint8():
  248. strided_slice_grad(np.uint8)
  249. @pytest.mark.level0
  250. @pytest.mark.platform_x86_gpu_training
  251. @pytest.mark.env_onecard
  252. def test_strided_slice_grad_bool():
  253. strided_slice_grad(np.bool)