Browse Source

add UT for optimizer and the get_flattened_lineage function

tags/v1.0.0
shenghong 5 years ago
parent
commit
517f7c1ff6
4 changed files with 120 additions and 3 deletions
  1. +56
    -2
      tests/st/func/lineagemgr/test_model.py
  2. +15
    -1
      tests/ut/lineagemgr/test_model.py
  3. +14
    -0
      tests/ut/optimizer/__init__.py
  4. +35
    -0
      tests/ut/optimizer/test_utils.py

+ 56
- 2
tests/st/func/lineagemgr/test_model.py View File

@@ -21,14 +21,24 @@ Usage:
pytest lineagemgr
"""
import os
from unittest import TestCase
from unittest import TestCase, mock
import numpy as np
import pytest

from mindinsight.lineagemgr.model import filter_summary_lineage, get_summary_lineage
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError,
LineageParamTypeError, LineageParamValueError,
LineageSearchConditionParamError)
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
from mindinsight.lineagemgr.model import get_flattened_lineage
from mindspore.application.model_zoo.resnet import ResNet
from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import MindDataset
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import RunContext
from ....utils.lineage_writer.model_lineage import AnalyzeObject, TrainLineage

from .conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
from ....ut.lineagemgr.querier import event_data
from ....utils.tools import assert_equal_lineages
@@ -814,3 +824,47 @@ class TestModelApi(TestCase):
BASE_SUMMARY_DIR,
search_condition
)


class TestLineageTable:
"""Test lineage table ."""

@classmethod
def setup_class(cls):
"""Setup method"""
cls.run_context = dict(
train_network=ResNet(),
loss_fn=SoftmaxCrossEntropyWithLogits(),
net_outputs=Tensor(np.array([0.03])),
optimizer=Momentum(Tensor(0.12)),
train_dataset=MindDataset(dataset_size=32),
epoch_num=10,
cur_step_num=320,
parallel_mode="stand_alone",
device_number=2,
batch_num=32
)
cls.user_defined_info = {"info": "info1", "version": "v1"}

@pytest.mark.scene_train(2)
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascned_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch.object(AnalyzeObject, 'get_file_size')
def test_training_end(self):
"""Test the function of get_flattened_lineage"""
train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info)

train_callback.initial_learning_rate = 0.12
train_callback.begin(RunContext(self.run_context))
train_callback.end(RunContext(self.run_context))

summary_base_dir = SUMMARY_DIR
datamanager = data_manager.DataManager(summary_base_dir)
datamanager.register_brief_cache_item_updater(LineageCacheItemUpdater())
datamanager.start_load_data().join()

data = get_flattened_lineage(datamanager)
assert data.get('[U]info') == ['info1']

+ 15
- 1
tests/ut/lineagemgr/test_model.py View File

@@ -16,12 +16,14 @@
from unittest import TestCase, mock
from unittest.mock import MagicMock

from mindinsight.lineagemgr.model import get_summary_lineage, filter_summary_lineage, _convert_relative_path_to_abspath
from mindinsight.lineagemgr.model import get_summary_lineage, filter_summary_lineage, \
_convert_relative_path_to_abspath, get_flattened_lineage
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamSummaryPathError, \
LineageFileNotFoundError, LineageSummaryParseException, LineageQuerierParamException, \
LineageQuerySummaryDataError, LineageSearchConditionParamError, LineageParamTypeError, \
LineageParamValueError
from mindinsight.lineagemgr.common.path_parser import SummaryPathParser
from ...st.func.lineagemgr.test_model import LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2


class TestModel(TestCase):
@@ -242,3 +244,15 @@ class TestFilterAPI(TestCase):
None,
'/path/to/summary/dir'
)

@mock.patch('mindinsight.lineagemgr.model.filter_summary_lineage')
def test_get_lineage_table(self, mock_filter_summary_lineage):
"""Test get_flattened_lineage with valid param."""
mock_data = {
'object': [LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2]
}
mock_datamanager = MagicMock()
mock_datamanager.summary_base_dir = '/tmp/'
mock_filter_summary_lineage.return_value = mock_data
result = get_flattened_lineage(mock_datamanager, None)
assert result.get('[U]info') == ['info1', None]

+ 14
- 0
tests/ut/optimizer/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 35
- 0
tests/ut/optimizer/test_utils.py View File

@@ -0,0 +1,35 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the model module."""
import numpy as np
import pytest

from mindinsight.optimizer.utils.utils import is_simple_numpy_number, calc_histogram


def test_is_simple_numpy_number():
assert is_simple_numpy_number(np.int8)
assert is_simple_numpy_number(np.int16)
assert is_simple_numpy_number(np.float)
assert not is_simple_numpy_number(str)


def test_calc_histogram():
"""Test calc_histogram function"""
data = np.array([2, 2, 3, 4, 5])
output = calc_histogram(data)
assert output[0][1] == pytest.approx(0.6, 1e-6)
assert output[1][1] == pytest.approx(0.6, 1e-6)
assert output[0][2] == pytest.approx(2.0, 1e-6)

Loading…
Cancel
Save