You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_landscape.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test create landscape."""
  16. import os
  17. import shutil
  18. import tempfile
  19. import pytest
  20. from mindspore.common import set_seed
  21. from mindspore import nn
  22. from mindspore.nn.metrics import Loss
  23. from mindspore.train import Model
  24. from mindspore.train.callback import SummaryLandscape
  25. from tests.security_utils import security_off_wrap
  26. from tests.ut.python.train.dataset import create_mnist_dataset, LeNet5
  27. set_seed(1)
  28. _VALUE_CACHE = list()
  29. def get_value():
  30. """Get the value which is added by add_value function."""
  31. global _VALUE_CACHE
  32. value = _VALUE_CACHE
  33. _VALUE_CACHE = list()
  34. return value
  35. def callback_fn():
  36. """A python function job"""
  37. network = LeNet5()
  38. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  39. metrics = {"Loss": Loss()}
  40. model = Model(network, loss, metrics=metrics)
  41. ds_train = create_mnist_dataset("train")
  42. return model, network, ds_train, metrics
  43. class TestLandscape:
  44. """Test the exception parameter for landscape."""
  45. base_summary_dir = ''
  46. def setup_class(self):
  47. """Run before test this class."""
  48. self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
  49. def teardown_class(self):
  50. """Run after test this class."""
  51. if os.path.exists(self.base_summary_dir):
  52. shutil.rmtree(self.base_summary_dir)
  53. def teardown_method(self):
  54. """Run after each test function."""
  55. get_value()
  56. @security_off_wrap
  57. @pytest.mark.parametrize("collect_landscape", [
  58. {
  59. 'landscape_size': None
  60. },
  61. {
  62. 'create_landscape': None
  63. },
  64. {
  65. 'num_samples': None
  66. },
  67. {
  68. 'intervals': None
  69. },
  70. ])
  71. def test_params_gen_landscape_with_multi_process_value_type_error(self, collect_landscape):
  72. """Test the value of gen_landscape_with_multi_process param."""
  73. device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
  74. summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
  75. summary_landscape = SummaryLandscape(summary_dir)
  76. with pytest.raises(TypeError) as exc:
  77. summary_landscape.gen_landscapes_with_multi_process(
  78. callback_fn,
  79. collect_landscape=collect_landscape,
  80. device_ids=[device_id]
  81. )
  82. param_name = list(collect_landscape)[0]
  83. param_value = collect_landscape[param_name]
  84. if param_name in ['landscape_size', 'num_samples']:
  85. expected_type = "['int']"
  86. elif param_name == 'unit':
  87. expected_type = "['str']"
  88. elif param_name == 'create_landscape':
  89. expected_type = "['dict']"
  90. else:
  91. expected_type = "['list']"
  92. expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
  93. f'but got {type(param_value).__name__}.'
  94. assert expected_msg == str(exc.value)