| @@ -21,14 +21,24 @@ Usage: | |||
| pytest lineagemgr | |||
| """ | |||
| import os | |||
| from unittest import TestCase | |||
| from unittest import TestCase, mock | |||
| import numpy as np | |||
| import pytest | |||
| from mindinsight.lineagemgr.model import filter_summary_lineage, get_summary_lineage | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError, | |||
| LineageParamTypeError, LineageParamValueError, | |||
| LineageSearchConditionParamError) | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater | |||
| from mindinsight.lineagemgr.model import get_flattened_lineage | |||
| from mindspore.application.model_zoo.resnet import ResNet | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.dataset.engine import MindDataset | |||
| from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits | |||
| from mindspore.train.callback import RunContext | |||
| from ....utils.lineage_writer.model_lineage import AnalyzeObject, TrainLineage | |||
| from .conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 | |||
| from ....ut.lineagemgr.querier import event_data | |||
| from ....utils.tools import assert_equal_lineages | |||
| @@ -814,3 +824,47 @@ class TestModelApi(TestCase): | |||
| BASE_SUMMARY_DIR, | |||
| search_condition | |||
| ) | |||
| class TestLineageTable: | |||
| """Test lineage table .""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Setup method""" | |||
| cls.run_context = dict( | |||
| train_network=ResNet(), | |||
| loss_fn=SoftmaxCrossEntropyWithLogits(), | |||
| net_outputs=Tensor(np.array([0.03])), | |||
| optimizer=Momentum(Tensor(0.12)), | |||
| train_dataset=MindDataset(dataset_size=32), | |||
| epoch_num=10, | |||
| cur_step_num=320, | |||
| parallel_mode="stand_alone", | |||
| device_number=2, | |||
| batch_num=32 | |||
| ) | |||
| cls.user_defined_info = {"info": "info1", "version": "v1"} | |||
| @pytest.mark.scene_train(2) | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascned_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| @mock.patch.object(AnalyzeObject, 'get_file_size') | |||
| def test_training_end(self): | |||
| """Test the function of get_flattened_lineage""" | |||
| train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info) | |||
| train_callback.initial_learning_rate = 0.12 | |||
| train_callback.begin(RunContext(self.run_context)) | |||
| train_callback.end(RunContext(self.run_context)) | |||
| summary_base_dir = SUMMARY_DIR | |||
| datamanager = data_manager.DataManager(summary_base_dir) | |||
| datamanager.register_brief_cache_item_updater(LineageCacheItemUpdater()) | |||
| datamanager.start_load_data().join() | |||
| data = get_flattened_lineage(datamanager) | |||
| assert data.get('[U]info') == ['info1'] | |||
| @@ -16,12 +16,14 @@ | |||
| from unittest import TestCase, mock | |||
| from unittest.mock import MagicMock | |||
| from mindinsight.lineagemgr.model import get_summary_lineage, filter_summary_lineage, _convert_relative_path_to_abspath | |||
| from mindinsight.lineagemgr.model import get_summary_lineage, filter_summary_lineage, \ | |||
| _convert_relative_path_to_abspath, get_flattened_lineage | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamSummaryPathError, \ | |||
| LineageFileNotFoundError, LineageSummaryParseException, LineageQuerierParamException, \ | |||
| LineageQuerySummaryDataError, LineageSearchConditionParamError, LineageParamTypeError, \ | |||
| LineageParamValueError | |||
| from mindinsight.lineagemgr.common.path_parser import SummaryPathParser | |||
| from ...st.func.lineagemgr.test_model import LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 | |||
| class TestModel(TestCase): | |||
| @@ -242,3 +244,15 @@ class TestFilterAPI(TestCase): | |||
| None, | |||
| '/path/to/summary/dir' | |||
| ) | |||
| @mock.patch('mindinsight.lineagemgr.model.filter_summary_lineage') | |||
| def test_get_lineage_table(self, mock_filter_summary_lineage): | |||
| """Test get_flattened_lineage with valid param.""" | |||
| mock_data = { | |||
| 'object': [LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2] | |||
| } | |||
| mock_datamanager = MagicMock() | |||
| mock_datamanager.summary_base_dir = '/tmp/' | |||
| mock_filter_summary_lineage.return_value = mock_data | |||
| result = get_flattened_lineage(mock_datamanager, None) | |||
| assert result.get('[U]info') == ['info1', None] | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test the model module.""" | |||
| import numpy as np | |||
| import pytest | |||
| from mindinsight.optimizer.utils.utils import is_simple_numpy_number, calc_histogram | |||
| def test_is_simple_numpy_number(): | |||
| assert is_simple_numpy_number(np.int8) | |||
| assert is_simple_numpy_number(np.int16) | |||
| assert is_simple_numpy_number(np.float) | |||
| assert not is_simple_numpy_number(str) | |||
| def test_calc_histogram(): | |||
| """Test calc_histogram function""" | |||
| data = np.array([2, 2, 3, 4, 5]) | |||
| output = calc_histogram(data) | |||
| assert output[0][1] == pytest.approx(0.6, 1e-6) | |||
| assert output[1][1] == pytest.approx(0.6, 1e-6) | |||
| assert output[0][2] == pytest.approx(2.0, 1e-6) | |||