|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """test create landscape."""
- import os
- import shutil
- import tempfile
- import pytest
-
- from mindspore.common import set_seed
- from mindspore import nn
- from mindspore.nn.metrics import Loss
- from mindspore.train import Model
- from mindspore.train.callback import SummaryLandscape
- from tests.security_utils import security_off_wrap
- from tests.ut.python.train.dataset import create_mnist_dataset, LeNet5
-
- set_seed(1)
-
- _VALUE_CACHE = list()
-
- def get_value():
- """Get the value which is added by add_value function."""
- global _VALUE_CACHE
-
- value = _VALUE_CACHE
- _VALUE_CACHE = list()
- return value
-
-
- def callback_fn():
- """A python function job"""
- network = LeNet5()
- loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
- metrics = {"Loss": Loss()}
- model = Model(network, loss, metrics=metrics)
- ds_train = create_mnist_dataset("train")
- return model, network, ds_train, metrics
-
-
- class TestLandscape:
- """Test the exception parameter for landscape."""
- base_summary_dir = ''
-
- def setup_class(self):
- """Run before test this class."""
- self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
-
- def teardown_class(self):
- """Run after test this class."""
- if os.path.exists(self.base_summary_dir):
- shutil.rmtree(self.base_summary_dir)
-
- def teardown_method(self):
- """Run after each test function."""
- get_value()
-
- @security_off_wrap
- @pytest.mark.parametrize("collect_landscape", [
- {
- 'landscape_size': None
- },
- {
- 'create_landscape': None
- },
- {
- 'num_samples': None
- },
- {
- 'intervals': None
- },
- ])
- def test_params_gen_landscape_with_multi_process_value_type_error(self, collect_landscape):
- """Test the value of gen_landscape_with_multi_process param."""
- device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
- summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
- summary_landscape = SummaryLandscape(summary_dir)
- with pytest.raises(TypeError) as exc:
- summary_landscape.gen_landscapes_with_multi_process(
- callback_fn,
- collect_landscape=collect_landscape,
- device_ids=[device_id]
- )
- param_name = list(collect_landscape)[0]
- param_value = collect_landscape[param_name]
- if param_name in ['landscape_size', 'num_samples']:
- expected_type = "['int']"
- elif param_name == 'unit':
- expected_type = "['str']"
- elif param_name == 'create_landscape':
- expected_type = "['dict']"
- else:
- expected_type = "['list']"
- expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
- f'but got {type(param_value).__name__}.'
- assert expected_msg == str(exc.value)
|