|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import ast
- import re
- import inspect
- import astunparse
- from mindspore import nn
- from mindspore.ops import functional as F
- from mindspore.rewrite.ast_helpers import AstReplacer
-
-
- class SimpleNet2(nn.Cell):
- def construct(self, x):
- return F.add(x, x)
-
-
- class SimpleNet(nn.Cell):
- def __init__(self):
- super(SimpleNet, self).__init__()
- SimpleNet._get_int()
- self.aaa = SimpleNet._get_int()
- self.bbb = SimpleNet._get_int() + 1
- self.ccc = F.add(SimpleNet._get_int(), 1)
- self.ddd = SimpleNet2()
-
- @staticmethod
- def _get_int():
- return 1
-
- def construct(self, x):
- SimpleNet._get_int()
- aaa = SimpleNet._get_int()
- bbb = SimpleNet._get_int() + aaa
- ccc = F.add(SimpleNet._get_int(), bbb)
- x = self.ddd(ccc)
- return x
-
-
- def test_replacer():
- """
- Feature: Class AstReplacer in Package rewrite.
- Description:
- Use AstReplacer to replace all "SimpleNet" symbol to "SimpleNet2" symbol.
- Use AstReplacer to undo all replace.
- Expectation: AstReplacer can replace all "SimpleNet" symbol to "SimpleNet2" symbol and restore original ast node.
- """
-
- original_code = inspect.getsource(SimpleNet)
- assert len(re.findall("SimpleNet", original_code)) == 11
- assert len(re.findall("SimpleNet2", original_code)) == 1
-
- ast_root = ast.parse(original_code)
- replacer = AstReplacer(ast_root)
- replacer.replace_all("SimpleNet", "SimpleNet2")
- replaced_code = astunparse.unparse(ast_root)
- assert len(re.findall("SimpleNet", replaced_code)) == 11
- assert len(re.findall("SimpleNet2", replaced_code)) == 11
-
- replacer.undo_all()
- assert len(re.findall("SimpleNet", original_code)) == 11
- assert len(re.findall("SimpleNet2", original_code)) == 1
|