You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_split_op.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
  22. class OpNetWrapper(nn.Cell):
  23. def __init__(self, op):
  24. super(OpNetWrapper, self).__init__()
  25. self.op = op
  26. def construct(self, *inputs):
  27. return self.op(*inputs)
  28. @pytest.mark.level0
  29. @pytest.mark.platform_x86_cpu
  30. @pytest.mark.env_onecard
  31. def test_out1_axis0():
  32. op = P.Split(0, 1)
  33. op_wrapper = OpNetWrapper(op)
  34. input_x = Tensor(np.arange(24).astype(np.int32).reshape((2, 2, 6)))
  35. outputs = op_wrapper(input_x)
  36. print(outputs)
  37. assert outputs[0].shape == (2, 2, 6)
  38. assert np.allclose(outputs[0].asnumpy()[0, 0, :], [0, 1, 2, 3, 4, 5])
  39. @pytest.mark.level0
  40. @pytest.mark.platform_x86_cpu
  41. @pytest.mark.env_onecard
  42. def test_out2_axis2():
  43. op = P.Split(2, 2)
  44. op_wrapper = OpNetWrapper(op)
  45. input_x = Tensor(np.arange(24).astype(np.int32).reshape((2, 2, 6)))
  46. outputs = op_wrapper(input_x)
  47. print(outputs)
  48. assert outputs[0].shape == (2, 2, 3)
  49. assert outputs[1].shape == (2, 2, 3)
  50. assert np.allclose(outputs[0].asnumpy()[0, 0, :], [0, 1, 2])
  51. assert np.allclose(outputs[1].asnumpy()[0, 0, :], [3, 4, 5])
  52. @pytest.mark.level0
  53. @pytest.mark.platform_x86_cpu
  54. @pytest.mark.env_onecard
  55. def test_out2_axis1neg():
  56. op = P.Split(-1, 2)
  57. op_wrapper = OpNetWrapper(op)
  58. input_x = Tensor(np.arange(24).astype(np.float32).reshape((2, 2, 6)))
  59. outputs = op_wrapper(input_x)
  60. print(outputs)
  61. assert np.allclose(outputs[0].asnumpy()[0, :, :], [[0., 1., 2.], [6., 7., 8.]])
  62. assert np.allclose(outputs[1].asnumpy()[0, :, :], [[3., 4., 5.], [9., 10., 11.]])
  63. if __name__ == '__main__':
  64. test_out1_axis0()
  65. test_out2_axis2()
  66. test_out2_axis1neg()