You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cumsum_op.py 9.0 kB


  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  22. axis0 = 0
  23. axis1 = 1
  24. axis2 = 2
  25. axis3 = 3
  26. axis4 = 4
  27. axis5 = -1
  28. axis6 = -2
  29. x0 = np.random.rand(3, 3, 4, 5, 3).astype(np.float32)
  30. x1 = np.random.rand(2, 3, 4, 5, 3).astype(np.float16)
  31. x2 = np.random.randint(-10000, 10000, size=(2, 3, 4, 5, 3)).astype(np.int32)
  32. x3 = np.random.randint(-5, 5, size=(2, 3, 4, 5, 3)).astype(np.int8)
  33. x4 = np.random.randint(0, 10, size=(2, 3, 4, 5, 3)).astype(np.uint8)
  34. x5 = np.random.rand(3).astype(np.float32)
  35. list1 = [x0, x1, x2, x3, x4]
  36. list2 = [axis0, axis1, axis2, axis3, axis4, axis5, axis6]
  37. class CumSum(nn.Cell):
  38. def __init__(self, exclusive=False, reverse=False):
  39. super(CumSum, self).__init__()
  40. self.cumsum_op = P.CumSum(exclusive, reverse)
  41. self.x0 = Tensor(x0)
  42. self.axis0 = axis0
  43. self.x1 = Tensor(x0)
  44. self.axis1 = axis1
  45. self.x2 = Tensor(x0)
  46. self.axis2 = axis2
  47. self.x3 = Tensor(x0)
  48. self.axis3 = axis3
  49. self.x4 = Tensor(x0)
  50. self.axis4 = axis4
  51. self.x5 = Tensor(x0)
  52. self.axis5 = axis5
  53. self.x6 = Tensor(x0)
  54. self.axis6 = axis6
  55. self.x7 = Tensor(x1)
  56. self.axis7 = axis0
  57. self.x8 = Tensor(x1)
  58. self.axis8 = axis1
  59. self.x9 = Tensor(x1)
  60. self.axis9 = axis2
  61. self.x10 = Tensor(x1)
  62. self.axis10 = axis3
  63. self.x11 = Tensor(x1)
  64. self.axis11 = axis4
  65. self.x12 = Tensor(x1)
  66. self.axis12 = axis5
  67. self.x13 = Tensor(x1)
  68. self.axis13 = axis6
  69. self.x14 = Tensor(x2)
  70. self.axis14 = axis0
  71. self.x15 = Tensor(x2)
  72. self.axis15 = axis1
  73. self.x16 = Tensor(x2)
  74. self.axis16 = axis2
  75. self.x17 = Tensor(x2)
  76. self.axis17 = axis3
  77. self.x18 = Tensor(x2)
  78. self.axis18 = axis4
  79. self.x19 = Tensor(x2)
  80. self.axis19 = axis5
  81. self.x20 = Tensor(x2)
  82. self.axis20 = axis6
  83. self.x21 = Tensor(x3)
  84. self.axis21 = axis0
  85. self.x22 = Tensor(x3)
  86. self.axis22 = axis1
  87. self.x23 = Tensor(x3)
  88. self.axis23 = axis2
  89. self.x24 = Tensor(x3)
  90. self.axis24 = axis3
  91. self.x25 = Tensor(x3)
  92. self.axis25 = axis4
  93. self.x26 = Tensor(x3)
  94. self.axis26 = axis5
  95. self.x27 = Tensor(x3)
  96. self.axis27 = axis6
  97. self.x28 = Tensor(x4)
  98. self.axis28 = axis0
  99. self.x29 = Tensor(x4)
  100. self.axis29 = axis1
  101. self.x30 = Tensor(x4)
  102. self.axis30 = axis2
  103. self.x31 = Tensor(x4)
  104. self.axis31 = axis3
  105. self.x32 = Tensor(x4)
  106. self.axis32 = axis4
  107. self.x33 = Tensor(x4)
  108. self.axis33 = axis5
  109. self.x34 = Tensor(x4)
  110. self.axis34 = axis6
  111. self.x35 = Tensor(x5)
  112. self.axis35 = axis0
  113. def construct(self):
  114. return (self.cumsum_op(self.x0, self.axis0),
  115. self.cumsum_op(self.x1, self.axis1),
  116. self.cumsum_op(self.x2, self.axis2),
  117. self.cumsum_op(self.x3, self.axis3),
  118. self.cumsum_op(self.x4, self.axis4),
  119. self.cumsum_op(self.x5, self.axis5),
  120. self.cumsum_op(self.x6, self.axis6),
  121. self.cumsum_op(self.x7, self.axis7),
  122. self.cumsum_op(self.x8, self.axis8),
  123. self.cumsum_op(self.x9, self.axis9),
  124. self.cumsum_op(self.x10, self.axis10),
  125. self.cumsum_op(self.x11, self.axis11),
  126. self.cumsum_op(self.x12, self.axis12),
  127. self.cumsum_op(self.x13, self.axis13),
  128. self.cumsum_op(self.x14, self.axis14),
  129. self.cumsum_op(self.x15, self.axis15),
  130. self.cumsum_op(self.x16, self.axis16),
  131. self.cumsum_op(self.x17, self.axis17),
  132. self.cumsum_op(self.x18, self.axis18),
  133. self.cumsum_op(self.x19, self.axis19),
  134. self.cumsum_op(self.x20, self.axis20),
  135. self.cumsum_op(self.x21, self.axis21),
  136. self.cumsum_op(self.x22, self.axis22),
  137. self.cumsum_op(self.x23, self.axis23),
  138. self.cumsum_op(self.x24, self.axis24),
  139. self.cumsum_op(self.x25, self.axis25),
  140. self.cumsum_op(self.x26, self.axis26),
  141. self.cumsum_op(self.x27, self.axis27),
  142. self.cumsum_op(self.x28, self.axis28),
  143. self.cumsum_op(self.x29, self.axis29),
  144. self.cumsum_op(self.x30, self.axis30),
  145. self.cumsum_op(self.x31, self.axis31),
  146. self.cumsum_op(self.x32, self.axis32),
  147. self.cumsum_op(self.x33, self.axis33),
  148. self.cumsum_op(self.x34, self.axis34),
  149. self.cumsum_op(self.x35, self.axis35))
  150. @pytest.mark.level0
  151. @pytest.mark.platform_x86_cpu
  152. @pytest.mark.env_onecard
  153. def test_cumsum():
  154. cumsum = CumSum()
  155. output = cumsum()
  156. k = 0
  157. for i in list1:
  158. for j in list2:
  159. expect = np.cumsum(i, axis=j)
  160. diff = abs(output[k].asnumpy() - expect)
  161. error = np.ones(shape=expect.shape) * 1.0e-5
  162. assert np.all(diff < error)
  163. assert output[k].shape == expect.shape
  164. k += 1
  165. expect = np.cumsum(x5, axis=axis0)
  166. diff = abs(output[k].asnumpy() - expect)
  167. error = np.ones(shape=expect.shape) * 1.0e-5
  168. assert np.all(diff < error)
  169. assert output[k].shape == expect.shape
  170. def test_cumsum2():
  171. cumsum = CumSum(exclusive=False, reverse=True)
  172. output = cumsum()
  173. k = 0
  174. for i in list1:
  175. for j in list2:
  176. result1 = np.flip(i, axis=j)
  177. result2 = np.cumsum(result1, axis=j)
  178. expect = np.flip(result2, axis=j)
  179. diff = abs(output[k].asnumpy() - expect)
  180. error = np.ones(shape=expect.shape) * 1.0e-5
  181. assert np.all(diff < error)
  182. assert output[k].shape == expect.shape
  183. k += 1
  184. result1 = np.flip(x5, axis=axis0)
  185. result2 = np.cumsum(result1, axis=axis0)
  186. expect = np.flip(result2, axis=axis0)
  187. diff = abs(output[k].asnumpy() - expect)
  188. error = np.ones(shape=expect.shape) * 1.0e-5
  189. assert np.all(diff < error)
  190. assert output[k].shape == expect.shape
  191. def test_cumsum3():
  192. cumsum = CumSum(exclusive=True, reverse=False)
  193. output = cumsum()
  194. k = 0
  195. for i in list1:
  196. for j in list2:
  197. result1 = np.insert(i, 0, [0], axis=j)
  198. result2 = np.delete(result1, -1, axis=j)
  199. expect = np.cumsum(result2, axis=j)
  200. diff = abs(output[k].asnumpy() - expect)
  201. error = np.ones(shape=expect.shape) * 1.0e-5
  202. assert np.all(diff < error)
  203. assert output[k].shape == expect.shape
  204. k += 1
  205. result1 = np.insert(x5, 0, [0], axis=axis0)
  206. result2 = np.delete(result1, -1, axis=axis0)
  207. expect = np.cumsum(result2, axis=axis0)
  208. diff = abs(output[k].asnumpy() - expect)
  209. error = np.ones(shape=expect.shape) * 1.0e-5
  210. assert np.all(diff < error)
  211. assert output[k].shape == expect.shape
  212. def test_cumsum4():
  213. cumsum = CumSum(exclusive=True, reverse=True)
  214. output = cumsum()
  215. k = 0
  216. for i in list1:
  217. for j in list2:
  218. result1 = np.flip(i, axis=j)
  219. result2 = np.insert(result1, 0, [0], axis=j)
  220. result3 = np.delete(result2, -1, axis=j)
  221. result4 = np.cumsum(result3, axis=j)
  222. expect = np.flip(result4, axis=j)
  223. diff = abs(output[k].asnumpy() - expect)
  224. error = np.ones(shape=expect.shape) * 1.0e-5
  225. assert np.all(diff < error)
  226. assert output[k].shape == expect.shape
  227. k += 1
  228. result1 = np.flip(x5, axis=axis0)
  229. result2 = np.insert(result1, 0, [0], axis=axis0)
  230. result3 = np.delete(result2, -1, axis=axis0)
  231. result4 = np.cumsum(result3, axis=axis0)
  232. expect = np.flip(result4, axis=axis0)
  233. diff = abs(output[k].asnumpy() - expect)
  234. error = np.ones(shape=expect.shape) * 1.0e-5
  235. assert np.all(diff < error)
  236. assert output[k].shape == expect.shape