diff --git a/mindinsight/explainer/manager/explain_loader.py b/mindinsight/explainer/manager/explain_loader.py index 61ad80f3..84f8a618 100644 --- a/mindinsight/explainer/manager/explain_loader.py +++ b/mindinsight/explainer/manager/explain_loader.py @@ -20,6 +20,8 @@ import re from collections import defaultdict from datetime import datetime from typing import Dict, Iterable, List, Optional, Union +from enum import Enum +import threading from mindinsight.explainer.common.enums import ExplainFieldsEnum from mindinsight.explainer.common.log import logger @@ -44,6 +46,11 @@ _SAMPLE_FIELD_NAMES = [ ] +class _LoaderStatus(Enum): + STOP = 'STOP' + LOADING = 'LOADING' + + def _round(score): """Take round of a number to given precision.""" try: @@ -73,6 +80,9 @@ class ExplainLoader: self._metadata = {'explainers': [], 'metrics': [], 'labels': []} self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} + self._status = _LoaderStatus.STOP.value + self._status_mutex = threading.Lock() + @property def all_classes(self) -> List[Dict]: """ @@ -263,6 +273,7 @@ class ExplainLoader: def load(self): """Start loading data from the latest summary file to the loader.""" + self.status = _LoaderStatus.LOADING.value filenames = [] for filename in FileHandler.list_dir(self._loader_info['summary_dir']): if FileHandler.is_file(FileHandler.join(self._loader_info['summary_dir'], filename)): @@ -274,16 +285,32 @@ class ExplainLoader: % self._loader_info['summary_dir']) is_end = False - while not is_end: - is_clean, is_end, event_dict = self._parser.parse_explain(filenames) + while not is_end and self.status != _LoaderStatus.STOP.value: + file_changed, is_end, event_dict = self._parser.parse_explain(filenames) - if is_clean: + if file_changed: logger.info('Summary file in %s update, reload the data in the summary.', self._loader_info['summary_dir']) self._clear_job() if event_dict: self._import_data_from_event(event_dict) + @property + def status(self): + """Get the status of this class with lock.""" + with self._status_mutex: + return self._status + + @status.setter + def status(self, status): + """Set the status of this class with lock.""" + with self._status_mutex: + self._status = status + + def stop(self): + """Stop load data.""" + self.status = _LoaderStatus.STOP.value + def get_all_samples(self) -> List[Dict]: """ Return a list of sample information cachced in the explain job diff --git a/mindinsight/explainer/manager/explain_manager.py b/mindinsight/explainer/manager/explain_manager.py index 1cd41331..1a6dd17f 100644 --- a/mindinsight/explainer/manager/explain_manager.py +++ b/mindinsight/explainer/manager/explain_manager.py @@ -14,20 +14,21 @@ # ============================================================================ """ExplainManager.""" +from collections import OrderedDict + import os import threading import time -from collections import OrderedDict from datetime import datetime from typing import Optional from mindinsight.conf import settings from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import BaseEnum -from mindinsight.explainer.common.log import logger -from mindinsight.explainer.manager.explain_loader import ExplainLoader from mindinsight.datavisual.data_access.file_handler import FileHandler from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher +from mindinsight.explainer.common.log import logger +from mindinsight.explainer.manager.explain_loader import ExplainLoader from mindinsight.utils.exceptions import MindInsightException, ParamValueError, UnknownError _MAX_LOADERS_NUM = 3 @@ -37,8 +38,8 @@ class _ExplainManagerStatus(BaseEnum): """Manager status.""" INIT = 'INIT' LOADING = 'LOADING' + STOPPING = 'STOPPING' DONE = 'DONE' - INVALID = 'INVALID' class ExplainManager: @@ -49,6 +50,7 @@ class ExplainManager: self._loader_pool = OrderedDict() self._loading_status = _ExplainManagerStatus.INIT.value self._status_mutex = threading.Lock() + self._load_data_mutex = threading.Lock() self._loader_pool_mutex = threading.Lock() self._max_loaders_num = _MAX_LOADERS_NUM self._summary_watcher = SummaryWatcher() @@ -67,7 +69,7 @@ class ExplainManager: once. Default: 0. """ thread = threading.Thread(target=self._repeat_loading, - name='start_load_thread', + name='explainer.start_load_thread', args=(reload_interval,), daemon=True) time.sleep(1) @@ -127,48 +129,52 @@ class ExplainManager: """Periodically loading summary.""" while True: try: - logger.info('Start to load data, repeat interval: %r.', repeat_interval) - self._load_data() - if not repeat_interval: - return + if self.status == _ExplainManagerStatus.STOPPING.value: + logger.debug('Current loading status is %s, we will not trigger repeat loading.', + _ExplainManagerStatus.STOPPING.value) + else: + logger.info('Starts triggering repeat loading, repeat interval: %r.', repeat_interval) + self._load_data() + if not repeat_interval: + return time.sleep(repeat_interval) except UnknownError as ex: logger.error('Unexpected error happens when loading data. Loading status: %s, loading pool size: %d' - 'Detail: %s', self._loading_status, len(self._loader_pool), str(ex)) + 'Detail: %s', self.status, len(self._loader_pool), str(ex)) def _load_data(self): """ Prepare loaders in cache and start loading the data from summaries. Only a limited number of loaders will be cached in terms of updated_time or query_time. The size of cache - pool is determined by _MAX_LOADERS_NUM. When the manager start loading data, only the lastest _MAX_LOADER_NUM + pool is determined by _MAX_LOADERS_NUM. When the manager start loading data, only the latest _MAX_LOADER_NUM summaries will be loaded in cache. If a cached loader if queries by 'get_job', the query_time of the loader will be updated as well as the the loader moved to the end of cache. If an uncached summary is queried, a new loader instance will be generated and put to the end cache. """ try: - with self._status_mutex: - if self._loading_status == _ExplainManagerStatus.LOADING.value: - logger.info('Current status is %s, will ignore to load data.', self._loading_status) + with self._load_data_mutex: + if self.status == _ExplainManagerStatus.LOADING.value: + logger.info('Current status is %s, will ignore to load data.', self.status) return - self._loading_status = _ExplainManagerStatus.LOADING.value - + logger.info('Start to load data, and status change to %s.', _ExplainManagerStatus.LOADING.value) + self.status = _ExplainManagerStatus.LOADING.value self._cache_loaders() - self._execute_loading() - if not self._loader_pool: - self._loading_status = _ExplainManagerStatus.INVALID.value - else: - self._loading_status = _ExplainManagerStatus.DONE.value + if self.status == _ExplainManagerStatus.STOPPING.value: + logger.info('The manager status has been %s, will not execute loading.', self.status) + return + self._execute_loading() - logger.info('Load event data end, status: %s, and loader pool size: %d', - self._loading_status, len(self._loader_pool)) + logger.info('Load event data end, current status: %s, next status: %s, loader pool size: %d.', + self.status, _ExplainManagerStatus.DONE.value, len(self._loader_pool)) except Exception as ex: - self._loading_status = _ExplainManagerStatus.INVALID.value logger.exception(ex) raise UnknownError(str(ex)) + finally: + self.status = _ExplainManagerStatus.DONE.value def _cache_loaders(self): """Cache explain loader in cache pool.""" @@ -217,30 +223,36 @@ class ExplainManager: def _execute_loading(self): """Execute the data loading.""" - for loader_id in list(self._loader_pool.keys()): + # We will load the newest loader first. + for loader_id in list(self._loader_pool.keys())[::-1]: try: with self._loader_pool_mutex: loader = self._loader_pool.get(loader_id, None) if loader is None: - logger.debug('Loader %r has been deleted, will not load data', loader_id) - return + logger.debug('Loader %r has been deleted, will not load data.', loader_id) + continue + + if self.status == _ExplainManagerStatus.STOPPING.value: + logger.info('Loader %s status is %s, will return.', loader_id, loader.status) + return + loader.load() except MindInsightException as ex: - logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex) + logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s.', loader_id, ex) with self._loader_pool_mutex: self._delete_loader(loader_id) def _delete_loader(self, loader_id): - """delete loader given loader_id""" + """Delete loader given loader_id.""" if loader_id in self._loader_pool: self._loader_pool.pop(loader_id) - logger.debug('delete loader %s', loader_id) + logger.debug('delete loader %s, and stop this loader loading data.', loader_id) def _check_status_valid(self): """Check manager status.""" - if self._loading_status == _ExplainManagerStatus.INIT.value: - raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._loading_status) + if self.status == _ExplainManagerStatus.INIT.value: + raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self.status) def _check_summary_exist(self, loader_id): """Verify thee train_job is existed given loader_id.""" @@ -250,10 +262,44 @@ class ExplainManager: def _reload_data_again(self): """Reload the data one more time.""" logger.debug('Start to reload data again.') - thread = threading.Thread(target=self._load_data, name='reload_data_thread') + + def _wrapper(): + if self.status == _ExplainManagerStatus.STOPPING.value: + return + self._stop_load_data() + self._load_data() + + thread = threading.Thread(target=_wrapper, name='explainer.reload_data_thread') thread.daemon = False thread.start() + def _stop_load_data(self): + """Stop loading data, status changes to Stopping.""" + if self.status != _ExplainManagerStatus.LOADING.value: + return + + logger.info('Start to stop loading data, set status to %s.', _ExplainManagerStatus.STOPPING.value) + self.status = _ExplainManagerStatus.STOPPING.value + + for loader in self._loader_pool.values(): + loader.stop() + + while self.status != _ExplainManagerStatus.DONE.value: + continue + logger.info('Stop loading data end.') + + @property + def status(self): + """Get the status of this manager with lock.""" + with self._status_mutex: + return self._loading_status + + @status.setter + def status(self, status): + """Set the status of this manager with lock.""" + with self._status_mutex: + self._loading_status = status + @staticmethod def _generate_loader_id(relative_path): """Generate loader id for given path""" diff --git a/mindinsight/explainer/manager/explain_parser.py b/mindinsight/explainer/manager/explain_parser.py index 545550a8..f795e7a8 100644 --- a/mindinsight/explainer/manager/explain_parser.py +++ b/mindinsight/explainer/manager/explain_parser.py @@ -98,8 +98,8 @@ class ExplainParser(_SummaryParser): field_list, tensor_value_list = self._event_decode(event_str) for field, tensor_value in zip(field_list, tensor_value_list): event_data[field] = tensor_value - logger.info("Parse summary file offset %d, file path: %s.", self._summary_file_handler.offset, - file_path) + logger.debug("Parse summary file offset %d, file path: %s.", + self._summary_file_handler.offset, file_path) return is_clean, is_end, event_data except (exceptions.CRCFailedError, exceptions.CRCLengthFailedError) as ex: diff --git a/tests/ut/__init__.py b/tests/ut/__init__.py index 1b0b1781..d0c454b1 100644 --- a/tests/ut/__init__.py +++ b/tests/ut/__init__.py @@ -14,6 +14,6 @@ # ============================================================================ """Import the mocked mindspore.""" import sys -from ..utils import mindspore +from tests.utils import mindspore sys.modules['mindspore'] = mindspore diff --git a/tests/ut/explainer/__init__.py b/tests/ut/explainer/__init__.py new file mode 100644 index 00000000..9c3f0132 --- /dev/null +++ b/tests/ut/explainer/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UT for explainer.""" diff --git a/tests/ut/explainer/encapsulator/__init__.py b/tests/ut/explainer/encapsulator/__init__.py index dce42252..d6da6ea1 100644 --- a/tests/ut/explainer/encapsulator/__init__.py +++ b/tests/ut/explainer/encapsulator/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/ut/explainer/manager/__init__.py b/tests/ut/explainer/manager/__init__.py new file mode 100644 index 00000000..d67cdd73 --- /dev/null +++ b/tests/ut/explainer/manager/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UT for explainer.manager.""" diff --git a/tests/ut/explainer/manager/test_explain_loader.py b/tests/ut/explainer/manager/test_explain_loader.py new file mode 100644 index 00000000..724b6d6c --- /dev/null +++ b/tests/ut/explainer/manager/test_explain_loader.py @@ -0,0 +1,64 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UT for explainer.manager.explain_manager""" +import os +import threading +import time +from unittest.mock import patch + +from mindinsight.datavisual.data_access.file_handler import FileHandler +from mindinsight.explainer.manager.explain_loader import ExplainLoader +from mindinsight.explainer.manager.explain_loader import _LoaderStatus +from mindinsight.explainer.manager.explain_parser import ExplainParser + + +def abc(): + FileHandler.is_file('aaa') + print('after') + +class TestExplainLoader: + """Test explain loader class.""" + @patch.object(ExplainParser, 'parse_explain') + @patch.object(FileHandler, 'list_dir') + @patch.object(FileHandler, 'is_file') + @patch.object(os, 'stat') + def test_stop(self, mock_stat, mock_is_file, mock_list_dir, mock_parse_explain): + """Test stop function.""" + mock_is_file.return_value = True + mock_list_dir.return_value = ['events.summary.123.host_explain'] + mock_parse_explain.return_value = (True, False, None) + + class _MockStat: + def __init__(self, _): + self.st_ctime = 1 + self.st_mtime = 1 + self.st_size = 1 + + mock_stat.side_effect = _MockStat + + loader = ExplainLoader( + loader_id='./summary_dir', + summary_dir='./summary_dir') + + def _stop_loader(explain_loader): + time.sleep(0.01) + assert explain_loader.status == _LoaderStatus.LOADING.value + explain_loader.stop() + + thread = threading.Thread(target=_stop_loader, args=[loader], daemon=True) + thread.start() + + loader.load() + assert loader.status == _LoaderStatus.STOP.value diff --git a/tests/ut/explainer/manager/test_explain_manager.py b/tests/ut/explainer/manager/test_explain_manager.py new file mode 100644 index 00000000..e3858ebe --- /dev/null +++ b/tests/ut/explainer/manager/test_explain_manager.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""UT for explainer.manager.explain_loader.""" +import os +import threading +import time +from unittest.mock import patch + +from mindinsight.explainer.manager.explain_loader import ExplainLoader +from mindinsight.explainer.manager.explain_loader import _LoaderStatus +from mindinsight.explainer.manager.explain_manager import ExplainManager +from mindinsight.explainer.manager.explain_manager import _ExplainManagerStatus + + +class TestExplainManager: + """Test explain manager class.""" + + def test_stop_load_data_not_loading_status(self): + """Test stop load data when the status is not loading.""" + manager = ExplainManager('./summary_dir') + assert manager.status == _ExplainManagerStatus.INIT.value + + manager.status = _ExplainManagerStatus.DONE.value + manager._stop_load_data() + assert manager.status == _ExplainManagerStatus.DONE.value + + @patch.object(os, 'stat') + def test_stop_load_data_with_loading_status(self, mock_stat): + """Test stop load data with status is loading.""" + class _MockStat: + def __init__(self, _): + self.st_ctime = 1 + self.st_mtime = 1 + self.st_size = 1 + + mock_stat.side_effect = _MockStat + + manager = ExplainManager('./summary_dir') + manager.status = _ExplainManagerStatus.LOADING.value + loader_count = 3 + for i in range(loader_count): + loader = ExplainLoader(f'./summary_dir{i}', f'./summary_dir{i}') + loader.status = _LoaderStatus.LOADING.value + manager._loader_pool[i] = loader + + def _wrapper(loader_manager): + assert loader_manager.status == _ExplainManagerStatus.LOADING.value + time.sleep(0.01) + loader_manager.status = _ExplainManagerStatus.DONE.value + thread = threading.Thread(target=_wrapper, args=(manager,), daemon=True) + thread.start() + manager._stop_load_data() + for loader in manager._loader_pool.values(): + assert loader.status == _LoaderStatus.STOP.value + assert manager.status == _ExplainManagerStatus.DONE.value + + def test_stop_load_data_with_after_cache_loaders(self): + """ + Test stop load data that is triggered by get a not in loader pool job. + + In this case, we will mock the cache_loader function, and set status to STOP by other thread. + """ + manager = ExplainManager('./summary_dir') + + def _mock_cache_loaders(): + for _ in range(3): + time.sleep(0.1) + manager._cache_loaders = _mock_cache_loaders + load_data_thread = threading.Thread(target=manager._load_data, name='manager_load_data', daemon=True) + stop_thread = threading.Thread(target=manager._stop_load_data, name='stop_load_data', daemon=True) + load_data_thread.start() + while manager.status != _ExplainManagerStatus.LOADING.value: + continue + stop_thread.start() + stop_thread.join() + assert manager.status == _ExplainManagerStatus.DONE.value