You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cumsum_op.py 4.0 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.common.api import ms_function
  21. from mindspore.ops import operations as P
  22. x0 = np.random.rand(2, 3, 4, 4).astype(np.float32)
  23. axis0 = 3
  24. x1 = np.random.rand(2, 3, 4, 4).astype(np.float32)
  25. axis1 = 3
  26. x2 = np.random.rand(2, 3, 1, 4).astype(np.float32)
  27. axis2 = 2
  28. x3 = np.random.rand(2, 3, 1, 4).astype(np.float32)
  29. axis3 = 2
  30. x4 = np.random.rand(2, 3, 4, 4).astype(np.float32)
  31. axis4 = 1
  32. x5 = np.random.rand(2, 3).astype(np.float32)
  33. axis5 = 1
  34. x6 = np.random.rand(1, 1, 1, 1).astype(np.float32)
  35. axis6 = 0
  36. context.set_context(device_target='GPU')
  37. class CumSum(nn.Cell):
  38. def __init__(self):
  39. super(CumSum, self).__init__()
  40. self.x0 = Tensor(x0)
  41. self.axis0 = axis0
  42. self.x1 = Tensor(x1)
  43. self.axis1 = axis1
  44. self.x2 = Tensor(x2)
  45. self.axis2 = axis2
  46. self.x3 = Tensor(x3)
  47. self.axis3 = axis3
  48. self.x4 = Tensor(x4)
  49. self.axis4 = axis4
  50. self.x5 = Tensor(x5)
  51. self.axis5 = axis5
  52. self.x6 = Tensor(x6)
  53. self.axis6 = axis6
  54. @ms_function
  55. def construct(self):
  56. return (P.CumSum()(self.x0, self.axis0),
  57. P.CumSum()(self.x1, self.axis1),
  58. P.CumSum()(self.x2, self.axis2),
  59. P.CumSum()(self.x3, self.axis3),
  60. P.CumSum()(self.x4, self.axis4),
  61. P.CumSum()(self.x5, self.axis5),
  62. P.CumSum()(self.x6, self.axis6))
  63. @pytest.mark.level0
  64. @pytest.mark.platform_x86_gpu_training
  65. @pytest.mark.env_onecard
  66. def test_CumSum():
  67. cumsum = CumSum()
  68. output = cumsum()
  69. expect0 = np.cumsum(x0, axis=axis0)
  70. diff0 = abs(output[0].asnumpy() - expect0)
  71. error0 = np.ones(shape=expect0.shape) * 1.0e-5
  72. assert np.all(diff0 < error0)
  73. assert output[0].shape == expect0.shape
  74. expect1 = np.cumsum(x1, axis=axis1)
  75. diff1 = abs(output[1].asnumpy() - expect1)
  76. error1 = np.ones(shape=expect1.shape) * 1.0e-5
  77. assert np.all(diff1 < error1)
  78. assert output[1].shape == expect1.shape
  79. expect2 = np.cumsum(x2, axis=axis2)
  80. diff2 = abs(output[2].asnumpy() - expect2)
  81. error2 = np.ones(shape=expect2.shape) * 1.0e-5
  82. assert np.all(diff2 < error2)
  83. assert output[2].shape == expect2.shape
  84. expect3 = np.cumsum(x3, axis=axis3)
  85. diff3 = abs(output[3].asnumpy() - expect3)
  86. error3 = np.ones(shape=expect3.shape) * 1.0e-5
  87. assert np.all(diff3 < error3)
  88. assert output[3].shape == expect3.shape
  89. expect4 = np.cumsum(x4, axis=axis4)
  90. diff4 = abs(output[4].asnumpy() - expect4)
  91. error4 = np.ones(shape=expect4.shape) * 1.0e-5
  92. assert np.all(diff4 < error4)
  93. assert output[4].shape == expect4.shape
  94. expect5 = np.cumsum(x5, axis=axis5)
  95. diff5 = abs(output[5].asnumpy() - expect5)
  96. error5 = np.ones(shape=expect5.shape) * 1.0e-5
  97. assert np.all(diff5 < error5)
  98. assert output[5].shape == expect5.shape
  99. expect6 = np.cumsum(x6, axis=axis6)
  100. diff6 = abs(output[6].asnumpy() - expect6)
  101. error6 = np.ones(shape=expect6.shape) * 1.0e-5
  102. assert np.all(diff6 < error6)
  103. assert output[6].shape == expect6.shape