You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_flatten_recursive_stmt.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import ast
  16. import inspect
  17. from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
  18. from mindspore.rewrite.ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
  19. class Network(Cell):
  20. def __init__(self):
  21. super().__init__()
  22. self.conv = Conv2d(16, 16, 3)
  23. self.bn = BatchNorm2d(16)
  24. self.relu1 = ReLU()
  25. self.relu2 = ReLU()
  26. self.relu3 = ReLU()
  27. def construct(self, x):
  28. x = self.conv(x + 1)
  29. x = x + 1 * 5 + 4 / 2 + self.bn(x)
  30. self.relu1(x * 5)
  31. x = self.relu2(x + 1)
  32. x = True and x or x
  33. x = self.relu3(x)
  34. return x + 3
  35. def _get_ast():
  36. source = inspect.getsource(Network)
  37. return ast.parse(source)
  38. def test_flatten():
  39. """
  40. Feature: Class FlattenRecursiveStmt.
  41. Description: Apply FlattenRecursiveStmt on a simple network.
  42. Expectation: Success.
  43. """
  44. ast_node = _get_ast()
  45. frs = FlattenRecursiveStmt()
  46. frs.transform(ast_node)
  47. assert len(ast_node.body) == 1
  48. ast_class = ast_node.body[0]
  49. assert isinstance(ast_class, ast.ClassDef)
  50. assert len(ast_class.body) == 2
  51. ast_init_func = ast_class.body[0]
  52. assert isinstance(ast_init_func, ast.FunctionDef)
  53. assert len(ast_init_func.body) == 6
  54. ast_construct_func = ast_class.body[1]
  55. assert isinstance(ast_construct_func, ast.FunctionDef)
  56. assert len(ast_construct_func.body) == 17