You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fix_bug.py 4.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_fix_bug """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import operations as P
  22. from mindspore.common import dtype as ms
  23. from mindspore.common.api import _executor
  24. class assignment1_Net(nn.Cell):
  25. """ assignment1_Net definition """
  26. def __init__(self, number):
  27. super().__init__()
  28. self.number = number
  29. self.relu = nn.ReLU()
  30. def construct(self, x):
  31. y = self.number
  32. for _ in [1, y]:
  33. x = self.relu(x)
  34. return x
  35. class assignment2_Net(nn.Cell):
  36. """ assignment2_Net definition """
  37. def __init__(self, number):
  38. super().__init__()
  39. self.number = number
  40. self.relu = nn.ReLU()
  41. def construct(self, x):
  42. a, b = self.number
  43. for _ in [a, b]:
  44. x = self.relu(x)
  45. return x
  46. def assignment_operator_base(number):
  47. """ assignment_operator_base """
  48. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  49. input_me = Tensor(input_np)
  50. x = number
  51. if isinstance(x, int):
  52. net = assignment1_Net(x)
  53. else:
  54. net = assignment2_Net(x)
  55. _executor.compile(net, input_me)
  56. def test_ME_assignment_operator_0010():
  57. """ test_ME_assignment_operator_0010 """
  58. assignment_operator_base(3)
  59. def test_ME_assignment_operator_0020():
  60. """ test_ME_assignment_operator_0020 """
  61. assignment_operator_base((1, 3))
  62. class unsupported_method_net(nn.Cell):
  63. """ unsupported_method_net definition """
  64. def __init__(self):
  65. super().__init__()
  66. self.relu = nn.ReLU()
  67. def construct(self, x):
  68. with open("a.txt") as f:
  69. f.read()
  70. return x
  71. def test_compile_unspported():
  72. """ test_compile_unspported """
  73. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  74. input_me = Tensor(input_np)
  75. net = unsupported_method_net()
  76. with pytest.raises(RuntimeError):
  77. _executor.compile(net, input_me)
  78. def test_parser_map_0002():
  79. class NetMap0002(nn.Cell):
  80. def __init__(self):
  81. super().__init__()
  82. self.relu = nn.ReLU()
  83. self.hypermap = C.Map()
  84. def mul(self, x=2, y=4):
  85. return x * y
  86. def construct(self, x):
  87. if map(self.mul) == 8:
  88. x = self.relu(x)
  89. return x
  90. input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32)
  91. input_me_x = Tensor(input_np_x)
  92. net = NetMap0002()
  93. with pytest.raises(TypeError):
  94. net(input_me_x)
  95. def test_fix_expanddims_loss_scale():
  96. class ControlOneIfOneScaleOneScale(nn.Cell):
  97. def __init__(self):
  98. super().__init__()
  99. self.op = P.ExpandDims()
  100. def construct(self, x, y, data):
  101. if x > y:
  102. out = 1
  103. else:
  104. out = 2
  105. if x > y:
  106. out = self.op(data, out)
  107. else:
  108. out = self.op(data, out)
  109. return out
  110. net = ControlOneIfOneScaleOneScale()
  111. x = Tensor(1, ms.float32)
  112. y = Tensor(0, ms.float32)
  113. input_shape = (1024, 512, 7, 7)
  114. input_data = np.random.randn(*input_shape).astype(np.float32)
  115. net = ControlOneIfOneScaleOneScale()
  116. net(x, y, Tensor(input_data))