You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_map.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. from mindspore import Tensor, nn, Parameter
  18. from mindspore.nn import Cell
  19. import mindspore as ms
  20. def test_map_args_size():
  21. """
  22. Feature: Check the size of inputs of map.
  23. Description: The size of inputs of map must be greater than 1.
  24. Expectation: The size of inputs of map must be greater than 1.
  25. """
  26. class MapNet(Cell):
  27. def __init__(self):
  28. super().__init__()
  29. self.relu = nn.ReLU()
  30. def mul(self, x=2, y=4):
  31. return x * y
  32. def construct(self, x):
  33. if map(self.mul) == 8:
  34. x = self.relu(x)
  35. return x
  36. input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32)
  37. input_me_x = Tensor(input_np_x)
  38. net = MapNet()
  39. with pytest.raises(Exception, match="The Map operator must have at least one argument."):
  40. ret = net(input_me_x)
  41. print("ret:", ret)
  42. def test_map_args_type():
  43. """
  44. Feature: Check the type of inputs of Map().
  45. Description: The type of inputs of Map() must be list, tuple or class.
  46. Expectation: The type of inputs of Map() must be list, tuple or class.
  47. """
  48. class MapNet(Cell):
  49. def __init__(self):
  50. super().__init__()
  51. self.relu = nn.ReLU()
  52. def mul(self, x=2, y=4):
  53. return x * y
  54. def construct(self, x):
  55. if map(self.mul, 3, 4) == 8:
  56. x = self.relu(x)
  57. return x
  58. input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32)
  59. input_me_x = Tensor(input_np_x)
  60. net = MapNet()
  61. with pytest.raises(Exception, match="Map can only be applied to list, tuple and class"):
  62. ret = net(input_me_x)
  63. print("ret:", ret)
  64. def test_map_args_full_make_list():
  65. """
  66. Feature: Check the types of all inputs in Map.
  67. Description: The types of all inputs in Map must be same.
  68. Expectation: The types of all inputs in Map must be same.
  69. """
  70. class MapNet(Cell):
  71. def mul(self, x=2, y=4):
  72. return x * y
  73. def construct(self, x, y):
  74. if map(self.mul, x, y) == [8]:
  75. x = y
  76. return x
  77. input_me_x = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  78. input_me_y = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  79. net = MapNet()
  80. with pytest.raises(Exception, match="The types of arguments in Map must be consistent"):
  81. ret = net([input_me_x], (input_me_y))
  82. print("ret:", ret)
  83. def test_map_args_full_make_list_same_length():
  84. """
  85. Feature: Check the length of list input Map.
  86. Description: The list in Map should have same length.
  87. Expectation: The list in Map should have same length.
  88. """
  89. class MapNet(Cell):
  90. def mul(self, x=2, y=4):
  91. return x * y
  92. def construct(self, x, y):
  93. if map(self.mul, x, y) == [8]:
  94. x = y
  95. return x
  96. input_me_x = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  97. input_me_y = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  98. net = MapNet()
  99. with pytest.raises(Exception, match="The length of lists in Map must be the same"):
  100. ret = net([input_me_x], [input_me_y, input_me_y])
  101. print("ret:", ret)
  102. def test_map_args_full_make_tuple_same_length():
  103. """
  104. Feature: Check the length of tuple input Map.
  105. Description: The tuple in Map should have same length.
  106. Expectation: The tuple in Map should have same length.
  107. """
  108. class MapNet(Cell):
  109. def mul(self, x=2, y=4):
  110. return x * y
  111. def construct(self, x, y):
  112. if map(self.mul, x, y) == [8]:
  113. x = y
  114. return x
  115. input_me_x = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  116. input_me_y = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  117. net = MapNet()
  118. with pytest.raises(Exception, match="The length of tuples in Map must the same"):
  119. ret = net((input_me_x, input_me_x), (input_me_y, input_me_y, input_me_y))
  120. print("ret:", ret)
  121. def test_map_param_cast():
  122. """
  123. Feature: Check the ref type when insert auto cast.
  124. Description: Check the ref type when insert auto cast.
  125. Expectation: Check the ref type when insert auto cast.
  126. """
  127. class MapNet(Cell):
  128. def __init__(self):
  129. super().__init__()
  130. self.param = Parameter(Tensor(5, ms.float32), name="param_b")
  131. def construct(self, x):
  132. self.param = x
  133. return self.param
  134. input_me_x = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float64))
  135. net = MapNet()
  136. with pytest.raises(Exception, match="For 'S-Prim-Assign' operator, "
  137. "the type of writable argument is 'float32'"):
  138. ret = net(input_me_x)
  139. print("ret:", ret)