You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fix_bug.py 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_fix_bug """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.common.api import _executor
  21. class assignment1_Net(nn.Cell):
  22. """ assignment1_Net definition """
  23. def __init__(self, number):
  24. super().__init__()
  25. self.number = number
  26. self.relu = nn.ReLU()
  27. def construct(self, x):
  28. y = self.number
  29. for _ in [1, y]:
  30. x = self.relu(x)
  31. return x
  32. class assignment2_Net(nn.Cell):
  33. """ assignment2_Net definition """
  34. def __init__(self, number):
  35. super().__init__()
  36. self.number = number
  37. self.relu = nn.ReLU()
  38. def construct(self, x):
  39. a, b = self.number
  40. for _ in [a, b]:
  41. x = self.relu(x)
  42. return x
  43. def assignment_operator_base(number):
  44. """ assignment_operator_base """
  45. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  46. input_me = Tensor(input_np)
  47. x = number
  48. if isinstance(x, int):
  49. net = assignment1_Net(x)
  50. else:
  51. net = assignment2_Net(x)
  52. _executor.compile(net, input_me)
  53. def test_ME_assignment_operator_0010():
  54. """ test_ME_assignment_operator_0010 """
  55. assignment_operator_base(3)
  56. def test_ME_assignment_operator_0020():
  57. """ test_ME_assignment_operator_0020 """
  58. assignment_operator_base((1, 3))
  59. class unsupported_method_net(nn.Cell):
  60. """ unsupported_method_net definition """
  61. def __init__(self):
  62. super().__init__()
  63. self.relu = nn.ReLU()
  64. def construct(self, x):
  65. with open("a.txt") as f:
  66. f.read()
  67. return x
  68. def test_compile_unspported():
  69. """ test_compile_unspported """
  70. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  71. input_me = Tensor(input_np)
  72. net = unsupported_method_net()
  73. with pytest.raises(RuntimeError):
  74. _executor.compile(net, input_me)