You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ast_replacer.py 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import ast
  16. import re
  17. import inspect
  18. import astunparse
  19. from mindspore import nn
  20. from mindspore.ops import functional as F
  21. from mindspore.rewrite.ast_helpers import AstReplacer
  22. class SimpleNet2(nn.Cell):
  23. def construct(self, x):
  24. return F.add(x, x)
  25. class SimpleNet(nn.Cell):
  26. def __init__(self):
  27. super(SimpleNet, self).__init__()
  28. SimpleNet._get_int()
  29. self.aaa = SimpleNet._get_int()
  30. self.bbb = SimpleNet._get_int() + 1
  31. self.ccc = F.add(SimpleNet._get_int(), 1)
  32. self.ddd = SimpleNet2()
  33. @staticmethod
  34. def _get_int():
  35. return 1
  36. def construct(self, x):
  37. SimpleNet._get_int()
  38. aaa = SimpleNet._get_int()
  39. bbb = SimpleNet._get_int() + aaa
  40. ccc = F.add(SimpleNet._get_int(), bbb)
  41. x = self.ddd(ccc)
  42. return x
  43. def test_replacer():
  44. """
  45. Feature: Class AstReplacer in Package rewrite.
  46. Description:
  47. Use AstReplacer to replace all "SimpleNet" symbol to "SimpleNet2" symbol.
  48. Use AstReplacer to undo all replace.
  49. Expectation: AstReplacer can replace all "SimpleNet" symbol to "SimpleNet2" symbol and restore original ast node.
  50. """
  51. original_code = inspect.getsource(SimpleNet)
  52. assert len(re.findall("SimpleNet", original_code)) == 11
  53. assert len(re.findall("SimpleNet2", original_code)) == 1
  54. ast_root = ast.parse(original_code)
  55. replacer = AstReplacer(ast_root)
  56. replacer.replace_all("SimpleNet", "SimpleNet2")
  57. replaced_code = astunparse.unparse(ast_root)
  58. assert len(re.findall("SimpleNet", replaced_code)) == 11
  59. assert len(re.findall("SimpleNet2", replaced_code)) == 11
  60. replacer.undo_all()
  61. assert len(re.findall("SimpleNet", original_code)) == 11
  62. assert len(re.findall("SimpleNet2", original_code)) == 1