# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """UT for explainer.manager.explain_loader.""" import os import threading import time from unittest.mock import patch from mindinsight.explainer.manager.explain_loader import ExplainLoader from mindinsight.explainer.manager.explain_loader import _LoaderStatus from mindinsight.explainer.manager.explain_manager import ExplainManager from mindinsight.explainer.manager.explain_manager import _ExplainManagerStatus class TestExplainManager: """Test explain manager class.""" def test_stop_load_data_not_loading_status(self): """Test stop load data when the status is not loading.""" manager = ExplainManager('./summary_dir') assert manager.status == _ExplainManagerStatus.INIT.value manager.status = _ExplainManagerStatus.DONE.value manager._stop_load_data() assert manager.status == _ExplainManagerStatus.DONE.value @patch.object(os, 'stat') def test_stop_load_data_with_loading_status(self, mock_stat): """Test stop load data with status is loading.""" class _MockStat: def __init__(self, _): self.st_ctime = 1 self.st_mtime = 1 self.st_size = 1 mock_stat.side_effect = _MockStat manager = ExplainManager('./summary_dir') manager.status = _ExplainManagerStatus.LOADING.value loader_count = 3 for i in range(loader_count): loader = ExplainLoader(f'./summary_dir{i}', f'./summary_dir{i}') loader.status = _LoaderStatus.LOADING.value manager._loader_pool[i] = loader def _wrapper(loader_manager): assert loader_manager.status == _ExplainManagerStatus.LOADING.value time.sleep(0.01) loader_manager.status = _ExplainManagerStatus.DONE.value thread = threading.Thread(target=_wrapper, args=(manager,), daemon=True) thread.start() manager._stop_load_data() for loader in manager._loader_pool.values(): assert loader.status == _LoaderStatus.STOP.value assert manager.status == _ExplainManagerStatus.DONE.value def test_stop_load_data_with_after_cache_loaders(self): """ Test stop load data that is triggered by get a not in loader pool job. In this case, we will mock the cache_loader function, and set status to STOP by other thread. """ manager = ExplainManager('./summary_dir') def _mock_cache_loaders(): for _ in range(3): time.sleep(0.1) manager._cache_loaders = _mock_cache_loaders load_data_thread = threading.Thread(target=manager._load_data, name='manager_load_data', daemon=True) stop_thread = threading.Thread(target=manager._stop_load_data, name='stop_load_data', daemon=True) load_data_thread.start() while manager.status != _ExplainManagerStatus.LOADING.value: continue stop_thread.start() stop_thread.join() assert manager.status == _ExplainManagerStatus.DONE.value