You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_stridedslice_grad_op.py 13 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import composite as C
  22. class StridedSliceNet(nn.Cell):
  23. def __init__(self, begin, end, stride, begin_mask=0, end_mask=0, ellipsis_mask=0):
  24. super(StridedSliceNet, self).__init__()
  25. self.begin = begin
  26. self.end = end
  27. self.strides = stride
  28. self.slice = P.StridedSlice(begin_mask, end_mask, ellipsis_mask)
  29. def construct(self, x):
  30. return self.slice(x, self.begin, self.end, self.strides)
  31. class GradData(nn.Cell):
  32. def __init__(self, network):
  33. super(GradData, self).__init__()
  34. self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=False)
  35. self.network = network
  36. def construct(self, x):
  37. return self.grad(self.network)(x)
  38. def strided_slice_grad(nptype):
  39. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  40. x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(nptype))
  41. net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  42. dx = GradData(net)(x)
  43. expect = np.array([[[[0., 0., 0., 0., 0.],
  44. [0., 0., 0., 0., 0.],
  45. [0., 0., 0., 0., 0.],
  46. [0., 0., 0., 0., 0.]],
  47. [[0., 0., 0., 0., 0.],
  48. [0., 0., 0., 0., 0.],
  49. [0., 0., 0., 0., 0.],
  50. [0., 0., 0., 0., 0.]],
  51. [[0., 0., 0., 0., 0.],
  52. [0., 0., 0., 0., 0.],
  53. [0., 0., 0., 0., 0.],
  54. [0., 0., 0., 0., 0.]]],
  55. [[[0., 0., 1., 1., 0.],
  56. [0., 0., 1., 1., 0.],
  57. [0., 0., 0., 0., 0.],
  58. [0., 0., 0., 0., 0.]],
  59. [[0., 0., 1., 1., 0.],
  60. [0., 0., 1., 1., 0.],
  61. [0., 0., 0., 0., 0.],
  62. [0., 0., 0., 0., 0.]],
  63. [[0., 0., 0., 0., 0.],
  64. [0., 0., 0., 0., 0.],
  65. [0., 0., 0., 0., 0.],
  66. [0., 0., 0., 0., 0.]]]]).astype(nptype)
  67. assert np.allclose(dx[0].asnumpy(), expect)
  68. net = StridedSliceNet((1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2))
  69. dx = GradData(net)(x)
  70. expect = np.array([[[[0., 0., 0., 0., 0.],
  71. [0., 0., 0., 0., 0.],
  72. [0., 0., 0., 0., 0.],
  73. [0., 0., 0., 0., 0.]],
  74. [[0., 0., 0., 0., 0.],
  75. [0., 0., 0., 0., 0.],
  76. [0., 0., 0., 0., 0.],
  77. [0., 0., 0., 0., 0.]],
  78. [[0., 0., 0., 0., 0.],
  79. [0., 0., 0., 0., 0.],
  80. [0., 0., 0., 0., 0.],
  81. [0., 0., 0., 0., 0.]]],
  82. [[[0., 0., 1., 0., 1.],
  83. [0., 0., 1., 0., 1.],
  84. [0., 0., 0., 0., 0.],
  85. [0., 0., 0., 0., 0.]],
  86. [[0., 0., 1., 0., 1.],
  87. [0., 0., 1., 0., 1.],
  88. [0., 0., 0., 0., 0.],
  89. [0., 0., 0., 0., 0.]],
  90. [[0., 0., 0., 0., 0.],
  91. [0., 0., 0., 0., 0.],
  92. [0., 0., 0., 0., 0.],
  93. [0., 0., 0., 0., 0.]]]]).astype(nptype)
  94. assert np.allclose(dx[0].asnumpy(), expect)
  95. net = StridedSliceNet((1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1))
  96. dx = GradData(net)(x)
  97. expect = np.array([[[[0., 0., 0., 0., 0.],
  98. [0., 0., 0., 0., 0.],
  99. [0., 0., 0., 0., 0.],
  100. [0., 0., 0., 0., 0.]],
  101. [[0., 0., 0., 0., 0.],
  102. [0., 0., 0., 0., 0.],
  103. [0., 0., 0., 0., 0.],
  104. [0., 0., 0., 0., 0.]],
  105. [[0., 0., 0., 0., 0.],
  106. [0., 0., 0., 0., 0.],
  107. [0., 0., 0., 0., 0.],
  108. [0., 0., 0., 0., 0.]]],
  109. [[[0., 0., 1., 1., 1.],
  110. [0., 0., 1., 1., 1.],
  111. [0., 0., 0., 0., 0.],
  112. [0., 0., 0., 0., 0.]],
  113. [[0., 0., 1., 1., 1.],
  114. [0., 0., 1., 1., 1.],
  115. [0., 0., 0., 0., 0.],
  116. [0., 0., 0., 0., 0.]],
  117. [[0., 0., 0., 0., 0.],
  118. [0., 0., 0., 0., 0.],
  119. [0., 0., 0., 0., 0.],
  120. [0., 0., 0., 0., 0.]]]]).astype(nptype)
  121. assert np.allclose(dx[0].asnumpy(), expect)
  122. # ME infer fault
  123. # y = GradData()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2))
  124. # expect = np.array([[[[0., 0., 0., 0., 0.],
  125. # [0., 0., 0., 0., 0.],
  126. # [0., 0., 0., 0., 0.],
  127. # [0., 0., 0., 0., 0.]],
  128. # [[0., 0., 0., 0., 0.],
  129. # [0., 0., 0., 0., 0.],
  130. # [0., 0., 0., 0., 0.],
  131. # [0., 0., 0., 0., 0.]],
  132. # [[0., 0., 0., 0., 0.],
  133. # [0., 0., 0., 0., 0.],
  134. # [0., 0., 0., 0., 0.],
  135. # [0., 0., 0., 0., 0.]]],
  136. # [[[0., 0., 0., 0., 0.],
  137. # [0., 1., 0., 1., 0.],
  138. # [0., 1., 0., 1., 0.],
  139. # [0., 1., 0., 1., 0.]],
  140. # [[0., 0., 0., 0., 0.],
  141. # [0., 1., 0., 1., 0.],
  142. # [0., 1., 0., 1., 0.],
  143. # [0., 1., 0., 1., 0.]],begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100
  144. # [[0., 0., 0., 0., 0.],
  145. # [0., 0., 0., 0., 0.],
  146. # [0., 0., 0., 0., 0.],
  147. # [0., 0., 0., 0., 0.]]]])
  148. # assert np.allclose(y.asnumpy(), expect)
  149. # y = Grad(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
  150. # expect = np.array([[[[0., 0., 0., 0., 0.],
  151. # [0., 0., 0., 0., 0.],
  152. # [0., 0., 0., 0., 0.],
  153. # [0., 0., 0., 0., 0.]],
  154. # [[0., 0., 0., 0., 0.],
  155. # [0., 0., 0., 0., 0.],
  156. # [0., 0., 0., 0., 0.],
  157. # [0., 0., 0., 0., 0.]],
  158. # [[0., 0., 0., 0., 0.],
  159. # [0., 0., 0., 0., 0.],
  160. # [0., 0., 0., 0., 0.],
  161. # [0., 0., 0., 0., 0.]]],
  162. # [[[0., 0., 1., 1., 0.],
  163. # [0., 0., 1., 1., 0.],
  164. # [0., 0., 0., 0., 0.],
  165. # [0., 0., 0., 0., 0.]],
  166. # [[0., 0., 1., 1., 0.],
  167. # [0., 0., 1., 1., 0.],
  168. # [0., 0., 0., 0., 0.],
  169. # [0., 0., 0., 0., 0.]],
  170. # [[0., 0., 0., 0., 0.],
  171. # [0., 0., 0., 0., 0.],
  172. # [0., 0., 0., 0., 0.],
  173. # [0., 0., 0., 0., 0.]]]])
  174. # assert np.allclose(y.asnumpy(), expect)
  175. net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1),
  176. begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100)
  177. dx = GradData(net)(x)
  178. expect = np.array([[[[0., 0., 0., 0., 0.],
  179. [0., 0., 0., 0., 0.],
  180. [0., 0., 0., 0., 0.],
  181. [0., 0., 0., 0., 0.]],
  182. [[0., 0., 0., 0., 0.],
  183. [0., 0., 0., 0., 0.],
  184. [0., 0., 0., 0., 0.],
  185. [0., 0., 0., 0., 0.]],
  186. [[0., 0., 0., 0., 0.],
  187. [0., 0., 0., 0., 0.],
  188. [0., 0., 0., 0., 0.],
  189. [0., 0., 0., 0., 0.]]],
  190. [[[1., 1., 1., 1., 0.],
  191. [1., 1., 1., 1., 0.],
  192. [1., 1., 1., 1., 0.],
  193. [1., 1., 1., 1., 0.]],
  194. [[1., 1., 1., 1., 0.],
  195. [1., 1., 1., 1., 0.],
  196. [1., 1., 1., 1., 0.],
  197. [1., 1., 1., 1., 0.]],
  198. [[1., 1., 1., 1., 0.],
  199. [1., 1., 1., 1., 0.],
  200. [1., 1., 1., 1., 0.],
  201. [1., 1., 1., 1., 0.]]]]).astype(nptype)
  202. assert np.allclose(dx[0].asnumpy(), expect)
  203. x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32))
  204. net = StridedSliceNet((1, 0, 0), (2, -3, 3), (1, 1, 3))
  205. dx = GradData(net)(x)
  206. expect = np.array([[[0., 0., 0., 0., 0.],
  207. [0., 0., 0., 0., 0.],
  208. [0., 0., 0., 0., 0.],
  209. [0., 0., 0., 0., 0.]],
  210. [[1., 0., 0., 0., 0.],
  211. [0., 0., 0., 0., 0.],
  212. [0., 0., 0., 0., 0.],
  213. [0., 0., 0., 0., 0.]],
  214. [[0., 0., 0., 0., 0.],
  215. [0., 0., 0., 0., 0.],
  216. [0., 0., 0., 0., 0.],
  217. [0., 0., 0., 0., 0.]]]).astype(nptype)
  218. assert np.allclose(dx[0].asnumpy(), expect)
  219. x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(nptype))
  220. net = StridedSliceNet((0, 0, 0, 1, 1, 2, 2), (1, 1, 1, 2, 3, 3, 4), (1, 1, 1, 1, 1, 1, 1))
  221. dx = GradData(net)(x)
  222. expect = np.array([[[[[[[0., 0., 0., 0., 0.],
  223. [0., 0., 0., 0., 0.],
  224. [0., 0., 0., 0., 0.],
  225. [0., 0., 0., 0., 0.]],
  226. [[0., 0., 0., 0., 0.],
  227. [0., 0., 0., 0., 0.],
  228. [0., 0., 0., 0., 0.],
  229. [0., 0., 0., 0., 0.]],
  230. [[0., 0., 0., 0., 0.],
  231. [0., 0., 0., 0., 0.],
  232. [0., 0., 0., 0., 0.],
  233. [0., 0., 0., 0., 0.]]],
  234. [[[0., 0., 0., 0., 0.],
  235. [0., 0., 0., 0., 0.],
  236. [0., 0., 0., 0., 0.],
  237. [0., 0., 0., 0., 0.]],
  238. [[0., 0., 0., 0., 0.],
  239. [0., 0., 0., 0., 0.],
  240. [0., 0., 1., 1., 0.],
  241. [0., 0., 0., 0., 0.]],
  242. [[0., 0., 0., 0., 0.],
  243. [0., 0., 0., 0., 0.],
  244. [0., 0., 1., 1., 0.],
  245. [0., 0., 0., 0., 0.]]]]]]]).astype(nptype)
  246. assert np.allclose(dx[0].asnumpy(), expect)
  247. @pytest.mark.level0
  248. @pytest.mark.platform_x86_gpu_training
  249. @pytest.mark.env_onecard
  250. def test_strided_slice_grad_float32():
  251. strided_slice_grad(np.float32)
  252. @pytest.mark.level0
  253. @pytest.mark.platform_x86_gpu_training
  254. @pytest.mark.env_onecard
  255. def test_strided_slice_grad_int16():
  256. strided_slice_grad(np.int16)
  257. @pytest.mark.level0
  258. @pytest.mark.platform_x86_gpu_training
  259. @pytest.mark.env_onecard
  260. def test_strided_slice_grad_uint8():
  261. strided_slice_grad(np.uint8)
  262. @pytest.mark.level0
  263. @pytest.mark.platform_x86_gpu_training
  264. @pytest.mark.env_onecard
  265. def test_strided_slice_grad_bool():
  266. strided_slice_grad(np.bool)