|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import ast
- import inspect
-
- from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
- from mindspore.rewrite.ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
-
-
- class Network(Cell):
- def __init__(self):
- super().__init__()
- self.conv = Conv2d(16, 16, 3)
- self.bn = BatchNorm2d(16)
- self.relu1 = ReLU()
- self.relu2 = ReLU()
- self.relu3 = ReLU()
-
- def construct(self, x):
- x = self.conv(x + 1)
- x = x + 1 * 5 + 4 / 2 + self.bn(x)
- self.relu1(x * 5)
- x = self.relu2(x + 1)
- x = True and x or x
- x = self.relu3(x)
- return x + 3
-
-
- def _get_ast():
- source = inspect.getsource(Network)
- return ast.parse(source)
-
-
- def test_flatten():
- """
- Feature: Class FlattenRecursiveStmt.
- Description: Apply FlattenRecursiveStmt on a simple network.
- Expectation: Success.
- """
- ast_node = _get_ast()
- frs = FlattenRecursiveStmt()
- frs.transform(ast_node)
- assert len(ast_node.body) == 1
- ast_class = ast_node.body[0]
- assert isinstance(ast_class, ast.ClassDef)
- assert len(ast_class.body) == 2
- ast_init_func = ast_class.body[0]
- assert isinstance(ast_init_func, ast.FunctionDef)
- assert len(ast_init_func.body) == 6
- ast_construct_func = ast_class.body[1]
- assert isinstance(ast_construct_func, ast.FunctionDef)
- assert len(ast_construct_func.body) == 17
|