You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_randperm.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore
  18. import mindspore.context as context
  19. import mindspore.nn as nn
  20. from mindspore import Tensor
  21. from mindspore.ops import operations as P
  22. class RandpermNet(nn.Cell):
  23. def __init__(self, max_length, pad, dtype):
  24. super(RandpermNet, self).__init__()
  25. self.randperm = P.Randperm(max_length, pad, dtype)
  26. def construct(self, x):
  27. return self.randperm(x)
  28. def randperm(max_length, pad, dtype, n):
  29. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  30. x = Tensor(np.array([n]).astype(np.int32))
  31. randperm_net = RandpermNet(max_length, pad, dtype)
  32. output = randperm_net(x).asnumpy()
  33. # verify permutation
  34. output_perm_sorted = np.sort(output[0:n])
  35. expected = np.arange(n)
  36. np.testing.assert_array_equal(expected, output_perm_sorted)
  37. # verify pad
  38. output_pad = output[n:]
  39. for e in output_pad:
  40. assert e == pad
  41. print(output)
  42. print(output.dtype)
  43. @pytest.mark.level1
  44. @pytest.mark.platform_x86_gpu_training
  45. @pytest.mark.env_onecard
  46. def test_randperm_int8():
  47. randperm(8, -1, mindspore.int8, 5)
  48. @pytest.mark.level1
  49. @pytest.mark.platform_x86_gpu_training
  50. @pytest.mark.env_onecard
  51. def test_randperm_int16():
  52. randperm(3, 0, mindspore.int16, 3)
  53. @pytest.mark.level1
  54. @pytest.mark.platform_x86_gpu_training
  55. @pytest.mark.env_onecard
  56. def test_randperm_int32():
  57. randperm(4, -6, mindspore.int32, 2)
  58. @pytest.mark.level1
  59. @pytest.mark.platform_x86_gpu_training
  60. @pytest.mark.env_onecard
  61. def test_randperm_int64():
  62. randperm(12, 128, mindspore.int64, 4)
  63. @pytest.mark.level1
  64. @pytest.mark.platform_x86_gpu_training
  65. @pytest.mark.env_onecard
  66. def test_randperm_uint8():
  67. randperm(8, 1, mindspore.uint8, 5)
  68. @pytest.mark.level1
  69. @pytest.mark.platform_x86_gpu_training
  70. @pytest.mark.env_onecard
  71. def test_randperm_uint16():
  72. randperm(8, 0, mindspore.uint16, 8)
  73. @pytest.mark.level0
  74. @pytest.mark.platform_x86_gpu_training
  75. @pytest.mark.env_onecard
  76. def test_randperm_uint32():
  77. randperm(4, 8, mindspore.uint32, 3)
  78. @pytest.mark.level0
  79. @pytest.mark.platform_x86_gpu_training
  80. @pytest.mark.env_onecard
  81. def test_randperm_uint64():
  82. randperm(5, 4, mindspore.uint64, 5)
  83. @pytest.mark.level0
  84. @pytest.mark.platform_x86_gpu_training
  85. @pytest.mark.env_onecard
  86. def test_randperm_n_too_large():
  87. with pytest.raises(RuntimeError) as info:
  88. randperm(1, 0, mindspore.int32, 2)
  89. assert "n (2) cannot exceed max_length_ (1)" in str(info.value)