diff --git a/mindspore/explainer/_image_classification_runner.py b/mindspore/explainer/_image_classification_runner.py index ba398ff65d..d3a173e34d 100644 --- a/mindspore/explainer/_image_classification_runner.py +++ b/mindspore/explainer/_image_classification_runner.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -199,7 +199,7 @@ class ImageClassificationRunner: """ self._verify_data_n_settings(check_all=True) - with SummaryRecord(self._summary_dir) as summary: + with SummaryRecord(self._summary_dir, raise_exception=True) as summary: print("Start running and writing......") begin = time() diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 93c79bef73..84f35398f2 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -207,7 +207,9 @@ class SummaryCollector(Callback): self._dataset_sink_mode = True def __enter__(self): - self._record = SummaryRecord(log_dir=self._summary_dir, max_file_size=self._max_file_size) + self._record = SummaryRecord(log_dir=self._summary_dir, + max_file_size=self._max_file_size, + raise_exception=False) self._first_step, self._dataset_sink_mode = True, True return self @@ -319,7 +321,14 @@ class SummaryCollector(Callback): f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}') if 'histogram_regular' in specified_data: - check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None))) + regular = specified_data.get('histogram_regular') + check_value_type('histogram_regular', regular, (str, type(None))) + if isinstance(regular, str): + try: + re.match(regular, '') + except re.error as exc: + raise ValueError(f'For `collect_specified_data`, the value of `histogram_regular` ' + f'is not a valid regular expression. Detail: {str(exc)}.') bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'} for item in bool_items: diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index a583b17083..4fbfa8bcfa 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -57,14 +57,17 @@ class WriterPool(ctx.Process): Args: base_dir (str): The base directory to hold all the files. max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes. + raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs + in recording data. Default: False, this means that error logs are printed and no exception is thrown. filedict (dict): The mapping from plugin to filename. """ - def __init__(self, base_dir, max_file_size, **filedict) -> None: + def __init__(self, base_dir, max_file_size, raise_exception=False, **filedict) -> None: super().__init__() self._base_dir, self._filedict = base_dir, filedict self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None self._max_file_size = max_file_size + self._raise_exception = raise_exception self.start() def run(self): @@ -119,8 +122,14 @@ class WriterPool(ctx.Process): for writer in self._writers[:]: try: writer.write(plugin, data) - except RuntimeError as e: - logger.warning(e.args[0]) + except RuntimeError as exc: + logger.error(str(exc)) + self._writers.remove(writer) + writer.close() + if self._raise_exception: + raise + except RuntimeWarning as exc: + logger.warning(str(exc)) self._writers.remove(writer) writer.close() diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index f7db1d857f..4c0e199723 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import threading from collections import defaultdict from mindspore import log as logger +from mindspore.nn import Cell from ..._c_expression import Tensor from ..._checkparam import Validator @@ -29,7 +30,7 @@ from ._explain_adapter import check_explain_proto from ._writer_pool import WriterPool # for the moment, this lock is for caution's sake, -# there are actually no any concurrencies happening. +# there are actually no any concurrences happening. _summary_lock = threading.Lock() # cache the summary data _summary_tensor_cache = {} @@ -56,10 +57,6 @@ def _get_summary_tensor_data(): return data -def _dictlist(): - return defaultdict(list) - - class SummaryRecord: """ SummaryRecord is used to record the summary data and lineage data. @@ -80,12 +77,13 @@ class SummaryRecord: file_prefix (str): The prefix of file. Default: "events". file_suffix (str): The suffix of file. Default: "_MS". network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. - max_file_size (Optional[int]): The maximum size of each file that can be written to disk (in bytes). \ + max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). \ Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. + raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs + in recording data. Default: False, this means that error logs are printed and no exception is thrown. Raises: - TypeError: If the type of `max_file_size` is not int, or the type of `file_prefix` or `file_suffix` is not str. - RuntimeError: If the log_dir is not a normalized absolute path name. + TypeError: If the parameter type is incorrect. Examples: >>> # use in with statement to auto close @@ -100,10 +98,11 @@ class SummaryRecord: ... summary_record.close() """ - def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", network=None, max_file_size=None): + def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", + network=None, max_file_size=None, raise_exception=False): self._closed, self._event_writer = False, None - self._mode, self._data_pool = 'train', _dictlist() + self._mode, self._data_pool = 'train', defaultdict(list) Validator.check_str_by_regular(file_prefix) Validator.check_str_by_regular(file_suffix) @@ -120,6 +119,8 @@ class SummaryRecord: logger.warning("The 'max_file_size' should be greater than 0.") max_file_size = None + Validator.check_value_type(arg_name='raise_exception', arg_value=raise_exception, valid_types=bool) + self.prefix = file_prefix self.suffix = file_suffix self.network = network @@ -127,16 +128,15 @@ class SummaryRecord: # create the summary writer file self.event_file_name = get_event_file_name(self.prefix, self.suffix) - try: - self.full_file_name = os.path.join(self.log_path, self.event_file_name) - except Exception as ex: - raise RuntimeError(ex) + self.full_file_name = os.path.join(self.log_path, self.event_file_name) + filename_dict = dict(summary=self.full_file_name, + lineage=get_event_file_name(self.prefix, '_lineage'), + explainer=get_event_file_name(self.prefix, '_explain')) self._event_writer = WriterPool(log_dir, max_file_size, - summary=self.full_file_name, - lineage=get_event_file_name(self.prefix, '_lineage'), - explainer=get_event_file_name(self.prefix, '_explain')) + raise_exception, + **filename_dict) _get_summary_tensor_data() atexit.register(self.close) @@ -195,8 +195,8 @@ class SummaryRecord: - The data type of value should be a 'Explain' object when the plugin is 'explainer', see mindspore/ccsrc/summary.proto. Raises: - ValueError: When the name is not valid. - TypeError: When the value is not a Tensor. + ValueError: If the parameter value is invalid. + TypeError: If the parameter type is error. Examples: >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: @@ -238,6 +238,10 @@ class SummaryRecord: Returns: bool, whether the record process is successful or not. + Raises: + TypeError: If the parameter type is error. + RuntimeError: If the disk space is insufficient. + Examples: >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: ... summary_record.record(step=2) @@ -245,11 +249,12 @@ class SummaryRecord: True """ logger.debug("SummaryRecord step is %r.", step) + Validator.check_value_type(arg_name='step', arg_value=step, valid_types=int) + Validator.check_value_type(arg_name='train_network', arg_value=train_network, valid_types=[Cell, type(None)]) + if self._closed: logger.error("The record writer is closed.") return False - if not isinstance(step, int) or isinstance(step, bool): - raise ValueError("`step` should be int") # Set the current summary of train step if self.network is not None and not self.has_graph: graph_proto = self.network.get_func_graph_proto() @@ -294,7 +299,7 @@ class SummaryRecord: value['step'] = step return self._data_pool finally: - self._data_pool = _dictlist() + self._data_pool = defaultdict(list) @property def log_dir(self): diff --git a/mindspore/train/summary/writer.py b/mindspore/train/summary/writer.py index 1a8e424473..b74ac7587b 100644 --- a/mindspore/train/summary/writer.py +++ b/mindspore/train/summary/writer.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,8 +60,8 @@ class BaseWriter: self._max_file_size -= required_length self.writer.Write(data) else: - raise RuntimeError(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, " - f"but the '{self._filepath}' requires to write {required_length} bytes.") + raise RuntimeWarning(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, " + f"but the '{self._filepath}' requires to write {required_length} bytes.") def flush(self): """Flush the writer.""" diff --git a/tests/ut/python/train/summary/test_image_summary.py b/tests/ut/python/train/summary/test_image_summary.py index addeaec212..6801f3f1b4 100644 --- a/tests/ut/python/train/summary/test_image_summary.py +++ b/tests/ut/python/train/summary/test_image_summary.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -@File : test_image_summary.py -@Author: -@Date : 2019-07-4 -@Desc : test summary function -""" +"""test_image_summary""" import logging import os import numpy as np @@ -70,23 +65,14 @@ def get_test_data(step): # Test: call method on parse graph code def test_image_summary_sample(): """ test_image_summary_sample """ - log.debug("begin test_image_summary_sample") - # step 0: create the thread with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: - # step 1: create the test data for summary - # step 2: create the Event for i in range(1, 5): test_data = get_test_data(i) _cache_summary_tensor_data(test_data) test_writer.record(i) test_writer.flush() - # step 3: send the event to mq - - # step 4: accept the event and write the file - log.debug("finished test_image_summary_sample") - class Net(nn.Cell): """ Net definition """ @@ -175,23 +161,11 @@ class ImageSummaryCallback(Callback): def test_image_summary_train(): """ test_image_summary_train """ dataset = get_dataset() - - log.debug("begin test_image_summary_sample") - # step 0: create the thread with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: - # step 1: create the test data for summary - - # step 2: create the Event - model = get_model() callback = ImageSummaryCallback(test_writer) model.train(2, dataset, callbacks=[callback]) - # step 3: send the event to mq - - # step 4: accept the event and write the file - log.debug("finished test_image_summary_sample") - def test_image_summary_data(): """ test_image_summary_data """ @@ -207,13 +181,6 @@ def test_image_summary_data(): test_data_list.append(dct) i += 1 - log.debug("begin test_image_summary_sample") - # step 0: create the thread with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: - # step 1: create the test data for summary - - # step 2: create the Event _cache_summary_tensor_data(test_data_list) test_writer.record(1) - - log.debug("finished test_image_summary_sample") diff --git a/tests/ut/python/train/summary/test_summary.py b/tests/ut/python/train/summary/test_summary.py index b069d89954..cf60780413 100644 --- a/tests/ut/python/train/summary/test_summary.py +++ b/tests/ut/python/train/summary/test_summary.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -@File : test_summary.py -@Author: -@Date : 2019-07-4 -@Desc : test summary function -""" -import logging +"""Test summary.""" import os import random + import numpy as np -import pytest + import mindspore.nn as nn from mindspore.common.tensor import Tensor @@ -32,9 +27,6 @@ from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary CUR_DIR = os.getcwd() SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" -log = logging.getLogger("test") -log.setLevel(level=logging.ERROR) - def get_test_data(step): """ get_test_data """ @@ -58,26 +50,14 @@ def get_test_data(step): return test_data_list -# Test 1: summary sample of scalar def test_scalar_summary_sample(): """ test_scalar_summary_sample """ - log.debug("begin test_scalar_summary_sample") - # step 0: create the thread with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: - # step 1: create the test data for summary - - # step 2: create the Event - for i in range(1, 500): + for i in range(1, 5): test_data = get_test_data(i) _cache_summary_tensor_data(test_data) test_writer.record(i) - # step 3: send the event to mq - - # step 4: accept the event and write the file - - log.debug("finished test_scalar_summary_sample") - def get_test_data_shape_1(step): """ get_test_data_shape_1 """ @@ -104,23 +84,12 @@ def get_test_data_shape_1(step): # Test: shape = (1,) def test_scalar_summary_sample_with_shape_1(): """ test_scalar_summary_sample_with_shape_1 """ - log.debug("begin test_scalar_summary_sample_with_shape_1") - # step 0: create the thread with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: - # step 1: create the test data for summary - - # step 2: create the Event for i in range(1, 100): test_data = get_test_data_shape_1(i) _cache_summary_tensor_data(test_data) test_writer.record(i) - # step 3: send the event to mq - - # step 4: accept the event and write the file - - log.debug("finished test_scalar_summary_sample") - # Test: test with ge class SummaryDemo(nn.Cell): @@ -143,13 +112,7 @@ class SummaryDemo(nn.Cell): def test_scalar_summary_with_ge(): """ test_scalar_summary_with_ge """ - log.debug("begin test_scalar_summary_with_ge") - - # step 0: create the thread with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: - # step 1: create the network for summary - x = Tensor(np.array([1.1]).astype(np.float32)) - y = Tensor(np.array([1.2]).astype(np.float32)) net = SummaryDemo() net.set_train() @@ -161,45 +124,17 @@ def test_scalar_summary_with_ge(): net(x, y) test_writer.record(i) - log.debug("finished test_scalar_summary_with_ge") - # test the problem of two consecutive use cases going wrong def test_scalar_summary_with_ge_2(): """ test_scalar_summary_with_ge_2 """ - log.debug("begin test_scalar_summary_with_ge_2") - - # step 0: create the thread with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: - # step 1: create the network for summary - x = Tensor(np.array([1.1]).astype(np.float32)) - y = Tensor(np.array([1.2]).astype(np.float32)) net = SummaryDemo() net.set_train() - # step 2: create the Event steps = 100 for i in range(1, steps): x = Tensor(np.array([1.1]).astype(np.float32)) y = Tensor(np.array([1.2]).astype(np.float32)) net(x, y) test_writer.record(i) - - log.debug("finished test_scalar_summary_with_ge_2") - - -def test_validate(): - with SummaryRecord(SUMMARY_DIR) as sr: - sr.record(1) - with pytest.raises(ValueError): - sr.record(False) - with pytest.raises(ValueError): - sr.record(2.0) - with pytest.raises(ValueError): - sr.record((1, 3)) - with pytest.raises(ValueError): - sr.record([2, 3]) - with pytest.raises(ValueError): - sr.record("str") - with pytest.raises(ValueError): - sr.record(sr) diff --git a/tests/ut/python/train/summary/test_summary_abnormal_input.py b/tests/ut/python/train/summary/test_summary_abnormal_input.py deleted file mode 100644 index 388952feca..0000000000 --- a/tests/ut/python/train/summary/test_summary_abnormal_input.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -@File : test_summary_abnormal_input.py -@Author: -@Date : 2019-08-5 -@Desc : test summary function of abnormal input -""" -import logging -import os -import numpy as np - -from mindspore.common.tensor import Tensor -from mindspore.train.summary.summary_record import SummaryRecord - -CUR_DIR = os.getcwd() -SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" - -log = logging.getLogger("test") -log.setLevel(level=logging.ERROR) - - -def get_test_data(step): - """ get_test_data """ - test_data_list = [] - tag1 = "x1[:Scalar]" - tag2 = "x2[:Scalar]" - np1 = np.array(step + 1).astype(np.float32) - np2 = np.array(step + 2).astype(np.float32) - - dict1 = {} - dict1["name"] = tag1 - dict1["data"] = Tensor(np1) - - dict2 = {} - dict2["name"] = tag2 - dict2["data"] = Tensor(np2) - - test_data_list.append(dict1) - test_data_list.append(dict2) - - return test_data_list - - -# Test: call method on parse graph code -def test_summaryrecord_input_null_string(): - log.debug("begin test_summaryrecord_input_null_string") - # step 0: create the thread - try: - with SummaryRecord(""): - pass - except: - assert True - else: - assert False - log.debug("finished test_summaryrecord_input_null_string") - - -def test_summaryrecord_input_None(): - log.debug("begin test_summaryrecord_input_None") - # step 0: create the thread - try: - with SummaryRecord(None): - pass - except: - assert True - else: - assert False - log.debug("finished test_summaryrecord_input_None") - - -def test_summaryrecord_input_relative_dir_1(): - log.debug("begin test_summaryrecord_input_relative_dir_1") - # step 0: create the thread - try: - with SummaryRecord("./test_temp_summary_event_file/"): - pass - except: - assert False - else: - assert True - log.debug("finished test_summaryrecord_input_relative_dir_1") - - -def test_summaryrecord_input_relative_dir_2(): - log.debug("begin test_summaryrecord_input_relative_dir_2") - # step 0: create the thread - try: - with SummaryRecord("../summary/"): - pass - except: - assert False - else: - assert True - log.debug("finished test_summaryrecord_input_relative_dir_2") - - -def test_summaryrecord_input_invalid_type_dir(): - log.debug("begin test_summaryrecord_input_invalid_type_dir") - # step 0: create the thread - try: - with SummaryRecord(32): - pass - except: - assert True - else: - assert False - log.debug("finished test_summaryrecord_input_invalid_type_dir") - - -def test_mulit_layer_directory(): - log.debug("begin test_mulit_layer_directory") - # step 0: create the thread - try: - with SummaryRecord("./test_temp_summary_event_file/test/t1/"): - pass - except: - assert False - else: - assert True - log.debug("finished test_mulit_layer_directory") diff --git a/tests/ut/python/train/summary/test_summary_collector.py b/tests/ut/python/train/summary/test_summary_collector.py index 3349cf8287..36bde6f88f 100644 --- a/tests/ut/python/train/summary/test_summary_collector.py +++ b/tests/ut/python/train/summary/test_summary_collector.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -187,6 +187,14 @@ class TestSummaryCollector: assert expected_msg == str(exc.value) + def test_params_with_histogram_regular_value_error(self): + """Test histogram regular.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(ValueError) as exc: + SummaryCollector(summary_dir, collect_specified_data={'histogram_regular': '*'}) + + assert 'For `collect_specified_data`, the value of `histogram_regular`' in str(exc.value) + def test_params_with_collect_specified_data_unexpected_key(self): """Test the collect_specified_data parameter with unexpected key.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) @@ -260,7 +268,7 @@ class TestSummaryCollector: cb_params.train_dataset_element = image_data with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: summary_collector._collect_input_data(cb_params) - # Note Here need to asssert the result and expected data + # Note Here need to assert the result and expected data @mock.patch.object(SummaryRecord, 'add_value') def test_collect_dataset_graph_success(self, mock_add_value): @@ -296,7 +304,6 @@ class TestSummaryCollector: assert summary_collector._is_parse_loss_success - def test_get_optimizer_from_cb_params_success(self): """Test get optimizer success from cb params.""" cb_params = _InternalCallbackParam() diff --git a/tests/ut/python/train/summary/test_summary_record.py b/tests/ut/python/train/summary/test_summary_record.py new file mode 100644 index 0000000000..4f317174b8 --- /dev/null +++ b/tests/ut/python/train/summary/test_summary_record.py @@ -0,0 +1,80 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test_summary_abnormal_input""" +import os +import shutil +import tempfile + +import numpy as np +import pytest + +from mindspore.common.tensor import Tensor +from mindspore.train.summary.summary_record import SummaryRecord + + +def get_test_data(step): + """ get_test_data """ + test_data_list = [] + tag1 = "x1[:Scalar]" + tag2 = "x2[:Scalar]" + np1 = np.array(step + 1).astype(np.float32) + np2 = np.array(step + 2).astype(np.float32) + + dict1 = {} + dict1["name"] = tag1 + dict1["data"] = Tensor(np1) + + dict2 = {} + dict2["name"] = tag2 + dict2["data"] = Tensor(np2) + + test_data_list.append(dict1) + test_data_list.append(dict2) + + return test_data_list + + +class TestSummaryRecord: + """Test SummaryRecord""" + def setup_class(self): + """Run before test this class.""" + self.base_summary_dir = tempfile.mkdtemp(suffix='summary') + + def teardown_class(self): + """Run after test this class.""" + if os.path.exists(self.base_summary_dir): + shutil.rmtree(self.base_summary_dir) + + @pytest.mark.parametrize("log_dir", ["", None, 32]) + def test_log_dir_with_type_error(self, log_dir): + with pytest.raises(TypeError): + with SummaryRecord(log_dir): + pass + + @pytest.mark.parametrize("raise_exception", ["", None, 32]) + def test_raise_exception_with_type_error(self, raise_exception): + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(TypeError) as exc: + with SummaryRecord(log_dir=summary_dir, raise_exception=raise_exception): + pass + + assert "raise_exception" in str(exc.value) + + @pytest.mark.parametrize("step", [False, 2.0, (1, 3), [2, 3], "str"]) + def test_step_of_record_with_type_error(self, step): + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(TypeError): + with SummaryRecord(summary_dir) as sr: + sr.record(step)