Browse Source

add UT for mindconverter

tags/v1.0.0
lilongfei 5 years ago
parent
commit
b1d81dab57
2 changed files with 61 additions and 2 deletions
  1. +38
    -0
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py
  2. +23
    -2
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py

+ 38
- 0
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_hierarchical_tree.py View File

@@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test Name manager module."""
from unittest import mock, TestCase
from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.hierarchical_tree import HierarchicalTree
from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_node import PyTorchGraphNode


class TestHierarchicalTree(TestCase):
"""Test the class of HierarchicalTree."""

def test_tree_identifier(self):
"""Test tree_identifier"""
tree = HierarchicalTree()
self.assertIsInstance(tree.tree_identifier, str)

@mock.patch(
'mindinsight.mindconverter.graph_based_converter.' \
'third_party_graph.pytorch_graph_node.PyTorchGraphNode._get_raw_params')
def test_insert(self, get_raw_params):
"""Test insert"""
get_raw_params.return_value = []
tree = HierarchicalTree()
pt_node = PyTorchGraphNode()
tree.insert(pt_node, 'ResNet', (1, 3, 224, 224), (1, 64, 112, 112))
self.assertEqual(tree.root, 'ResNet')

+ 23
- 2
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py View File

@@ -14,13 +14,34 @@
# ==============================================================================
"""Test name manager module."""
from unittest import TestCase
from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import NameMgr, GlobalVarNameMgr, \
global_op_namespace


class TestNameMgr(TestCase):
"""Tester of name mgr."""

def test_global_name_mgr(self):
def test_global_get_name_not_in_record(self):
"""Test global name mgr."""
name = GlobalVarNameMgr().get_name("onnx::Conv")
assert isinstance(name, str)

def test_global_get_name_in_record(self):
"""Test global name mgr."""
global_op_namespace['abc'] = 0
name_mgr = GlobalVarNameMgr()
name = name_mgr.get_name('abc')
assert isinstance(name, str)

def test_get_name_not_in_record(self):
"""Test get_name old_name not in self.record"""
name_mgr = NameMgr()
name = name_mgr.get_name('abc')
assert isinstance(name, str)

def test_get_name_in_record(self):
"""Test get_name old_name in self.record"""
name_mgr = NameMgr()
name_mgr.record = {'abc': ['123']}
name = name_mgr.get_name('abc')
assert isinstance(name, str)

Loading…
Cancel
Save