Merge pull request !440 from wenkai/pref_opt_0720_1cp1tags/v0.6.0-beta
| @@ -34,11 +34,11 @@ class DataLoader: | |||
| self._summary_dir = summary_dir | |||
| self._loader = None | |||
| def load(self, workers_count=1): | |||
| def load(self, computing_resource_mgr): | |||
| """Load the data when loader is exist. | |||
| Args: | |||
| workers_count (int): The count of workers. Default value is 1. | |||
| computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance. | |||
| """ | |||
| if self._loader is None: | |||
| @@ -53,7 +53,7 @@ class DataLoader: | |||
| logger.warning("No valid files can be loaded, summary_dir: %s.", self._summary_dir) | |||
| raise exceptions.SummaryLogPathInvalid() | |||
| self._loader.load(workers_count) | |||
| self._loader.load(computing_resource_mgr) | |||
| def get_events_data(self): | |||
| """ | |||
| @@ -40,6 +40,7 @@ from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE | |||
| from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator | |||
| from mindinsight.utils.computing_resource_mgr import ComputingResourceManager | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.utils.exceptions import UnknownError | |||
| @@ -510,7 +511,7 @@ class _DetailCacheManager(_BaseCacheManager): | |||
| logger.debug("delete loader %s", loader_id) | |||
| self._loader_pool.pop(loader_id) | |||
| def _execute_loader(self, loader_id, workers_count): | |||
| def _execute_loader(self, loader_id, computing_resource_mgr): | |||
| """ | |||
| Load data form data_loader. | |||
| @@ -518,7 +519,7 @@ class _DetailCacheManager(_BaseCacheManager): | |||
| Args: | |||
| loader_id (str): An ID for `Loader`. | |||
| workers_count (int): The count of workers. | |||
| computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance. | |||
| """ | |||
| try: | |||
| with self._loader_pool_mutex: | |||
| @@ -527,7 +528,7 @@ class _DetailCacheManager(_BaseCacheManager): | |||
| logger.debug("Loader %r has been deleted, will not load data.", loader_id) | |||
| return | |||
| loader.data_loader.load(workers_count) | |||
| loader.data_loader.load(computing_resource_mgr) | |||
| # Update loader cache status to CACHED. | |||
| # Loader with cache status CACHED should remain the same cache status. | |||
| @@ -580,13 +581,17 @@ class _DetailCacheManager(_BaseCacheManager): | |||
| logger.info("Start to execute load data. threads_count: %s.", threads_count) | |||
| with ThreadPoolExecutor(max_workers=threads_count) as executor: | |||
| futures = [] | |||
| loader_pool = self._get_snapshot_loader_pool() | |||
| for loader_id in loader_pool: | |||
| future = executor.submit(self._execute_loader, loader_id, threads_count) | |||
| futures.append(future) | |||
| wait(futures, return_when=ALL_COMPLETED) | |||
| with ComputingResourceManager( | |||
| executors_cnt=threads_count, | |||
| max_processes_cnt=settings.MAX_PROCESSES_COUNT) as computing_resource_mgr: | |||
| with ThreadPoolExecutor(max_workers=threads_count) as executor: | |||
| futures = [] | |||
| loader_pool = self._get_snapshot_loader_pool() | |||
| for loader_id in loader_pool: | |||
| future = executor.submit(self._execute_loader, loader_id, computing_resource_mgr) | |||
| futures.append(future) | |||
| wait(futures, return_when=ALL_COMPLETED) | |||
| def _get_threads_count(self): | |||
| """ | |||
| @@ -19,17 +19,12 @@ This module is used to load the MindSpore training log file. | |||
| Each instance will read an entire run, a run can contain one or | |||
| more log file. | |||
| """ | |||
| import concurrent.futures as futures | |||
| import math | |||
| import os | |||
| import re | |||
| import struct | |||
| import threading | |||
| from google.protobuf.message import DecodeError | |||
| from google.protobuf.text_format import ParseError | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.common.log import logger | |||
| @@ -84,14 +79,14 @@ class MSDataLoader: | |||
| "we will reload all files in path %s.", self._summary_dir) | |||
| self.__init__(self._summary_dir) | |||
| def load(self, workers_count=1): | |||
| def load(self, computing_resource_mgr): | |||
| """ | |||
| Load all log valid files. | |||
| When the file is reloaded, it will continue to load from where it left off. | |||
| Args: | |||
| workers_count (int): The count of workers. Default value is 1. | |||
| computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance. | |||
| """ | |||
| logger.debug("Start to load data in ms data loader.") | |||
| filenames = self.filter_valid_files() | |||
| @@ -102,8 +97,9 @@ class MSDataLoader: | |||
| self._valid_filenames = filenames | |||
| self._check_files_deleted(filenames, old_filenames) | |||
| for parser in self._parser_list: | |||
| parser.parse_files(workers_count, filenames, events_data=self._events_data) | |||
| with computing_resource_mgr.get_executor() as executor: | |||
| for parser in self._parser_list: | |||
| parser.parse_files(executor, filenames, events_data=self._events_data) | |||
| def filter_valid_files(self): | |||
| """ | |||
| @@ -133,12 +129,12 @@ class _Parser: | |||
| self._latest_mtime = 0 | |||
| self._summary_dir = summary_dir | |||
| def parse_files(self, workers_count, filenames, events_data): | |||
| def parse_files(self, executor, filenames, events_data): | |||
| """ | |||
| Load files and parse files content. | |||
| Args: | |||
| workers_count (int): The count of workers. | |||
| executor (Executor): The executor instance. | |||
| filenames (list[str]): File name list. | |||
| events_data (EventsData): The container of event data. | |||
| """ | |||
| @@ -186,7 +182,7 @@ class _Parser: | |||
| class _PbParser(_Parser): | |||
| """This class is used to parse pb file.""" | |||
| def parse_files(self, workers_count, filenames, events_data): | |||
| def parse_files(self, executor, filenames, events_data): | |||
| pb_filenames = self.filter_files(filenames) | |||
| pb_filenames = self.sort_files(pb_filenames) | |||
| for filename in pb_filenames: | |||
| @@ -264,12 +260,12 @@ class _SummaryParser(_Parser): | |||
| self._summary_file_handler = None | |||
| self._events_data = None | |||
| def parse_files(self, workers_count, filenames, events_data): | |||
| def parse_files(self, executor, filenames, events_data): | |||
| """ | |||
| Load summary file and parse file content. | |||
| Args: | |||
| workers_count (int): The count of workers. | |||
| executor (Executor): The executor instance. | |||
| filenames (list[str]): File name list. | |||
| events_data (EventsData): The container of event data. | |||
| """ | |||
| @@ -295,7 +291,9 @@ class _SummaryParser(_Parser): | |||
| self._latest_file_size = new_size | |||
| try: | |||
| self._load_single_file(self._summary_file_handler, workers_count) | |||
| self._load_single_file(self._summary_file_handler, executor) | |||
| # Wait for data in this file to be processed to avoid loading multiple files at the same time. | |||
| executor.wait_all_tasks_finish() | |||
| except UnknownError as ex: | |||
| logger.warning("Parse summary file failed, detail: %r," | |||
| "file path: %s.", str(ex), file_path) | |||
| @@ -314,75 +312,57 @@ class _SummaryParser(_Parser): | |||
| lambda filename: (re.search(r'summary\.\d+', filename) | |||
| and not filename.endswith("_lineage")), filenames)) | |||
| def _load_single_file(self, file_handler, workers_count): | |||
| def _load_single_file(self, file_handler, executor): | |||
| """ | |||
| Load a log file data. | |||
| Args: | |||
| file_handler (FileHandler): A file handler. | |||
| workers_count (int): The count of workers. | |||
| executor (Executor): The executor instance. | |||
| """ | |||
| default_concurrency = 1 | |||
| cpu_count = os.cpu_count() | |||
| if cpu_count is None: | |||
| concurrency = default_concurrency | |||
| else: | |||
| concurrency = min(math.floor(cpu_count / workers_count), | |||
| math.floor(settings.MAX_PROCESSES_COUNT / workers_count)) | |||
| if concurrency <= 0: | |||
| concurrency = default_concurrency | |||
| logger.debug("Load single summary file, file path: %s, concurrency: %s.", file_handler.file_path, concurrency) | |||
| semaphore = threading.Semaphore(value=concurrency) | |||
| with futures.ProcessPoolExecutor(max_workers=concurrency) as executor: | |||
| while True: | |||
| start_offset = file_handler.offset | |||
| try: | |||
| event_str = self._event_load(file_handler) | |||
| if event_str is None: | |||
| file_handler.reset_offset(start_offset) | |||
| break | |||
| # Make sure we have at most concurrency tasks not finished to save memory. | |||
| semaphore.acquire() | |||
| future = executor.submit(self._event_parse, event_str, self._latest_filename) | |||
| def _add_tensor_event_callback(future_value): | |||
| try: | |||
| tensor_values = future_value.result() | |||
| for tensor_value in tensor_values: | |||
| if tensor_value.plugin_name == PluginNameEnum.GRAPH.value: | |||
| try: | |||
| graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) | |||
| except KeyError: | |||
| graph_tags = [] | |||
| summary_tags = self.filter_files(graph_tags) | |||
| for tag in summary_tags: | |||
| self._events_data.delete_tensor_event(tag) | |||
| self._events_data.add_tensor_event(tensor_value) | |||
| except Exception as exc: | |||
| # Log exception for debugging. | |||
| logger.exception(exc) | |||
| raise | |||
| finally: | |||
| semaphore.release() | |||
| future.add_done_callback(_add_tensor_event_callback) | |||
| except exceptions.CRCFailedError: | |||
| while True: | |||
| start_offset = file_handler.offset | |||
| try: | |||
| event_str = self._event_load(file_handler) | |||
| if event_str is None: | |||
| file_handler.reset_offset(start_offset) | |||
| logger.warning("Check crc faild and ignore this file, file_path=%s, " | |||
| "offset=%s.", file_handler.file_path, file_handler.offset) | |||
| break | |||
| except (OSError, DecodeError, exceptions.MindInsightException) as ex: | |||
| logger.warning("Parse log file fail, and ignore this file, detail: %r," | |||
| "file path: %s.", str(ex), file_handler.file_path) | |||
| break | |||
| except Exception as ex: | |||
| logger.exception(ex) | |||
| raise UnknownError(str(ex)) | |||
| future = executor.submit(self._event_parse, event_str, self._latest_filename) | |||
| def _add_tensor_event_callback(future_value): | |||
| try: | |||
| tensor_values = future_value.result() | |||
| for tensor_value in tensor_values: | |||
| if tensor_value.plugin_name == PluginNameEnum.GRAPH.value: | |||
| try: | |||
| graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) | |||
| except KeyError: | |||
| graph_tags = [] | |||
| summary_tags = self.filter_files(graph_tags) | |||
| for tag in summary_tags: | |||
| self._events_data.delete_tensor_event(tag) | |||
| self._events_data.add_tensor_event(tensor_value) | |||
| except Exception as exc: | |||
| # Log exception for debugging. | |||
| logger.exception(exc) | |||
| raise | |||
| future.add_done_callback(_add_tensor_event_callback) | |||
| except exceptions.CRCFailedError: | |||
| file_handler.reset_offset(start_offset) | |||
| logger.warning("Check crc faild and ignore this file, file_path=%s, " | |||
| "offset=%s.", file_handler.file_path, file_handler.offset) | |||
| break | |||
| except (OSError, DecodeError, exceptions.MindInsightException) as ex: | |||
| logger.warning("Parse log file fail, and ignore this file, detail: %r," | |||
| "file path: %s.", str(ex), file_handler.file_path) | |||
| break | |||
| except Exception as ex: | |||
| logger.exception(ex) | |||
| raise UnknownError(str(ex)) | |||
| def _event_load(self, file_handler): | |||
| """ | |||
| @@ -213,7 +213,7 @@ class HistogramReservoir(Reservoir): | |||
| visual_range.update(histogram_container.max, histogram_container.min) | |||
| if visual_range.max == visual_range.min and not max_count: | |||
| logger.info("Max equals to min. Count is zero.") | |||
| logger.debug("Max equals to min. Count is zero.") | |||
| bins = calc_histogram_bins(max_count) | |||
| @@ -0,0 +1,261 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Compute resource manager.""" | |||
| import fractions | |||
| import math | |||
| import threading | |||
| from concurrent import futures | |||
| from mindinsight.utils.log import utils_logger as logger | |||
| from mindinsight.utils.constant import GeneralErrors | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| class ComputingResourceManager: | |||
| """ | |||
| Manager for computing resources. | |||
| This class provides executors for computing tasks. Executors can only be used once. | |||
| Args: | |||
| executors_cnt (int): Number of executors to be provided by this class. | |||
| max_processes_cnt (int): Max number of processes to be used for computing. | |||
| """ | |||
| def __init__(self, executors_cnt, max_processes_cnt): | |||
| self._max_processes_cnt = max_processes_cnt | |||
| self._executors_cnt = executors_cnt | |||
| self._lock = threading.Lock() | |||
| self._executors = { | |||
| ind: Executor( | |||
| self, executor_id=ind, | |||
| available_workers=fractions.Fraction(self._max_processes_cnt, self._executors_cnt)) | |||
| for ind in range(self._executors_cnt) | |||
| } | |||
| self._remaining_executors = len(self._executors) | |||
| self._backend = futures.ProcessPoolExecutor(max_workers=max_processes_cnt) | |||
| logger.info("Initialized ComputingResourceManager with executors_cnt=%s, max_processes_cnt=%s.", | |||
| executors_cnt, max_processes_cnt) | |||
| def __enter__(self): | |||
| """This method is not thread safe.""" | |||
| return self | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| """ | |||
| This should not block because every executor have waited. If it blocks, there may be some problem. | |||
| This method is not thread safe. | |||
| """ | |||
| self._backend.shutdown() | |||
| def get_executor(self): | |||
| """ | |||
| Get an executor. | |||
| Returns: | |||
| Executor, which can be used for submitting tasks. | |||
| Raises: | |||
| ComputeResourceManagerException: when no more executor is available. | |||
| """ | |||
| with self._lock: | |||
| self._remaining_executors -= 1 | |||
| if self._remaining_executors < 0: | |||
| raise ComputingResourceManagerException("No more executors.") | |||
| return self._executors[self._remaining_executors] | |||
| def destroy_executor(self, executor_id): | |||
| """ | |||
| Destroy an executor to reuse it's workers. | |||
| Args: | |||
| executor_id (int): Id of the executor to be destroyed. | |||
| """ | |||
| with self._lock: | |||
| released_workers = self._executors[executor_id].available_workers | |||
| self._executors.pop(executor_id) | |||
| remaining_executors = len(self._executors) | |||
| logger.info("Destroy executor %s. Will release %s worker(s). Remaining executors: %s.", | |||
| executor_id, released_workers, remaining_executors) | |||
| if not remaining_executors: | |||
| return | |||
| for executor in self._executors.values(): | |||
| executor.add_worker( | |||
| fractions.Fraction( | |||
| released_workers.numerator, | |||
| released_workers.denominator * remaining_executors)) | |||
| def submit(self, *args, **kwargs): | |||
| """ | |||
| Submit a task. | |||
| See concurrent.futures.Executor.submit() for details. | |||
| This method should only be called by Executor. Users should not call this method directly. | |||
| """ | |||
| with self._lock: | |||
| return self._backend.submit(*args, **kwargs) | |||
| class ComputingResourceManagerException(MindInsightException): | |||
| """ | |||
| Indicates a computing resource error has occurred. | |||
| This exception should not be presented to end users. | |||
| Args: | |||
| msg (str): Exception message. | |||
| """ | |||
| def __init__(self, msg): | |||
| super().__init__(error=GeneralErrors.COMPUTING_RESOURCE_ERROR, message=msg) | |||
| class WrappedFuture: | |||
| """ | |||
| Wrap Future objects with custom logics to release compute slots. | |||
| Args: | |||
| executor (Executor): The executor which generates this future. | |||
| original_future (futures.Future): Original future object. | |||
| """ | |||
| def __init__(self, executor, original_future: futures.Future): | |||
| self._original_future = original_future | |||
| self._executor = executor | |||
| def add_done_callback(self, callback): | |||
| """ | |||
| Add done callback. | |||
| See futures.Future.add_done_callback() for details. | |||
| """ | |||
| def _wrapped_callback(*args, **kwargs): | |||
| logger.debug("Future callback called.") | |||
| try: | |||
| return callback(*args, **kwargs) | |||
| finally: | |||
| self._executor.release_slot() | |||
| self._executor.remove_done_future(self._original_future) | |||
| self._original_future.add_done_callback(_wrapped_callback) | |||
| class Executor: | |||
| """ | |||
| Task executor. | |||
| Args: | |||
| mgr (ComputingResourceManager): The ComputingResourceManager that generates this executor. | |||
| executor_id (int): Executor id. | |||
| available_workers (fractions.Fraction): Available workers. | |||
| """ | |||
| def __init__(self, mgr: ComputingResourceManager, executor_id, available_workers): | |||
| self._mgr = mgr | |||
| self.closed = False | |||
| self._available_workers = available_workers | |||
| self._effective_workers = self._calc_effective_workers(self._available_workers) | |||
| self._slots = threading.Semaphore(value=self._effective_workers) | |||
| self._id = executor_id | |||
| self._futures = set() | |||
| self._lock = threading.Lock() | |||
| logger.debug("Available workers: %s.", available_workers) | |||
| def __enter__(self): | |||
| """This method is not thread safe.""" | |||
| if self.closed: | |||
| raise ComputingResourceManagerException("Can not reopen closed executor.") | |||
| return self | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| """This method is not thread safe.""" | |||
| self._close() | |||
| def submit(self, *args, **kwargs): | |||
| """ | |||
| Submit task. | |||
| See concurrent.futures.Executor.submit() for details. This method is not thread safe. | |||
| """ | |||
| logger.debug("Task submitted to executor %s.", self._id) | |||
| if self.closed: | |||
| raise ComputingResourceManagerException("Cannot submit task to a closed executor.") | |||
| # Thread will wait on acquire(). | |||
| self._slots.acquire() | |||
| future = self._mgr.submit(*args, **kwargs) | |||
| # set.add is atomic in c-python. | |||
| self._futures.add(future) | |||
| return WrappedFuture(self, future) | |||
| def release_slot(self): | |||
| """ | |||
| Release a slot for new tasks to be submitted. | |||
| Semaphore is itself thread safe, so no lock is needed. | |||
| This method should only be called by ExecutorFuture. | |||
| """ | |||
| self._slots.release() | |||
| def remove_done_future(self, future): | |||
| """ | |||
| Remove done futures so the executor will not track them. | |||
| This method should only be called by WrappedFuture. | |||
| """ | |||
| # set.remove is atomic in c-python so no lock is needed. | |||
| self._futures.remove(future) | |||
| @staticmethod | |||
| def _calc_effective_workers(available_workers): | |||
| return 1 if available_workers <= 1 else math.floor(available_workers) | |||
| def _close(self): | |||
| self.closed = True | |||
| logger.debug("Executor is being closed, futures to wait: %s", self._futures) | |||
| futures.wait(self._futures) | |||
| logger.debug("Executor wait futures completed.") | |||
| self._mgr.destroy_executor(self._id) | |||
| logger.debug("Executor is closed.") | |||
| @property | |||
| def available_workers(self): | |||
| """Get available workers.""" | |||
| with self._lock: | |||
| return self._available_workers | |||
| def add_worker(self, added_available_workers): | |||
| """This method should only be called by ComputeResourceManager.""" | |||
| logger.debug("Add worker: %s", added_available_workers) | |||
| with self._lock: | |||
| self._available_workers += added_available_workers | |||
| new_effective_workers = self._calc_effective_workers(self._available_workers) | |||
| if new_effective_workers > self._effective_workers: | |||
| for _ in range(new_effective_workers - self._effective_workers): | |||
| self._slots.release() | |||
| self._effective_workers = new_effective_workers | |||
| def wait_all_tasks_finish(self): | |||
| """ | |||
| Wait all tasks finish. | |||
| This method is not thread safe. | |||
| """ | |||
| futures.wait(self._futures) | |||
| @@ -43,6 +43,7 @@ class GeneralErrors(Enum): | |||
| FILE_SYSTEM_PERMISSION_ERROR = 8 | |||
| PORT_NOT_AVAILABLE_ERROR = 9 | |||
| URL_DECODE_ERROR = 10 | |||
| COMPUTING_RESOURCE_ERROR = 11 | |||
| class ProfilerMgrErrors(Enum): | |||
| @@ -224,3 +224,6 @@ def setup_logger(sub_module, log_name, **kwargs): | |||
| logger.addHandler(logfile_handler) | |||
| return logger | |||
| utils_logger = setup_logger("utils", "utils") | |||
| @@ -27,6 +27,7 @@ import pytest | |||
| from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid | |||
| from mindinsight.datavisual.data_transform import data_loader | |||
| from mindinsight.datavisual.data_transform.data_loader import DataLoader | |||
| from mindinsight.utils.computing_resource_mgr import ComputingResourceManager | |||
| from ..mock import MockLogger | |||
| @@ -57,7 +58,7 @@ class TestDataLoader: | |||
| """Test loading method with empty file list.""" | |||
| loader = DataLoader(self._summary_dir) | |||
| with pytest.raises(SummaryLogPathInvalid): | |||
| loader.load() | |||
| loader.load(ComputingResourceManager(1, 1)) | |||
| assert 'No valid files can be loaded' in str(MockLogger.log_msg['warning']) | |||
| def test_load_with_invalid_file_list(self): | |||
| @@ -66,7 +67,7 @@ class TestDataLoader: | |||
| self._generate_files(self._summary_dir, file_list) | |||
| loader = DataLoader(self._summary_dir) | |||
| with pytest.raises(SummaryLogPathInvalid): | |||
| loader.load() | |||
| loader.load(ComputingResourceManager(1, 1)) | |||
| assert 'No valid files can be loaded' in str(MockLogger.log_msg['warning']) | |||
| def test_load_success(self): | |||
| @@ -77,6 +78,6 @@ class TestDataLoader: | |||
| file_list = ['summary.001', 'summary.002'] | |||
| self._generate_files(dir_path, file_list) | |||
| dataloader = DataLoader(dir_path) | |||
| dataloader.load() | |||
| dataloader.load(ComputingResourceManager(1, 1)) | |||
| assert dataloader._loader is not None | |||
| shutil.rmtree(dir_path) | |||
| @@ -30,6 +30,7 @@ from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader | |||
| from mindinsight.datavisual.data_transform.ms_data_loader import _PbParser | |||
| from mindinsight.datavisual.data_transform.events_data import TensorEvent | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.utils.computing_resource_mgr import ComputingResourceManager | |||
| from ..mock import MockLogger | |||
| from ....utils.log_generators.graph_pb_generator import create_graph_pb_file | |||
| @@ -85,7 +86,7 @@ class TestMsDataLoader: | |||
| write_file(file1, SCALAR_RECORD) | |||
| ms_loader = MSDataLoader(summary_dir) | |||
| ms_loader._latest_summary_filename = 'summary.00' | |||
| ms_loader.load() | |||
| ms_loader.load(ComputingResourceManager(1, 1)) | |||
| shutil.rmtree(summary_dir) | |||
| tag = ms_loader.get_events_data().list_tags_by_plugin('scalar') | |||
| tensors = ms_loader.get_events_data().tensors(tag[0]) | |||
| @@ -98,7 +99,7 @@ class TestMsDataLoader: | |||
| file2 = os.path.join(summary_dir, 'summary.02') | |||
| write_file(file2, SCALAR_RECORD) | |||
| ms_loader = MSDataLoader(summary_dir) | |||
| ms_loader.load() | |||
| ms_loader.load(ComputingResourceManager(1, 1)) | |||
| shutil.rmtree(summary_dir) | |||
| assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning']) | |||
| @@ -124,7 +125,7 @@ class TestMsDataLoader: | |||
| summary_dir = tempfile.mkdtemp() | |||
| create_graph_pb_file(output_dir=summary_dir, filename=filename) | |||
| ms_loader = MSDataLoader(summary_dir) | |||
| ms_loader.load() | |||
| ms_loader.load(ComputingResourceManager(1, 1)) | |||
| events_data = ms_loader.get_events_data() | |||
| plugins = events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) | |||
| shutil.rmtree(summary_dir) | |||