You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Test util functions used in distribution classes.
  17. """
  18. import numpy as np
  19. import pytest
  20. from mindspore.nn.cell import Cell
  21. from mindspore import context
  22. from mindspore import dtype
  23. from mindspore import Tensor
  24. from mindspore.common.parameter import Parameter
  25. from mindspore.nn.probability.distribution._utils.utils import set_param_type, \
  26. cast_to_tensor, CheckTuple, CheckTensor
  27. def test_set_param_type():
  28. """
  29. Test set_param_type function.
  30. """
  31. tensor_fp16 = Tensor(0.1, dtype=dtype.float16)
  32. tensor_fp32 = Tensor(0.1, dtype=dtype.float32)
  33. tensor_fp64 = Tensor(0.1, dtype=dtype.float64)
  34. tensor_int32 = Tensor(0.1, dtype=dtype.int32)
  35. array_fp32 = np.array(1.0).astype(np.float32)
  36. array_fp64 = np.array(1.0).astype(np.float64)
  37. array_int32 = np.array(1.0).astype(np.int32)
  38. dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32}
  39. dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64}
  40. dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32}
  41. dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32}
  42. dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64}
  43. dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32}
  44. dict7 = {'a': 1.0}
  45. dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0}
  46. dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16}
  47. dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64}
  48. dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64}
  49. ans1 = set_param_type(dict1, dtype.float16)
  50. assert ans1 == dtype.float32
  51. with pytest.raises(TypeError):
  52. set_param_type(dict2, dtype.float32)
  53. ans3 = set_param_type(dict3, dtype.float16)
  54. assert ans3 == dtype.float32
  55. ans4 = set_param_type(dict4, dtype.float16)
  56. assert ans4 == dtype.float32
  57. with pytest.raises(TypeError):
  58. set_param_type(dict5, dtype.float32)
  59. with pytest.raises(TypeError):
  60. set_param_type(dict6, dtype.float32)
  61. ans7 = set_param_type(dict7, dtype.float32)
  62. assert ans7 == dtype.float32
  63. ans8 = set_param_type(dict8, dtype.float32)
  64. assert ans8 == dtype.float32
  65. ans9 = set_param_type(dict9, dtype.float32)
  66. assert ans9 == dtype.float16
  67. ans10 = set_param_type(dict10, dtype.float32)
  68. assert ans10 == dtype.float32
  69. ans11 = set_param_type(dict11, dtype.float32)
  70. assert ans11 == dtype.float32
  71. def test_cast_to_tensor():
  72. """
  73. Test cast_to_tensor.
  74. """
  75. with pytest.raises(ValueError):
  76. cast_to_tensor(None, dtype.float32)
  77. with pytest.raises(TypeError):
  78. cast_to_tensor(True, dtype.float32)
  79. with pytest.raises(TypeError):
  80. cast_to_tensor({'a': 1, 'b': 2}, dtype.float32)
  81. with pytest.raises(TypeError):
  82. cast_to_tensor('tensor', dtype.float32)
  83. ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param'))
  84. assert isinstance(ans1, Parameter)
  85. ans2 = cast_to_tensor(np.array(1.0).astype(np.float32))
  86. assert isinstance(ans2, Tensor)
  87. ans3 = cast_to_tensor([1.0, 2.0])
  88. assert isinstance(ans3, Tensor)
  89. ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32)
  90. assert isinstance(ans4, Tensor)
  91. ans5 = cast_to_tensor(0.1, dtype.float32)
  92. assert isinstance(ans5, Tensor)
  93. ans6 = cast_to_tensor(1, dtype.float32)
  94. assert isinstance(ans6, Tensor)
  95. class Net(Cell):
  96. """
  97. Test class: CheckTuple.
  98. """
  99. def __init__(self, value):
  100. super(Net, self).__init__()
  101. self.checktuple = CheckTuple()
  102. self.value = value
  103. def construct(self, value=None):
  104. if value is None:
  105. return self.checktuple(self.value, 'input')
  106. return self.checktuple(value, 'input')
  107. def test_check_tuple():
  108. """
  109. Test CheckTuple.
  110. """
  111. net1 = Net((1, 2, 3))
  112. ans1 = net1()
  113. assert isinstance(ans1, tuple)
  114. with pytest.raises(TypeError):
  115. net2 = Net('tuple')
  116. net2()
  117. context.set_context(mode=context.GRAPH_MODE)
  118. net3 = Net((1, 2, 3))
  119. ans3 = net3()
  120. assert isinstance(ans3, tuple)
  121. with pytest.raises(TypeError):
  122. net4 = Net('tuple')
  123. net4()
  124. class Net1(Cell):
  125. """
  126. Test class: CheckTensor.
  127. """
  128. def __init__(self, value):
  129. super(Net1, self).__init__()
  130. self.checktensor = CheckTensor()
  131. self.value = value
  132. self.context = context.get_context('mode')
  133. def construct(self, value=None):
  134. value = self.value if value is None else value
  135. if self.context == 0:
  136. self.checktensor(value, 'input')
  137. return value
  138. return self.checktensor(value, 'input')
  139. def test_check_tensor():
  140. """
  141. Test CheckTensor.
  142. """
  143. value = Tensor(0.1, dtype=dtype.float32)
  144. net1 = Net1(value)
  145. ans1 = net1()
  146. assert isinstance(ans1, Tensor)
  147. ans1 = net1(value)
  148. assert isinstance(ans1, Tensor)
  149. with pytest.raises(TypeError):
  150. net2 = Net1('tuple')
  151. net2()
  152. context.set_context(mode=context.GRAPH_MODE)
  153. net3 = Net1(value)
  154. ans3 = net3()
  155. assert isinstance(ans3, Tensor)
  156. ans3 = net3(value)
  157. assert isinstance(ans3, Tensor)
  158. with pytest.raises(TypeError):
  159. net4 = Net1('tuple')
  160. net4()