1. Support explainer raise an RuntimeError exception 2. fix the ut of SummaryRecordtags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -199,7 +199,7 @@ class ImageClassificationRunner: | |||||
| """ | """ | ||||
| self._verify_data_n_settings(check_all=True) | self._verify_data_n_settings(check_all=True) | ||||
| with SummaryRecord(self._summary_dir) as summary: | |||||
| with SummaryRecord(self._summary_dir, raise_exception=True) as summary: | |||||
| print("Start running and writing......") | print("Start running and writing......") | ||||
| begin = time() | begin = time() | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -207,7 +207,9 @@ class SummaryCollector(Callback): | |||||
| self._dataset_sink_mode = True | self._dataset_sink_mode = True | ||||
| def __enter__(self): | def __enter__(self): | ||||
| self._record = SummaryRecord(log_dir=self._summary_dir, max_file_size=self._max_file_size) | |||||
| self._record = SummaryRecord(log_dir=self._summary_dir, | |||||
| max_file_size=self._max_file_size, | |||||
| raise_exception=False) | |||||
| self._first_step, self._dataset_sink_mode = True, True | self._first_step, self._dataset_sink_mode = True, True | ||||
| return self | return self | ||||
| @@ -319,7 +321,14 @@ class SummaryCollector(Callback): | |||||
| f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}') | f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}') | ||||
| if 'histogram_regular' in specified_data: | if 'histogram_regular' in specified_data: | ||||
| check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None))) | |||||
| regular = specified_data.get('histogram_regular') | |||||
| check_value_type('histogram_regular', regular, (str, type(None))) | |||||
| if isinstance(regular, str): | |||||
| try: | |||||
| re.match(regular, '') | |||||
| except re.error as exc: | |||||
| raise ValueError(f'For `collect_specified_data`, the value of `histogram_regular` ' | |||||
| f'is not a valid regular expression. Detail: {str(exc)}.') | |||||
| bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'} | bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'} | ||||
| for item in bool_items: | for item in bool_items: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -57,14 +57,17 @@ class WriterPool(ctx.Process): | |||||
| Args: | Args: | ||||
| base_dir (str): The base directory to hold all the files. | base_dir (str): The base directory to hold all the files. | ||||
| max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes. | max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes. | ||||
| raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs | |||||
| in recording data. Default: False, this means that error logs are printed and no exception is thrown. | |||||
| filedict (dict): The mapping from plugin to filename. | filedict (dict): The mapping from plugin to filename. | ||||
| """ | """ | ||||
| def __init__(self, base_dir, max_file_size, **filedict) -> None: | |||||
| def __init__(self, base_dir, max_file_size, raise_exception=False, **filedict) -> None: | |||||
| super().__init__() | super().__init__() | ||||
| self._base_dir, self._filedict = base_dir, filedict | self._base_dir, self._filedict = base_dir, filedict | ||||
| self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None | self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None | ||||
| self._max_file_size = max_file_size | self._max_file_size = max_file_size | ||||
| self._raise_exception = raise_exception | |||||
| self.start() | self.start() | ||||
| def run(self): | def run(self): | ||||
| @@ -119,8 +122,14 @@ class WriterPool(ctx.Process): | |||||
| for writer in self._writers[:]: | for writer in self._writers[:]: | ||||
| try: | try: | ||||
| writer.write(plugin, data) | writer.write(plugin, data) | ||||
| except RuntimeError as e: | |||||
| logger.warning(e.args[0]) | |||||
| except RuntimeError as exc: | |||||
| logger.error(str(exc)) | |||||
| self._writers.remove(writer) | |||||
| writer.close() | |||||
| if self._raise_exception: | |||||
| raise | |||||
| except RuntimeWarning as exc: | |||||
| logger.warning(str(exc)) | |||||
| self._writers.remove(writer) | self._writers.remove(writer) | ||||
| writer.close() | writer.close() | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -20,6 +20,7 @@ import threading | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.nn import Cell | |||||
| from ..._c_expression import Tensor | from ..._c_expression import Tensor | ||||
| from ..._checkparam import Validator | from ..._checkparam import Validator | ||||
| @@ -29,7 +30,7 @@ from ._explain_adapter import check_explain_proto | |||||
| from ._writer_pool import WriterPool | from ._writer_pool import WriterPool | ||||
| # for the moment, this lock is for caution's sake, | # for the moment, this lock is for caution's sake, | ||||
| # there are actually no any concurrencies happening. | |||||
| # there are actually no any concurrences happening. | |||||
| _summary_lock = threading.Lock() | _summary_lock = threading.Lock() | ||||
| # cache the summary data | # cache the summary data | ||||
| _summary_tensor_cache = {} | _summary_tensor_cache = {} | ||||
| @@ -56,10 +57,6 @@ def _get_summary_tensor_data(): | |||||
| return data | return data | ||||
| def _dictlist(): | |||||
| return defaultdict(list) | |||||
| class SummaryRecord: | class SummaryRecord: | ||||
| """ | """ | ||||
| SummaryRecord is used to record the summary data and lineage data. | SummaryRecord is used to record the summary data and lineage data. | ||||
| @@ -80,12 +77,13 @@ class SummaryRecord: | |||||
| file_prefix (str): The prefix of file. Default: "events". | file_prefix (str): The prefix of file. Default: "events". | ||||
| file_suffix (str): The suffix of file. Default: "_MS". | file_suffix (str): The suffix of file. Default: "_MS". | ||||
| network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. | network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. | ||||
| max_file_size (Optional[int]): The maximum size of each file that can be written to disk (in bytes). \ | |||||
| max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). \ | |||||
| Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. | Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. | ||||
| raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs | |||||
| in recording data. Default: False, this means that error logs are printed and no exception is thrown. | |||||
| Raises: | Raises: | ||||
| TypeError: If the type of `max_file_size` is not int, or the type of `file_prefix` or `file_suffix` is not str. | |||||
| RuntimeError: If the log_dir is not a normalized absolute path name. | |||||
| TypeError: If the parameter type is incorrect. | |||||
| Examples: | Examples: | ||||
| >>> # use in with statement to auto close | >>> # use in with statement to auto close | ||||
| @@ -100,10 +98,11 @@ class SummaryRecord: | |||||
| ... summary_record.close() | ... summary_record.close() | ||||
| """ | """ | ||||
| def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", network=None, max_file_size=None): | |||||
| def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", | |||||
| network=None, max_file_size=None, raise_exception=False): | |||||
| self._closed, self._event_writer = False, None | self._closed, self._event_writer = False, None | ||||
| self._mode, self._data_pool = 'train', _dictlist() | |||||
| self._mode, self._data_pool = 'train', defaultdict(list) | |||||
| Validator.check_str_by_regular(file_prefix) | Validator.check_str_by_regular(file_prefix) | ||||
| Validator.check_str_by_regular(file_suffix) | Validator.check_str_by_regular(file_suffix) | ||||
| @@ -120,6 +119,8 @@ class SummaryRecord: | |||||
| logger.warning("The 'max_file_size' should be greater than 0.") | logger.warning("The 'max_file_size' should be greater than 0.") | ||||
| max_file_size = None | max_file_size = None | ||||
| Validator.check_value_type(arg_name='raise_exception', arg_value=raise_exception, valid_types=bool) | |||||
| self.prefix = file_prefix | self.prefix = file_prefix | ||||
| self.suffix = file_suffix | self.suffix = file_suffix | ||||
| self.network = network | self.network = network | ||||
| @@ -127,16 +128,15 @@ class SummaryRecord: | |||||
| # create the summary writer file | # create the summary writer file | ||||
| self.event_file_name = get_event_file_name(self.prefix, self.suffix) | self.event_file_name = get_event_file_name(self.prefix, self.suffix) | ||||
| try: | |||||
| self.full_file_name = os.path.join(self.log_path, self.event_file_name) | |||||
| except Exception as ex: | |||||
| raise RuntimeError(ex) | |||||
| self.full_file_name = os.path.join(self.log_path, self.event_file_name) | |||||
| filename_dict = dict(summary=self.full_file_name, | |||||
| lineage=get_event_file_name(self.prefix, '_lineage'), | |||||
| explainer=get_event_file_name(self.prefix, '_explain')) | |||||
| self._event_writer = WriterPool(log_dir, | self._event_writer = WriterPool(log_dir, | ||||
| max_file_size, | max_file_size, | ||||
| summary=self.full_file_name, | |||||
| lineage=get_event_file_name(self.prefix, '_lineage'), | |||||
| explainer=get_event_file_name(self.prefix, '_explain')) | |||||
| raise_exception, | |||||
| **filename_dict) | |||||
| _get_summary_tensor_data() | _get_summary_tensor_data() | ||||
| atexit.register(self.close) | atexit.register(self.close) | ||||
| @@ -195,8 +195,8 @@ class SummaryRecord: | |||||
| - The data type of value should be a 'Explain' object when the plugin is 'explainer', | - The data type of value should be a 'Explain' object when the plugin is 'explainer', | ||||
| see mindspore/ccsrc/summary.proto. | see mindspore/ccsrc/summary.proto. | ||||
| Raises: | Raises: | ||||
| ValueError: When the name is not valid. | |||||
| TypeError: When the value is not a Tensor. | |||||
| ValueError: If the parameter value is invalid. | |||||
| TypeError: If the parameter type is error. | |||||
| Examples: | Examples: | ||||
| >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | ||||
| @@ -238,6 +238,10 @@ class SummaryRecord: | |||||
| Returns: | Returns: | ||||
| bool, whether the record process is successful or not. | bool, whether the record process is successful or not. | ||||
| Raises: | |||||
| TypeError: If the parameter type is error. | |||||
| RuntimeError: If the disk space is insufficient. | |||||
| Examples: | Examples: | ||||
| >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | ||||
| ... summary_record.record(step=2) | ... summary_record.record(step=2) | ||||
| @@ -245,11 +249,12 @@ class SummaryRecord: | |||||
| True | True | ||||
| """ | """ | ||||
| logger.debug("SummaryRecord step is %r.", step) | logger.debug("SummaryRecord step is %r.", step) | ||||
| Validator.check_value_type(arg_name='step', arg_value=step, valid_types=int) | |||||
| Validator.check_value_type(arg_name='train_network', arg_value=train_network, valid_types=[Cell, type(None)]) | |||||
| if self._closed: | if self._closed: | ||||
| logger.error("The record writer is closed.") | logger.error("The record writer is closed.") | ||||
| return False | return False | ||||
| if not isinstance(step, int) or isinstance(step, bool): | |||||
| raise ValueError("`step` should be int") | |||||
| # Set the current summary of train step | # Set the current summary of train step | ||||
| if self.network is not None and not self.has_graph: | if self.network is not None and not self.has_graph: | ||||
| graph_proto = self.network.get_func_graph_proto() | graph_proto = self.network.get_func_graph_proto() | ||||
| @@ -294,7 +299,7 @@ class SummaryRecord: | |||||
| value['step'] = step | value['step'] = step | ||||
| return self._data_pool | return self._data_pool | ||||
| finally: | finally: | ||||
| self._data_pool = _dictlist() | |||||
| self._data_pool = defaultdict(list) | |||||
| @property | @property | ||||
| def log_dir(self): | def log_dir(self): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -60,8 +60,8 @@ class BaseWriter: | |||||
| self._max_file_size -= required_length | self._max_file_size -= required_length | ||||
| self.writer.Write(data) | self.writer.Write(data) | ||||
| else: | else: | ||||
| raise RuntimeError(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, " | |||||
| f"but the '{self._filepath}' requires to write {required_length} bytes.") | |||||
| raise RuntimeWarning(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, " | |||||
| f"but the '{self._filepath}' requires to write {required_length} bytes.") | |||||
| def flush(self): | def flush(self): | ||||
| """Flush the writer.""" | """Flush the writer.""" | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -12,12 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ | |||||
| @File : test_image_summary.py | |||||
| @Author: | |||||
| @Date : 2019-07-4 | |||||
| @Desc : test summary function | |||||
| """ | |||||
| """test_image_summary""" | |||||
| import logging | import logging | ||||
| import os | import os | ||||
| import numpy as np | import numpy as np | ||||
| @@ -70,23 +65,14 @@ def get_test_data(step): | |||||
| # Test: call method on parse graph code | # Test: call method on parse graph code | ||||
| def test_image_summary_sample(): | def test_image_summary_sample(): | ||||
| """ test_image_summary_sample """ | """ test_image_summary_sample """ | ||||
| log.debug("begin test_image_summary_sample") | |||||
| # step 0: create the thread | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | ||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| for i in range(1, 5): | for i in range(1, 5): | ||||
| test_data = get_test_data(i) | test_data = get_test_data(i) | ||||
| _cache_summary_tensor_data(test_data) | _cache_summary_tensor_data(test_data) | ||||
| test_writer.record(i) | test_writer.record(i) | ||||
| test_writer.flush() | test_writer.flush() | ||||
| # step 3: send the event to mq | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_image_summary_sample") | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| """ Net definition """ | """ Net definition """ | ||||
| @@ -175,23 +161,11 @@ class ImageSummaryCallback(Callback): | |||||
| def test_image_summary_train(): | def test_image_summary_train(): | ||||
| """ test_image_summary_train """ | """ test_image_summary_train """ | ||||
| dataset = get_dataset() | dataset = get_dataset() | ||||
| log.debug("begin test_image_summary_sample") | |||||
| # step 0: create the thread | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | ||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| model = get_model() | model = get_model() | ||||
| callback = ImageSummaryCallback(test_writer) | callback = ImageSummaryCallback(test_writer) | ||||
| model.train(2, dataset, callbacks=[callback]) | model.train(2, dataset, callbacks=[callback]) | ||||
| # step 3: send the event to mq | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_image_summary_sample") | |||||
| def test_image_summary_data(): | def test_image_summary_data(): | ||||
| """ test_image_summary_data """ | """ test_image_summary_data """ | ||||
| @@ -207,13 +181,6 @@ def test_image_summary_data(): | |||||
| test_data_list.append(dct) | test_data_list.append(dct) | ||||
| i += 1 | i += 1 | ||||
| log.debug("begin test_image_summary_sample") | |||||
| # step 0: create the thread | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | ||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| _cache_summary_tensor_data(test_data_list) | _cache_summary_tensor_data(test_data_list) | ||||
| test_writer.record(1) | test_writer.record(1) | ||||
| log.debug("finished test_image_summary_sample") | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -12,17 +12,12 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ | |||||
| @File : test_summary.py | |||||
| @Author: | |||||
| @Date : 2019-07-4 | |||||
| @Desc : test summary function | |||||
| """ | |||||
| import logging | |||||
| """Test summary.""" | |||||
| import os | import os | ||||
| import random | import random | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| @@ -32,9 +27,6 @@ from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary | |||||
| CUR_DIR = os.getcwd() | CUR_DIR = os.getcwd() | ||||
| SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" | SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" | ||||
| log = logging.getLogger("test") | |||||
| log.setLevel(level=logging.ERROR) | |||||
| def get_test_data(step): | def get_test_data(step): | ||||
| """ get_test_data """ | """ get_test_data """ | ||||
| @@ -58,26 +50,14 @@ def get_test_data(step): | |||||
| return test_data_list | return test_data_list | ||||
| # Test 1: summary sample of scalar | |||||
| def test_scalar_summary_sample(): | def test_scalar_summary_sample(): | ||||
| """ test_scalar_summary_sample """ | """ test_scalar_summary_sample """ | ||||
| log.debug("begin test_scalar_summary_sample") | |||||
| # step 0: create the thread | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | ||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| for i in range(1, 500): | |||||
| for i in range(1, 5): | |||||
| test_data = get_test_data(i) | test_data = get_test_data(i) | ||||
| _cache_summary_tensor_data(test_data) | _cache_summary_tensor_data(test_data) | ||||
| test_writer.record(i) | test_writer.record(i) | ||||
| # step 3: send the event to mq | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_scalar_summary_sample") | |||||
| def get_test_data_shape_1(step): | def get_test_data_shape_1(step): | ||||
| """ get_test_data_shape_1 """ | """ get_test_data_shape_1 """ | ||||
| @@ -104,23 +84,12 @@ def get_test_data_shape_1(step): | |||||
| # Test: shape = (1,) | # Test: shape = (1,) | ||||
| def test_scalar_summary_sample_with_shape_1(): | def test_scalar_summary_sample_with_shape_1(): | ||||
| """ test_scalar_summary_sample_with_shape_1 """ | """ test_scalar_summary_sample_with_shape_1 """ | ||||
| log.debug("begin test_scalar_summary_sample_with_shape_1") | |||||
| # step 0: create the thread | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | ||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| for i in range(1, 100): | for i in range(1, 100): | ||||
| test_data = get_test_data_shape_1(i) | test_data = get_test_data_shape_1(i) | ||||
| _cache_summary_tensor_data(test_data) | _cache_summary_tensor_data(test_data) | ||||
| test_writer.record(i) | test_writer.record(i) | ||||
| # step 3: send the event to mq | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_scalar_summary_sample") | |||||
| # Test: test with ge | # Test: test with ge | ||||
| class SummaryDemo(nn.Cell): | class SummaryDemo(nn.Cell): | ||||
| @@ -143,13 +112,7 @@ class SummaryDemo(nn.Cell): | |||||
| def test_scalar_summary_with_ge(): | def test_scalar_summary_with_ge(): | ||||
| """ test_scalar_summary_with_ge """ | """ test_scalar_summary_with_ge """ | ||||
| log.debug("begin test_scalar_summary_with_ge") | |||||
| # step 0: create the thread | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | ||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net = SummaryDemo() | net = SummaryDemo() | ||||
| net.set_train() | net.set_train() | ||||
| @@ -161,45 +124,17 @@ def test_scalar_summary_with_ge(): | |||||
| net(x, y) | net(x, y) | ||||
| test_writer.record(i) | test_writer.record(i) | ||||
| log.debug("finished test_scalar_summary_with_ge") | |||||
| # test the problem of two consecutive use cases going wrong | # test the problem of two consecutive use cases going wrong | ||||
| def test_scalar_summary_with_ge_2(): | def test_scalar_summary_with_ge_2(): | ||||
| """ test_scalar_summary_with_ge_2 """ | """ test_scalar_summary_with_ge_2 """ | ||||
| log.debug("begin test_scalar_summary_with_ge_2") | |||||
| # step 0: create the thread | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | ||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net = SummaryDemo() | net = SummaryDemo() | ||||
| net.set_train() | net.set_train() | ||||
| # step 2: create the Event | |||||
| steps = 100 | steps = 100 | ||||
| for i in range(1, steps): | for i in range(1, steps): | ||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | x = Tensor(np.array([1.1]).astype(np.float32)) | ||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | y = Tensor(np.array([1.2]).astype(np.float32)) | ||||
| net(x, y) | net(x, y) | ||||
| test_writer.record(i) | test_writer.record(i) | ||||
| log.debug("finished test_scalar_summary_with_ge_2") | |||||
| def test_validate(): | |||||
| with SummaryRecord(SUMMARY_DIR) as sr: | |||||
| sr.record(1) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(False) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(2.0) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record((1, 3)) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record([2, 3]) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record("str") | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(sr) | |||||
| @@ -1,133 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| @File : test_summary_abnormal_input.py | |||||
| @Author: | |||||
| @Date : 2019-08-5 | |||||
| @Desc : test summary function of abnormal input | |||||
| """ | |||||
| import logging | |||||
| import os | |||||
| import numpy as np | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.train.summary.summary_record import SummaryRecord | |||||
| CUR_DIR = os.getcwd() | |||||
| SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" | |||||
| log = logging.getLogger("test") | |||||
| log.setLevel(level=logging.ERROR) | |||||
| def get_test_data(step): | |||||
| """ get_test_data """ | |||||
| test_data_list = [] | |||||
| tag1 = "x1[:Scalar]" | |||||
| tag2 = "x2[:Scalar]" | |||||
| np1 = np.array(step + 1).astype(np.float32) | |||||
| np2 = np.array(step + 2).astype(np.float32) | |||||
| dict1 = {} | |||||
| dict1["name"] = tag1 | |||||
| dict1["data"] = Tensor(np1) | |||||
| dict2 = {} | |||||
| dict2["name"] = tag2 | |||||
| dict2["data"] = Tensor(np2) | |||||
| test_data_list.append(dict1) | |||||
| test_data_list.append(dict2) | |||||
| return test_data_list | |||||
| # Test: call method on parse graph code | |||||
| def test_summaryrecord_input_null_string(): | |||||
| log.debug("begin test_summaryrecord_input_null_string") | |||||
| # step 0: create the thread | |||||
| try: | |||||
| with SummaryRecord(""): | |||||
| pass | |||||
| except: | |||||
| assert True | |||||
| else: | |||||
| assert False | |||||
| log.debug("finished test_summaryrecord_input_null_string") | |||||
| def test_summaryrecord_input_None(): | |||||
| log.debug("begin test_summaryrecord_input_None") | |||||
| # step 0: create the thread | |||||
| try: | |||||
| with SummaryRecord(None): | |||||
| pass | |||||
| except: | |||||
| assert True | |||||
| else: | |||||
| assert False | |||||
| log.debug("finished test_summaryrecord_input_None") | |||||
| def test_summaryrecord_input_relative_dir_1(): | |||||
| log.debug("begin test_summaryrecord_input_relative_dir_1") | |||||
| # step 0: create the thread | |||||
| try: | |||||
| with SummaryRecord("./test_temp_summary_event_file/"): | |||||
| pass | |||||
| except: | |||||
| assert False | |||||
| else: | |||||
| assert True | |||||
| log.debug("finished test_summaryrecord_input_relative_dir_1") | |||||
| def test_summaryrecord_input_relative_dir_2(): | |||||
| log.debug("begin test_summaryrecord_input_relative_dir_2") | |||||
| # step 0: create the thread | |||||
| try: | |||||
| with SummaryRecord("../summary/"): | |||||
| pass | |||||
| except: | |||||
| assert False | |||||
| else: | |||||
| assert True | |||||
| log.debug("finished test_summaryrecord_input_relative_dir_2") | |||||
| def test_summaryrecord_input_invalid_type_dir(): | |||||
| log.debug("begin test_summaryrecord_input_invalid_type_dir") | |||||
| # step 0: create the thread | |||||
| try: | |||||
| with SummaryRecord(32): | |||||
| pass | |||||
| except: | |||||
| assert True | |||||
| else: | |||||
| assert False | |||||
| log.debug("finished test_summaryrecord_input_invalid_type_dir") | |||||
| def test_mulit_layer_directory(): | |||||
| log.debug("begin test_mulit_layer_directory") | |||||
| # step 0: create the thread | |||||
| try: | |||||
| with SummaryRecord("./test_temp_summary_event_file/test/t1/"): | |||||
| pass | |||||
| except: | |||||
| assert False | |||||
| else: | |||||
| assert True | |||||
| log.debug("finished test_mulit_layer_directory") | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -187,6 +187,14 @@ class TestSummaryCollector: | |||||
| assert expected_msg == str(exc.value) | assert expected_msg == str(exc.value) | ||||
| def test_params_with_histogram_regular_value_error(self): | |||||
| """Test histogram regular.""" | |||||
| summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | |||||
| with pytest.raises(ValueError) as exc: | |||||
| SummaryCollector(summary_dir, collect_specified_data={'histogram_regular': '*'}) | |||||
| assert 'For `collect_specified_data`, the value of `histogram_regular`' in str(exc.value) | |||||
| def test_params_with_collect_specified_data_unexpected_key(self): | def test_params_with_collect_specified_data_unexpected_key(self): | ||||
| """Test the collect_specified_data parameter with unexpected key.""" | """Test the collect_specified_data parameter with unexpected key.""" | ||||
| summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
| @@ -260,7 +268,7 @@ class TestSummaryCollector: | |||||
| cb_params.train_dataset_element = image_data | cb_params.train_dataset_element = image_data | ||||
| with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: | with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: | ||||
| summary_collector._collect_input_data(cb_params) | summary_collector._collect_input_data(cb_params) | ||||
| # Note Here need to asssert the result and expected data | |||||
| # Note Here need to assert the result and expected data | |||||
| @mock.patch.object(SummaryRecord, 'add_value') | @mock.patch.object(SummaryRecord, 'add_value') | ||||
| def test_collect_dataset_graph_success(self, mock_add_value): | def test_collect_dataset_graph_success(self, mock_add_value): | ||||
| @@ -296,7 +304,6 @@ class TestSummaryCollector: | |||||
| assert summary_collector._is_parse_loss_success | assert summary_collector._is_parse_loss_success | ||||
| def test_get_optimizer_from_cb_params_success(self): | def test_get_optimizer_from_cb_params_success(self): | ||||
| """Test get optimizer success from cb params.""" | """Test get optimizer success from cb params.""" | ||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| @@ -0,0 +1,80 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test_summary_abnormal_input""" | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.train.summary.summary_record import SummaryRecord | |||||
| def get_test_data(step): | |||||
| """ get_test_data """ | |||||
| test_data_list = [] | |||||
| tag1 = "x1[:Scalar]" | |||||
| tag2 = "x2[:Scalar]" | |||||
| np1 = np.array(step + 1).astype(np.float32) | |||||
| np2 = np.array(step + 2).astype(np.float32) | |||||
| dict1 = {} | |||||
| dict1["name"] = tag1 | |||||
| dict1["data"] = Tensor(np1) | |||||
| dict2 = {} | |||||
| dict2["name"] = tag2 | |||||
| dict2["data"] = Tensor(np2) | |||||
| test_data_list.append(dict1) | |||||
| test_data_list.append(dict2) | |||||
| return test_data_list | |||||
| class TestSummaryRecord: | |||||
| """Test SummaryRecord""" | |||||
| def setup_class(self): | |||||
| """Run before test this class.""" | |||||
| self.base_summary_dir = tempfile.mkdtemp(suffix='summary') | |||||
| def teardown_class(self): | |||||
| """Run after test this class.""" | |||||
| if os.path.exists(self.base_summary_dir): | |||||
| shutil.rmtree(self.base_summary_dir) | |||||
| @pytest.mark.parametrize("log_dir", ["", None, 32]) | |||||
| def test_log_dir_with_type_error(self, log_dir): | |||||
| with pytest.raises(TypeError): | |||||
| with SummaryRecord(log_dir): | |||||
| pass | |||||
| @pytest.mark.parametrize("raise_exception", ["", None, 32]) | |||||
| def test_raise_exception_with_type_error(self, raise_exception): | |||||
| summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | |||||
| with pytest.raises(TypeError) as exc: | |||||
| with SummaryRecord(log_dir=summary_dir, raise_exception=raise_exception): | |||||
| pass | |||||
| assert "raise_exception" in str(exc.value) | |||||
| @pytest.mark.parametrize("step", [False, 2.0, (1, 3), [2, 3], "str"]) | |||||
| def test_step_of_record_with_type_error(self, step): | |||||
| summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | |||||
| with pytest.raises(TypeError): | |||||
| with SummaryRecord(summary_dir) as sr: | |||||
| sr.record(step) | |||||