From: @ouwenchang Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -199,7 +199,7 @@ class ImageClassificationRunner: | |||
| """ | |||
| self._verify_data_n_settings(check_all=True) | |||
| with SummaryRecord(self._summary_dir) as summary: | |||
| with SummaryRecord(self._summary_dir, raise_exception=True) as summary: | |||
| print("Start running and writing......") | |||
| begin = time() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -207,7 +207,9 @@ class SummaryCollector(Callback): | |||
| self._dataset_sink_mode = True | |||
| def __enter__(self): | |||
| self._record = SummaryRecord(log_dir=self._summary_dir, max_file_size=self._max_file_size) | |||
| self._record = SummaryRecord(log_dir=self._summary_dir, | |||
| max_file_size=self._max_file_size, | |||
| raise_exception=False) | |||
| self._first_step, self._dataset_sink_mode = True, True | |||
| return self | |||
| @@ -319,7 +321,14 @@ class SummaryCollector(Callback): | |||
| f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}') | |||
| if 'histogram_regular' in specified_data: | |||
| check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None))) | |||
| regular = specified_data.get('histogram_regular') | |||
| check_value_type('histogram_regular', regular, (str, type(None))) | |||
| if isinstance(regular, str): | |||
| try: | |||
| re.match(regular, '') | |||
| except re.error as exc: | |||
| raise ValueError(f'For `collect_specified_data`, the value of `histogram_regular` ' | |||
| f'is not a valid regular expression. Detail: {str(exc)}.') | |||
| bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'} | |||
| for item in bool_items: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -58,14 +58,17 @@ class WriterPool(ctx.Process): | |||
| Args: | |||
| base_dir (str): The base directory to hold all the files. | |||
| max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes. | |||
| raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs | |||
| in recording data. Default: False, this means that error logs are printed and no exception is thrown. | |||
| filedict (dict): The mapping from plugin to filename. | |||
| """ | |||
| def __init__(self, base_dir, max_file_size, **filedict) -> None: | |||
| def __init__(self, base_dir, max_file_size, raise_exception=False, **filedict) -> None: | |||
| super().__init__() | |||
| self._base_dir, self._filedict = base_dir, filedict | |||
| self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None | |||
| self._max_file_size = max_file_size | |||
| self._raise_exception = raise_exception | |||
| self.start() | |||
| def run(self): | |||
| @@ -124,8 +127,14 @@ class WriterPool(ctx.Process): | |||
| for writer in self._writers[:]: | |||
| try: | |||
| writer.write(plugin, data) | |||
| except RuntimeError as e: | |||
| logger.warning(e.args[0]) | |||
| except RuntimeError as exc: | |||
| logger.error(str(exc)) | |||
| self._writers.remove(writer) | |||
| writer.close() | |||
| if self._raise_exception: | |||
| raise | |||
| except RuntimeWarning as exc: | |||
| logger.warning(str(exc)) | |||
| self._writers.remove(writer) | |||
| writer.close() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,6 +20,7 @@ import threading | |||
| from collections import defaultdict | |||
| from mindspore import log as logger | |||
| from mindspore.nn import Cell | |||
| from ..._c_expression import Tensor | |||
| from ..._checkparam import Validator | |||
| @@ -29,7 +30,7 @@ from ._explain_adapter import check_explain_proto | |||
| from ._writer_pool import WriterPool | |||
| # for the moment, this lock is for caution's sake, | |||
| # there are actually no any concurrencies happening. | |||
| # there are actually no any concurrences happening. | |||
| _summary_lock = threading.Lock() | |||
| # cache the summary data | |||
| _summary_tensor_cache = {} | |||
| @@ -56,10 +57,6 @@ def _get_summary_tensor_data(): | |||
| return data | |||
| def _dictlist(): | |||
| return defaultdict(list) | |||
| class SummaryRecord: | |||
| """ | |||
| SummaryRecord is used to record the summary data and lineage data. | |||
| @@ -80,12 +77,13 @@ class SummaryRecord: | |||
| file_prefix (str): The prefix of file. Default: "events". | |||
| file_suffix (str): The suffix of file. Default: "_MS". | |||
| network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. | |||
| max_file_size (Optional[int]): The maximum size of each file that can be written to disk (in bytes). \ | |||
| max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). \ | |||
| Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. | |||
| raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs | |||
| in recording data. Default: False, this means that error logs are printed and no exception is thrown. | |||
| Raises: | |||
| TypeError: If the type of `max_file_size` is not int, or the type of `file_prefix` or `file_suffix` is not str. | |||
| RuntimeError: If the log_dir is not a normalized absolute path name. | |||
| TypeError: If the parameter type is incorrect. | |||
| Examples: | |||
| >>> # use in with statement to auto close | |||
| @@ -100,10 +98,11 @@ class SummaryRecord: | |||
| ... summary_record.close() | |||
| """ | |||
| def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", network=None, max_file_size=None): | |||
| def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", | |||
| network=None, max_file_size=None, raise_exception=False): | |||
| self._closed, self._event_writer = False, None | |||
| self._mode, self._data_pool = 'train', _dictlist() | |||
| self._mode, self._data_pool = 'train', defaultdict(list) | |||
| Validator.check_str_by_regular(file_prefix) | |||
| Validator.check_str_by_regular(file_suffix) | |||
| @@ -120,6 +119,8 @@ class SummaryRecord: | |||
| logger.warning("The 'max_file_size' should be greater than 0.") | |||
| max_file_size = None | |||
| Validator.check_value_type(arg_name='raise_exception', arg_value=raise_exception, valid_types=bool) | |||
| self.prefix = file_prefix | |||
| self.suffix = file_suffix | |||
| self.network = network | |||
| @@ -127,16 +128,15 @@ class SummaryRecord: | |||
| # create the summary writer file | |||
| self.event_file_name = get_event_file_name(self.prefix, self.suffix) | |||
| try: | |||
| self.full_file_name = os.path.join(self.log_path, self.event_file_name) | |||
| except Exception as ex: | |||
| raise RuntimeError(ex) | |||
| self.full_file_name = os.path.join(self.log_path, self.event_file_name) | |||
| filename_dict = dict(summary=self.full_file_name, | |||
| lineage=get_event_file_name(self.prefix, '_lineage'), | |||
| explainer=get_event_file_name(self.prefix, '_explain')) | |||
| self._event_writer = WriterPool(log_dir, | |||
| max_file_size, | |||
| summary=self.full_file_name, | |||
| lineage=get_event_file_name(self.prefix, '_lineage'), | |||
| explainer=get_event_file_name(self.prefix, '_explain')) | |||
| raise_exception, | |||
| **filename_dict) | |||
| _get_summary_tensor_data() | |||
| atexit.register(self.close) | |||
| @@ -195,8 +195,8 @@ class SummaryRecord: | |||
| - The data type of value should be a 'Explain' object when the plugin is 'explainer', | |||
| see mindspore/ccsrc/summary.proto. | |||
| Raises: | |||
| ValueError: When the name is not valid. | |||
| TypeError: When the value is not a Tensor. | |||
| ValueError: If the parameter value is invalid. | |||
| TypeError: If the parameter type is error. | |||
| Examples: | |||
| >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| @@ -238,6 +238,10 @@ class SummaryRecord: | |||
| Returns: | |||
| bool, whether the record process is successful or not. | |||
| Raises: | |||
| TypeError: If the parameter type is error. | |||
| RuntimeError: If the disk space is insufficient. | |||
| Examples: | |||
| >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| ... summary_record.record(step=2) | |||
| @@ -245,11 +249,12 @@ class SummaryRecord: | |||
| True | |||
| """ | |||
| logger.debug("SummaryRecord step is %r.", step) | |||
| Validator.check_value_type(arg_name='step', arg_value=step, valid_types=int) | |||
| Validator.check_value_type(arg_name='train_network', arg_value=train_network, valid_types=[Cell, type(None)]) | |||
| if self._closed: | |||
| logger.error("The record writer is closed.") | |||
| return False | |||
| if not isinstance(step, int) or isinstance(step, bool): | |||
| raise ValueError("`step` should be int") | |||
| # Set the current summary of train step | |||
| if self.network is not None and not self.has_graph: | |||
| graph_proto = self.network.get_func_graph_proto() | |||
| @@ -294,7 +299,7 @@ class SummaryRecord: | |||
| value['step'] = step | |||
| return self._data_pool | |||
| finally: | |||
| self._data_pool = _dictlist() | |||
| self._data_pool = defaultdict(list) | |||
| @property | |||
| def log_dir(self): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -60,8 +60,8 @@ class BaseWriter: | |||
| self._max_file_size -= required_length | |||
| self.writer.Write(data) | |||
| else: | |||
| raise RuntimeError(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, " | |||
| f"but the '{self._filepath}' requires to write {required_length} bytes.") | |||
| raise RuntimeWarning(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, " | |||
| f"but the '{self._filepath}' requires to write {required_length} bytes.") | |||
| def flush(self): | |||
| """Flush the writer.""" | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -12,12 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| @File : test_image_summary.py | |||
| @Author: | |||
| @Date : 2019-07-4 | |||
| @Desc : test summary function | |||
| """ | |||
| """test_image_summary""" | |||
| import logging | |||
| import os | |||
| import numpy as np | |||
| @@ -70,23 +65,14 @@ def get_test_data(step): | |||
| # Test: call method on parse graph code | |||
| def test_image_summary_sample(): | |||
| """ test_image_summary_sample """ | |||
| log.debug("begin test_image_summary_sample") | |||
| # step 0: create the thread | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| for i in range(1, 5): | |||
| test_data = get_test_data(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| test_writer.flush() | |||
| # step 3: send the event to mq | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_image_summary_sample") | |||
| class Net(nn.Cell): | |||
| """ Net definition """ | |||
| @@ -175,23 +161,11 @@ class ImageSummaryCallback(Callback): | |||
| def test_image_summary_train(): | |||
| """ test_image_summary_train """ | |||
| dataset = get_dataset() | |||
| log.debug("begin test_image_summary_sample") | |||
| # step 0: create the thread | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| model = get_model() | |||
| callback = ImageSummaryCallback(test_writer) | |||
| model.train(2, dataset, callbacks=[callback]) | |||
| # step 3: send the event to mq | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_image_summary_sample") | |||
| def test_image_summary_data(): | |||
| """ test_image_summary_data """ | |||
| @@ -207,13 +181,6 @@ def test_image_summary_data(): | |||
| test_data_list.append(dct) | |||
| i += 1 | |||
| log.debug("begin test_image_summary_sample") | |||
| # step 0: create the thread | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| _cache_summary_tensor_data(test_data_list) | |||
| test_writer.record(1) | |||
| log.debug("finished test_image_summary_sample") | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -12,17 +12,12 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| @File : test_summary.py | |||
| @Author: | |||
| @Date : 2019-07-4 | |||
| @Desc : test summary function | |||
| """ | |||
| import logging | |||
| """Test summary.""" | |||
| import os | |||
| import random | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore.common.tensor import Tensor | |||
| @@ -32,9 +27,6 @@ from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary | |||
| CUR_DIR = os.getcwd() | |||
| SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" | |||
| log = logging.getLogger("test") | |||
| log.setLevel(level=logging.ERROR) | |||
| def get_test_data(step): | |||
| """ get_test_data """ | |||
| @@ -58,26 +50,14 @@ def get_test_data(step): | |||
| return test_data_list | |||
| # Test 1: summary sample of scalar | |||
| def test_scalar_summary_sample(): | |||
| """ test_scalar_summary_sample """ | |||
| log.debug("begin test_scalar_summary_sample") | |||
| # step 0: create the thread | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| for i in range(1, 500): | |||
| for i in range(1, 5): | |||
| test_data = get_test_data(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| # step 3: send the event to mq | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_scalar_summary_sample") | |||
| def get_test_data_shape_1(step): | |||
| """ get_test_data_shape_1 """ | |||
| @@ -104,23 +84,12 @@ def get_test_data_shape_1(step): | |||
| # Test: shape = (1,) | |||
| def test_scalar_summary_sample_with_shape_1(): | |||
| """ test_scalar_summary_sample_with_shape_1 """ | |||
| log.debug("begin test_scalar_summary_sample_with_shape_1") | |||
| # step 0: create the thread | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| for i in range(1, 100): | |||
| test_data = get_test_data_shape_1(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| # step 3: send the event to mq | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_scalar_summary_sample") | |||
| # Test: test with ge | |||
| class SummaryDemo(nn.Cell): | |||
| @@ -143,13 +112,7 @@ class SummaryDemo(nn.Cell): | |||
| def test_scalar_summary_with_ge(): | |||
| """ test_scalar_summary_with_ge """ | |||
| log.debug("begin test_scalar_summary_with_ge") | |||
| # step 0: create the thread | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net = SummaryDemo() | |||
| net.set_train() | |||
| @@ -161,45 +124,17 @@ def test_scalar_summary_with_ge(): | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| log.debug("finished test_scalar_summary_with_ge") | |||
| # test the problem of two consecutive use cases going wrong | |||
| def test_scalar_summary_with_ge_2(): | |||
| """ test_scalar_summary_with_ge_2 """ | |||
| log.debug("begin test_scalar_summary_with_ge_2") | |||
| # step 0: create the thread | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net = SummaryDemo() | |||
| net.set_train() | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| log.debug("finished test_scalar_summary_with_ge_2") | |||
| def test_validate(): | |||
| with SummaryRecord(SUMMARY_DIR) as sr: | |||
| sr.record(1) | |||
| with pytest.raises(ValueError): | |||
| sr.record(False) | |||
| with pytest.raises(ValueError): | |||
| sr.record(2.0) | |||
| with pytest.raises(ValueError): | |||
| sr.record((1, 3)) | |||
| with pytest.raises(ValueError): | |||
| sr.record([2, 3]) | |||
| with pytest.raises(ValueError): | |||
| sr.record("str") | |||
| with pytest.raises(ValueError): | |||
| sr.record(sr) | |||
| @@ -1,133 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| @File : test_summary_abnormal_input.py | |||
| @Author: | |||
| @Date : 2019-08-5 | |||
| @Desc : test summary function of abnormal input | |||
| """ | |||
| import logging | |||
| import os | |||
| import numpy as np | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| CUR_DIR = os.getcwd() | |||
| SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" | |||
| log = logging.getLogger("test") | |||
| log.setLevel(level=logging.ERROR) | |||
| def get_test_data(step): | |||
| """ get_test_data """ | |||
| test_data_list = [] | |||
| tag1 = "x1[:Scalar]" | |||
| tag2 = "x2[:Scalar]" | |||
| np1 = np.array(step + 1).astype(np.float32) | |||
| np2 = np.array(step + 2).astype(np.float32) | |||
| dict1 = {} | |||
| dict1["name"] = tag1 | |||
| dict1["data"] = Tensor(np1) | |||
| dict2 = {} | |||
| dict2["name"] = tag2 | |||
| dict2["data"] = Tensor(np2) | |||
| test_data_list.append(dict1) | |||
| test_data_list.append(dict2) | |||
| return test_data_list | |||
| # Test: call method on parse graph code | |||
| def test_summaryrecord_input_null_string(): | |||
| log.debug("begin test_summaryrecord_input_null_string") | |||
| # step 0: create the thread | |||
| try: | |||
| with SummaryRecord(""): | |||
| pass | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("finished test_summaryrecord_input_null_string") | |||
| def test_summaryrecord_input_None(): | |||
| log.debug("begin test_summaryrecord_input_None") | |||
| # step 0: create the thread | |||
| try: | |||
| with SummaryRecord(None): | |||
| pass | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("finished test_summaryrecord_input_None") | |||
| def test_summaryrecord_input_relative_dir_1(): | |||
| log.debug("begin test_summaryrecord_input_relative_dir_1") | |||
| # step 0: create the thread | |||
| try: | |||
| with SummaryRecord("./test_temp_summary_event_file/"): | |||
| pass | |||
| except: | |||
| assert False | |||
| else: | |||
| assert True | |||
| log.debug("finished test_summaryrecord_input_relative_dir_1") | |||
| def test_summaryrecord_input_relative_dir_2(): | |||
| log.debug("begin test_summaryrecord_input_relative_dir_2") | |||
| # step 0: create the thread | |||
| try: | |||
| with SummaryRecord("../summary/"): | |||
| pass | |||
| except: | |||
| assert False | |||
| else: | |||
| assert True | |||
| log.debug("finished test_summaryrecord_input_relative_dir_2") | |||
| def test_summaryrecord_input_invalid_type_dir(): | |||
| log.debug("begin test_summaryrecord_input_invalid_type_dir") | |||
| # step 0: create the thread | |||
| try: | |||
| with SummaryRecord(32): | |||
| pass | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("finished test_summaryrecord_input_invalid_type_dir") | |||
| def test_mulit_layer_directory(): | |||
| log.debug("begin test_mulit_layer_directory") | |||
| # step 0: create the thread | |||
| try: | |||
| with SummaryRecord("./test_temp_summary_event_file/test/t1/"): | |||
| pass | |||
| except: | |||
| assert False | |||
| else: | |||
| assert True | |||
| log.debug("finished test_mulit_layer_directory") | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -187,6 +187,14 @@ class TestSummaryCollector: | |||
| assert expected_msg == str(exc.value) | |||
| def test_params_with_histogram_regular_value_error(self): | |||
| """Test histogram regular.""" | |||
| summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | |||
| with pytest.raises(ValueError) as exc: | |||
| SummaryCollector(summary_dir, collect_specified_data={'histogram_regular': '*'}) | |||
| assert 'For `collect_specified_data`, the value of `histogram_regular`' in str(exc.value) | |||
| def test_params_with_collect_specified_data_unexpected_key(self): | |||
| """Test the collect_specified_data parameter with unexpected key.""" | |||
| summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | |||
| @@ -260,7 +268,7 @@ class TestSummaryCollector: | |||
| cb_params.train_dataset_element = image_data | |||
| with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: | |||
| summary_collector._collect_input_data(cb_params) | |||
| # Note Here need to asssert the result and expected data | |||
| # Note Here need to assert the result and expected data | |||
| @mock.patch.object(SummaryRecord, 'add_value') | |||
| def test_collect_dataset_graph_success(self, mock_add_value): | |||
| @@ -296,7 +304,6 @@ class TestSummaryCollector: | |||
| assert summary_collector._is_parse_loss_success | |||
| def test_get_optimizer_from_cb_params_success(self): | |||
| """Test get optimizer success from cb params.""" | |||
| cb_params = _InternalCallbackParam() | |||
| @@ -0,0 +1,80 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test_summary_abnormal_input""" | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| def get_test_data(step): | |||
| """ get_test_data """ | |||
| test_data_list = [] | |||
| tag1 = "x1[:Scalar]" | |||
| tag2 = "x2[:Scalar]" | |||
| np1 = np.array(step + 1).astype(np.float32) | |||
| np2 = np.array(step + 2).astype(np.float32) | |||
| dict1 = {} | |||
| dict1["name"] = tag1 | |||
| dict1["data"] = Tensor(np1) | |||
| dict2 = {} | |||
| dict2["name"] = tag2 | |||
| dict2["data"] = Tensor(np2) | |||
| test_data_list.append(dict1) | |||
| test_data_list.append(dict2) | |||
| return test_data_list | |||
| class TestSummaryRecord: | |||
| """Test SummaryRecord""" | |||
| def setup_class(self): | |||
| """Run before test this class.""" | |||
| self.base_summary_dir = tempfile.mkdtemp(suffix='summary') | |||
| def teardown_class(self): | |||
| """Run after test this class.""" | |||
| if os.path.exists(self.base_summary_dir): | |||
| shutil.rmtree(self.base_summary_dir) | |||
| @pytest.mark.parametrize("log_dir", ["", None, 32]) | |||
| def test_log_dir_with_type_error(self, log_dir): | |||
| with pytest.raises(TypeError): | |||
| with SummaryRecord(log_dir): | |||
| pass | |||
| @pytest.mark.parametrize("raise_exception", ["", None, 32]) | |||
| def test_raise_exception_with_type_error(self, raise_exception): | |||
| summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | |||
| with pytest.raises(TypeError) as exc: | |||
| with SummaryRecord(log_dir=summary_dir, raise_exception=raise_exception): | |||
| pass | |||
| assert "raise_exception" in str(exc.value) | |||
| @pytest.mark.parametrize("step", [False, 2.0, (1, 3), [2, 3], "str"]) | |||
| def test_step_of_record_with_type_error(self, step): | |||
| summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | |||
| with pytest.raises(TypeError): | |||
| with SummaryRecord(summary_dir) as sr: | |||
| sr.record(step) | |||