You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_assert.py 4.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_assert """
  16. import pytest
  17. from mindspore import nn, context, Tensor
  18. context.set_context(mode=context.GRAPH_MODE)
  19. def test_assert1():
  20. """
  21. Feature: support assert
  22. Description: test assert
  23. Expectation: AssertionError
  24. """
  25. class Net(nn.Cell):
  26. def construct(self):
  27. x = 1
  28. assert x == 2
  29. return x
  30. net = Net()
  31. with pytest.raises(AssertionError)as excinfo:
  32. net()
  33. assert "assert x == 2" in str(excinfo.value)
  34. def test_assert2():
  35. """
  36. Feature: support assert
  37. Description: test assert
  38. Expectation: no error
  39. """
  40. class Net(nn.Cell):
  41. def construct(self):
  42. x = 1
  43. assert True
  44. return x
  45. net = Net()
  46. out = net()
  47. assert out == 1
  48. def test_assert3():
  49. """
  50. Feature: support assert
  51. Description: test assert
  52. Expectation: no error
  53. """
  54. class Net(nn.Cell):
  55. def construct(self):
  56. x = 1
  57. assert x in [2, 3, 4]
  58. return x
  59. net = Net()
  60. with pytest.raises(AssertionError) as excinfo:
  61. net()
  62. assert "assert x in [2, 3, 4]" in str(excinfo.value)
  63. def test_assert4():
  64. """
  65. Feature: support assert
  66. Description: test assert
  67. Expectation: no error
  68. """
  69. class Net(nn.Cell):
  70. def construct(self):
  71. x = 1
  72. assert x in [2, 3, 4], "x not in [2, 3, 4]"
  73. return x
  74. net = Net()
  75. with pytest.raises(AssertionError) as excinfo:
  76. net()
  77. assert "x not in [2, 3, 4]" in str(excinfo.value)
  78. assert "assert x in [2, 3, 4]" in str(excinfo.value)
  79. def test_assert5():
  80. """
  81. Feature: support assert
  82. Description: test assert
  83. Expectation: no error
  84. """
  85. class Net(nn.Cell):
  86. def construct(self):
  87. x = 1
  88. assert x in [2, 3, 4], f"%d not in [2, 3, 4]" % x
  89. return x
  90. net = Net()
  91. with pytest.raises(AssertionError) as excinfo:
  92. net()
  93. assert "1 not in [2, 3, 4]" in str(excinfo.value)
  94. assert "assert x in [2, 3, 4]" in str(excinfo.value)
  95. def test_assert6():
  96. """
  97. Feature: support assert
  98. Description: test assert
  99. Expectation: no error
  100. """
  101. class Net(nn.Cell):
  102. def construct(self):
  103. x = 1
  104. assert x in [2, 3, 4], f"{x} not in [2, 3, 4]"
  105. return x
  106. net = Net()
  107. with pytest.raises(AssertionError) as excinfo:
  108. net()
  109. assert "1 not in [2, 3, 4]" in str(excinfo.value)
  110. assert "assert x in [2, 3, 4]" in str(excinfo.value)
  111. def test_assert7():
  112. """
  113. Feature: support assert
  114. Description: test assert
  115. Expectation: no error
  116. """
  117. class Net(nn.Cell):
  118. def construct(self):
  119. x = 1
  120. assert x in [2, 3, 4], "{} not in [2, 3, 4]".format(x)
  121. return x
  122. net = Net()
  123. with pytest.raises(AssertionError) as excinfo:
  124. net()
  125. assert "1 not in [2, 3, 4]" in str(excinfo.value)
  126. assert "assert x in [2, 3, 4]" in str(excinfo.value)
  127. def test_assert8():
  128. """
  129. Feature: support assert
  130. Description: test assert with variable in condition
  131. Expectation: no error
  132. """
  133. class Net(nn.Cell):
  134. def construct(self, x):
  135. assert x == 1
  136. return x
  137. net = Net()
  138. a = Tensor(1)
  139. with pytest.raises(RuntimeError) as excinfo:
  140. net(a)
  141. assert "Currently only supports raise in constant scenarios." in str(excinfo.value)