From 225bf22efb52ec372d07854dcee54315c0f0e85c Mon Sep 17 00:00:00 2001 From: YuhanShi53 Date: Sun, 25 Oct 2020 16:07:10 +0800 Subject: [PATCH] add ExplainJob, ExplainManager and EventParser for XAI backend --- mindinsight/explainer/manager/event_parse.py | 180 ++++++++ mindinsight/explainer/manager/explain_job.py | 395 ++++++++++++++++++ .../explainer/manager/explain_manager.py | 314 ++++++++++++++ .../explainer/manager/explain_parser.py | 33 +- 4 files changed, 909 insertions(+), 13 deletions(-) create mode 100644 mindinsight/explainer/manager/event_parse.py create mode 100644 mindinsight/explainer/manager/explain_job.py create mode 100644 mindinsight/explainer/manager/explain_manager.py diff --git a/mindinsight/explainer/manager/event_parse.py b/mindinsight/explainer/manager/event_parse.py new file mode 100644 index 00000000..6d0eeade --- /dev/null +++ b/mindinsight/explainer/manager/event_parse.py @@ -0,0 +1,180 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""EventParser for summary event.""" +from collections import namedtuple +from typing import Dict, Iterable, List, Optional, Tuple + +from mindinsight.explainer.common.enums import PluginNameEnum +from mindinsight.explainer.common.log import logger +from mindinsight.utils.exceptions import UnknownError + +_IMAGE_DATA_TAGS = { + 'image_data': PluginNameEnum.IMAGE_DATA.value, + 'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value, + 'inference': PluginNameEnum.INFERENCE.value, + 'explanation': PluginNameEnum.EXPLANATION.value +} + + +class EventParser: + """Parser for event data.""" + + def __init__(self, job): + self._job = job + self._sample_pool = {} + + def clear(self): + """Clear the loaded data.""" + self._sample_pool.clear() + + def parse_metadata(self, metadata) -> Tuple[List, List, List]: + """Parse the metadata event.""" + explainers = list(metadata.explain_method) + metrics = list(metadata.benchmark_method) + labels = list(metadata.label) + return explainers, metrics, labels + + def parse_benchmark(self, benchmark) -> Dict: + """Parse the benchmark event.""" + imported_benchmark = {} + for explainer_result in benchmark: + explainer = explainer_result.explain_method + total_score = explainer_result.total_score + label_score = explainer_result.label_score + + explainer_benchmark = { + 'explainer': explainer, + 'evaluations': EventParser._total_score_to_dict(total_score), + 'class_scores': EventParser._label_score_to_dict(label_score, self._job.labels) + } + imported_benchmark[explainer] = explainer_benchmark + return imported_benchmark + + def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]: + """Parse the sample event.""" + sample_id = sample.image_id + + if sample_id not in self._sample_pool: + self._sample_pool[sample_id] = sample + return None + + for tag in _IMAGE_DATA_TAGS: + try: + if tag == PluginNameEnum.INFERENCE.value: + self._parse_inference(sample, sample_id) + elif tag == PluginNameEnum.EXPLANATION.value: + self._parse_explanation(sample, sample_id) + else: + self._parse_sample_info(sample, sample_id, tag) + except UnknownError as ex: + logger.warning("Parse %s data failed within image related data," + " detail: %r", tag, str(ex)) + continue + + if EventParser._is_sample_data_complete(self._sample_pool[sample_id]): + return self._sample_pool.pop(sample_id) + if EventParser._is_ready_for_display(self._sample_pool[sample_id]): + return self._sample_pool[sample_id] + return None + + def _parse_inference(self, event, sample_id): + """Parse the inference event.""" + self._sample_pool[sample_id].inference.ground_truth_prob.extend( + event.inference.ground_truth_prob) + self._sample_pool[sample_id].inference.predicted_label.extend( + event.inference.predicted_label) + self._sample_pool[sample_id].inference.predicted_prob.extend( + event.inference.predicted_prob) + + def _parse_explanation(self, event, sample_id): + """Parse the explanation event.""" + if event.explanation: + for explanation_item in event.explanation: + new_explanation = self._sample_pool[sample_id].explanation.add() + new_explanation.explain_method = explanation_item.explain_method + new_explanation.label = explanation_item.label + new_explanation.heatmap = explanation_item.heatmap + + def _parse_sample_info(self, event, sample_id, tag): + """Parse the event containing image info.""" + if not getattr(self._sample_pool[sample_id], tag): + setattr(self._sample_pool[sample_id], tag, getattr(event, tag)) + + @staticmethod + def _total_score_to_dict(total_scores: Iterable): + """Transfer a list of benchmark score to a list of dict.""" + evaluation_info = [] + for total_score in total_scores: + metric_result = { + 'metric': total_score.benchmark_method, + 'score': total_score.score} + evaluation_info.append(metric_result) + return evaluation_info + + @staticmethod + def _label_score_to_dict(label_scores: Iterable, labels: List[str]): + """Transfer a list of benchmark score.""" + evaluation_info = [{'label': label, 'evaluations': []} + for label in labels] + for label_score in label_scores: + metric = label_score.benchmark_method + for i, score in enumerate(label_score.score): + label_metric_score = { + 'metric': metric, + 'score': score} + evaluation_info[i]['evaluations'].append(label_metric_score) + return evaluation_info + + @staticmethod + def _is_sample_data_complete(image_container: namedtuple) -> bool: + """Check whether sample data completely loaded.""" + required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference', 'explanation'] + for attr in required_attrs: + if not EventParser.is_attr_ready(image_container, attr): + return False + return True + + @staticmethod + def _is_ready_for_display(image_container: namedtuple) -> bool: + """ + Check whether the image_container is ready for frontend display. + + Args: + image_container (nametuple): container consists of sample data + + Return: + bool: whether the image_container if ready for display + """ + required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference'] + for attr in required_attrs: + if not EventParser.is_attr_ready(image_container, attr): + return False + return True + + @staticmethod + def is_attr_ready(image_container: namedtuple, attr: str) -> bool: + """ + Check whether the given attribute is ready in image_container. + + Args: + image_container (nametuple): container consist of sample data + attr (str): attribute to check + + Returns: + bool, whether the attr is ready + """ + if getattr(image_container, attr, False): + return True + return False diff --git a/mindinsight/explainer/manager/explain_job.py b/mindinsight/explainer/manager/explain_job.py new file mode 100644 index 00000000..5583628e --- /dev/null +++ b/mindinsight/explainer/manager/explain_job.py @@ -0,0 +1,395 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ExplainJob.""" + +import os +from datetime import datetime +from typing import List, Iterable, Union + +from mindinsight.explainer.common.enums import PluginNameEnum +from mindinsight.explainer.common.log import logger +from mindinsight.explainer.manager.explain_parser import _ExplainParser +from mindinsight.explainer.manager.event_parse import EventParser +from mindinsight.datavisual.data_access.file_handler import FileHandler +from mindinsight.datavisual.common.exceptions import TrainJobNotExistError + + +class ExplainJob: + """ExplainJob which manage the record in the summary file.""" + + def __init__(self, + job_id: str, + summary_dir: str, + create_time: float, + latest_update_time: float): + + self._job_id = job_id + self._summary_dir = summary_dir + self._parser = _ExplainParser(summary_dir) + + self._event_parser = EventParser(self) + self._latest_update_time = latest_update_time + self._create_time = create_time + self._labels = [] + self._metrics = [] + self._explainers = [] + self._samples_info = {} + self._labels_info = {} + self._benchmark = {} + self._overlay_dict = {} + self._image_dict = {} + + @property + def all_classes(self): + """ + Return a list of label info + + Returns: + class_objs (List[ClassObj]): a list of class_objects, each object + contains: + + - id (int): label id + - label (str): label name + - sample_count (int): number of samples for each label + """ + all_classes_return = [] + for label_id, label_info in self._labels_info.items(): + single_info = {'id': label_id, + 'label': label_info['label'], + 'sample_count': len(label_info['sample_ids'])} + all_classes_return.append(single_info) + return all_classes_return + + @property + def explainers(self): + """ + Return a list of explainer names + + Returns: + list(str), explainer names + """ + return self._explainers + + @property + def explainer_scores(self): + """Return evaluation results for every explainer.""" + return [score for score in self._benchmark.values()] + + @property + def sample_count(self): + """ + Return total number of samples in the job. + + Return: + int, total number of samples + + """ + return len(self._samples_info) + + @property + def train_id(self): + """ + Return ID of explain job + + Returns: + str, id of ExplainJob object + """ + return self._job_id + + @property + def metrics(self): + """ + Return a list of metric names + + Returns: + list(str), metric names + """ + return self._metrics + + @property + def min_confidence(self): + """ + Return minimum confidence + + Returns: + min_confidence (float): + """ + return None + + @property + def create_time(self): + """ + Return the create time of summary file + + Returns: + creation timestamp (float) + + """ + return self._create_time + + @property + def labels(self): + """Return the label contained in the job.""" + return self._labels + + @property + def latest_update_time(self): + """ + Return last modification time stamp of summary file. + + Returns: + float, last_modification_time stamp + """ + return self._latest_update_time + + @latest_update_time.setter + def latest_update_time(self, new_time: Union[float, datetime]): + """ + Update the latest_update_time timestamp manually. + + Args: + new_time stamp (union[float, datetime]): updated time for the job + """ + if isinstance(new_time, datetime): + self._latest_update_time = new_time.timestamp() + elif isinstance(new_time, str): + self._latest_update_time = new_time + else: + raise TypeError('new_time should have type of str or datetime') + + @property + def loader_id(self): + """Return the job id.""" + return self._job_id + + @property + def samples(self): + """Return the information of all samples in the job.""" + return self._samples_info + + @staticmethod + def get_create_time(file_path: str) -> float: + """Return timestamp of create time of specific path.""" + create_time = os.stat(file_path).st_ctime + return create_time + + @staticmethod + def get_update_time(file_path: str) -> float: + """Return timestamp of update time of specific path.""" + update_time = os.stat(file_path).st_mtime + return update_time + + @staticmethod + def _total_score_to_dict(total_scores: Iterable): + """Transfer a list of benchmark score to a list of dict.""" + evaluation_info = [] + for total_score in total_scores: + metric_result = {'metric': total_score.benchmark_method, + 'score': total_score.score} + evaluation_info.append(metric_result) + return evaluation_info + + @staticmethod + def _label_score_to_dict(label_scores: Iterable, labels: List[str]): + """Transfer a list of benchmark score.""" + evaluation_info = [{'label': label, 'evaluations': []} + for label in labels] + for label_score in label_scores: + metric = label_score.benchmark_method + for i, score in enumerate(label_score.score): + label_metric_score = dict() + label_metric_score['metric'] = metric + label_metric_score['score'] = score + evaluation_info[i]['evaluations'].append(label_metric_score) + return evaluation_info + + def _initialize_labels_info(self): + """Initialize a dict for labels in the job.""" + if self._labels is None: + logger.warning('No labels is provided in job %s', self._job_id) + return + + for label_id, label in enumerate(self._labels): + self._labels_info[label_id] = {'label': label, + 'sample_ids': set()} + + def _explanation_to_dict(self, explanation, sample_id): + """Transfer the explanation from event to dict storage.""" + explainer_name = explanation.explain_method + explain_label = explanation.label + saliency = explanation.heatmap + saliency_id = '{}_{}_{}'.format( + sample_id, explain_label, explainer_name) + explain_info = { + 'explainer': explainer_name, + 'overlay': saliency_id, + } + self._overlay_dict[saliency_id] = saliency + return explain_info + + def _image_container_to_dict(self, sample_data): + """Transfer the image container to dict storage.""" + sample_id = sample_data.image_id + + sample_info = { + 'id': sample_id, + 'name': sample_id, + 'labels': [self._labels_info[x]['label'] + for x in sample_data.ground_truth_label], + 'inferences': []} + self._image_dict[sample_id] = sample_data.image_data + + ground_truth_labels = list(sample_data.ground_truth_label) + ground_truth_probs = list(sample_data.inference.ground_truth_prob) + predicted_labels = list(sample_data.inference.predicted_label) + predicted_probs = list(sample_data.inference.predicted_prob) + + inference_info = {} + for label, prob in zip( + ground_truth_labels + predicted_labels, + ground_truth_probs + predicted_probs): + inference_info[label] = { + 'label': self._labels_info[label]['label'], + 'confidence': prob, + 'saliency_maps': []} + + if EventParser.is_attr_ready(sample_data, 'explanation'): + for explanation in sample_data.explanation: + explanation_dict = self._explanation_to_dict( + explanation, sample_id) + inference_info[explanation.label]['saliency_maps'].append(explanation_dict) + + sample_info['inferences'] = list(inference_info.values()) + return sample_info + + def _import_sample(self, sample): + """Add sample object of given sample id.""" + for label_id in sample.ground_truth_label: + self._labels_info[label_id]['sample_ids'].add(sample.image_id) + + sample_info = self._image_container_to_dict(sample) + self._samples_info.update({sample_info['id']: sample_info}) + + def retrieve_image(self, image_id: str): + """ + Retrieve image data from the job given image_id. + + Return: + string, image data in base64 byte + + """ + return self._image_dict.get(image_id, None) + + def retrieve_overlay(self, overlay_id: str): + """ + Retrieve sample map from the job given overlay_id. + + Return: + string, saliency_map data in base64 byte + """ + return self._overlay_dict.get(overlay_id, None) + + def get_all_samples(self): + """ + Return a list of sample information cachced in the explain job + + Returns: + sample_list (List[SampleObj]): a list of sample objects, each object + consists of: + + - id (int): sample id + - name (str): basename of image + - labels (list[str]): list of labels + - inferences list[dict]) + """ + samples_in_list = list(self._samples_info.values()) + return samples_in_list + + def _is_metadata_empty(self): + """Check whether metadata is loaded first.""" + if not self._explainers or not self._metrics or not self._labels: + return True + return False + + def _import_data_from_event(self, event): + """Parse and import data from the event data.""" + tags = { + 'image_id': PluginNameEnum.IMAGE_ID, + 'benchmark': PluginNameEnum.BENCHMARK, + 'metadata': PluginNameEnum.METADATA + } + + if 'metadata' not in event and self._is_metadata_empty(): + raise ValueError('metadata is empty, should write metadata first' + 'in the summary.') + for tag in tags: + if tag not in event: + continue + + if tag == PluginNameEnum.IMAGE_ID.value: + sample_event = event[tag] + sample_data = self._event_parser.parse_sample(sample_event) + if sample_data is not None: + self._import_sample(sample_data) + continue + + if tag == PluginNameEnum.BENCHMARK.value: + benchmark_event = event[tag].benchmark + benchmark = self._event_parser.parse_benchmark(benchmark_event) + self._benchmark = benchmark + + elif tag == PluginNameEnum.METADATA.value: + metadata_event = event[tag].metadata + metadata = self._event_parser.parse_metadata(metadata_event) + self._explainers, self._metrics, self._labels = metadata + self._initialize_labels_info() + + def load(self): + """ + Start loading data from parser. + """ + valid_file_names = [] + for filename in FileHandler.list_dir(self._summary_dir): + if FileHandler.is_file( + FileHandler.join(self._summary_dir, filename)): + valid_file_names.append(filename) + + if not valid_file_names: + raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir) + + is_end = False + while not is_end: + is_clean, is_end, event = self._parser.parse_explain(valid_file_names) + + if is_clean: + logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir) + self._clean_job() + + if event: + self._import_data_from_event(event) + + def _clean_job(self): + """Clean the cached data in job.""" + self._latest_update_time = ExplainJob.get_update_time(self._summary_dir) + self._create_time = ExplainJob.get_update_time(self._summary_dir) + self._labels.clear() + self._metrics.clear() + self._explainers.clear() + self._samples_info.clear() + self._labels_info.clear() + self._benchmark.clear() + self._overlay_dict.clear() + self._image_dict.clear() + self._event_parser.clear() diff --git a/mindinsight/explainer/manager/explain_manager.py b/mindinsight/explainer/manager/explain_manager.py new file mode 100644 index 00000000..24477848 --- /dev/null +++ b/mindinsight/explainer/manager/explain_manager.py @@ -0,0 +1,314 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ExplainManager.""" + +import os +import threading +import time + +from mindinsight.datavisual.common import exceptions +from mindinsight.datavisual.common.enums import BaseEnum +from mindinsight.explainer.common.log import logger +from mindinsight.explainer.manager.explain_job import ExplainJob +from mindinsight.datavisual.data_access.file_handler import FileHandler +from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher +from mindinsight.utils.exceptions import MindInsightException, ParamValueError + +_MAX_LOADER_NUM = 3 +_MAX_INTERVAL = 3 + + +class _ExplainManagerStatus(BaseEnum): + """Manager status.""" + INIT = 'INIT' + LOADING = 'LOADING' + DONE = 'DONE' + INVALID = 'INVALID' + + +class ExplainManager: + """ExplainManager.""" + + def __init__(self, summary_base_dir: str): + self._summary_base_dir = summary_base_dir + self._loader_pool = {} + self._deleted_ids = [] + self._status = _ExplainManagerStatus.INIT.value + self._status_mutex = threading.Lock() + self._loader_pool_mutex = threading.Lock() + self._max_loader_num = _MAX_LOADER_NUM + self._reload_interval = None + + def _reload_data(self): + """periodically load summary from file.""" + while True: + self._load_data() + + if not self._reload_interval: + break + time.sleep(self._reload_interval) + + def _load_data(self): + """Loading the summary in the given base directory.""" + logger.info( + 'Start to load data, reload interval: %r.', self._reload_interval) + + with self._status_mutex: + if self._status == _ExplainManagerStatus.LOADING.value: + logger.info('Current status is %s, will ignore to load data.', + self._status) + return + + self._status = _ExplainManagerStatus.LOADING.value + + self._generate_loaders() + self._execute_load_data() + + if not self._loader_pool: + self._status = _ExplainManagerStatus.INVALID.value + else: + self._status = _ExplainManagerStatus.DONE.value + + logger.info('Load event data end, status: %r, ' + 'and loader pool size is %r', + self._status, len(self._loader_pool)) + + def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): + """update the update time of loader of given id.""" + if latest_update_time is None: + latest_update_time = time.time() + self._loader_pool[loader_id].latest_update_time = latest_update_time + + def _delete_loader(self, loader_id): + """delete loader given loader_id""" + if self._loader_pool.get(loader_id, None) is not None: + self._loader_pool.pop(loader_id) + logger.debug('delete loader %s', loader_id) + + def _add_loader(self, loader): + """add loader to the loader_pool.""" + if len(self._loader_pool) >= _MAX_LOADER_NUM: + delete_num = len(self._loader_pool) - _MAX_LOADER_NUM + 1 + sorted_loaders = sorted( + self._loader_pool.items(), + key=lambda x: x[1].latest_update_time) + + for index in range(delete_num): + delete_loader_id = sorted_loaders[index][0] + self._delete_loader(delete_loader_id) + self._loader_pool.update({loader.loader_id: loader}) + + def _deal_loaders(self, latest_loaders): + """"update the loader pool.""" + with self._loader_pool_mutex: + for loader_id, loader in latest_loaders: + if self._loader_pool.get(loader_id, None) is None: + self._add_loader(loader) + continue + + if (self._loader_pool[loader_id].latest_update_time + < loader.latest_update_time): + self._update_loader_latest_update_time( + loader_id, loader.latest_update_time) + + @staticmethod + def _generate_loader_id(relative_path): + """Generate loader id for given path""" + loader_id = relative_path + return loader_id + + @staticmethod + def _generate_loader_name(relative_path): + """Generate_loader name for given path.""" + loader_name = relative_path + return loader_name + + def _generate_loader_by_relative_path(self, relative_path: str) -> ExplainJob: + """Generate explain job from given relative path.""" + current_dir = os.path.realpath(FileHandler.join( + self._summary_base_dir, relative_path + )) + loader_id = self._generate_loader_id(relative_path) + loader = ExplainJob( + job_id=loader_id, + summary_dir=current_dir, + create_time=ExplainJob.get_create_time(current_dir), + latest_update_time=ExplainJob.get_update_time(current_dir)) + return loader + + def _generate_loaders(self): + """Generate job loaders from the summary watcher.""" + dir_map_mtime_dict = {} + loader_dict = {} + min_modify_time = None + _, summaries = SummaryWatcher().list_explain_directories( + self._summary_base_dir) + + for item in summaries: + relative_path = item.get('relative_path') + modify_time = item.get('update_time').timestamp() + loader_id = self._generate_loader_id(relative_path) + + loader = self._loader_pool.get(loader_id, None) + if loader is not None and loader.latest_update_time > modify_time: + modify_time = loader.latest_update_time + + if min_modify_time is None: + min_modify_time = modify_time + + if len(dir_map_mtime_dict) < _MAX_LOADER_NUM: + if modify_time < min_modify_time: + min_modify_time = modify_time + dir_map_mtime_dict.update({relative_path: modify_time}) + else: + if modify_time >= min_modify_time: + dir_map_mtime_dict.update({relative_path: modify_time}) + + sorted_dir_tuple = sorted(dir_map_mtime_dict.items(), + key=lambda d: d[1])[-_MAX_LOADER_NUM:] + + for relative_path, modify_time in sorted_dir_tuple: + loader_id = self._generate_loader_id(relative_path) + loader = self._generate_loader_by_relative_path(relative_path) + loader_dict.update({loader_id: loader}) + + sorted_loaders = sorted(loader_dict.items(), + key=lambda x: x[1].latest_update_time) + latest_loaders = sorted_loaders[-_MAX_LOADER_NUM:] + self._deal_loaders(latest_loaders) + + def _execute_loader(self, loader_id): + """Execute the data loading.""" + try: + with self._loader_pool_mutex: + loader = self._loader_pool.get(loader_id, None) + if loader is None: + logger.debug('Loader %r has been deleted, will not load' + 'data', loader_id) + return + loader.load() + + except MindInsightException as e: + logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, e) + with self._loader_pool_mutex: + self._delete_loader(loader_id) + + def _execute_load_data(self): + """Execute the loader in the pool to load data.""" + loader_pool = self._get_snapshot_loader_pool() + for loader_id in loader_pool: + self._execute_loader(loader_id) + + def _get_snapshot_loader_pool(self): + """Get snapshot of loader_pool.""" + with self._loader_pool_mutex: + return dict(self._loader_pool) + + def _check_status_valid(self): + """Check manager status.""" + if self._status == _ExplainManagerStatus.INIT.value: + raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._status) + + @staticmethod + def _check_train_id_valid(train_id: str): + """Verify the train_id is valid.""" + if not train_id.startswith('./'): + logger.warning('train_id does not start with "./"') + return False + + if len(train_id.split('/')) > 2: + logger.warning('train_id contains multiple "/"') + return False + return True + + def _check_train_job_exist(self, train_id): + """Verify thee train_job is existed given train_id.""" + if train_id in self._loader_pool: + return + self._check_train_id_valid(train_id) + if SummaryWatcher().is_summary_directory(self._summary_base_dir, train_id): + return + raise ParamValueError('Can not find the train job in the manager, train_id: %s' % train_id) + + def _reload_data_again(self): + """Reload the data one more time.""" + logger.debug('Start to reload data again.') + thread = threading.Thread(target=self._load_data, + name='reload_data_thread') + thread.daemon = False + thread.start() + + def _get_job(self, train_id): + """Retrieve train_job given train_id.""" + is_reload = False + with self._loader_pool_mutex: + loader = self._loader_pool.get(train_id, None) + + if loader is None: + relative_path = train_id + temp_loader = self._generate_loader_by_relative_path( + relative_path) + + if temp_loader is None: + return None + + self._add_loader(temp_loader) + is_reload = True + + if is_reload: + self._reload_data_again() + return loader + + @property + def summary_base_dir(self): + """Return the base directory for summary records.""" + return self._summary_base_dir + + def get_job(self, train_id): + """ + Return ExplainJob given train_id. + + If explain job w.r.t given train_id is not found, None will be returned. + + Args: + train_id (str): The id of expected ExplainJob + + Return: + explain_job + """ + self._check_status_valid() + self._check_train_job_exist(train_id) + + loader = self._get_job(train_id) + if loader is None: + return None + return loader + + def start_load_data(self, + reload_interval=_MAX_INTERVAL): + """ + Start threads for loading data. + + Args: + reload_interval (int): interval to reload the summary from file + """ + self._reload_interval = reload_interval + + thread = threading.Thread(target=self._reload_data, name='start_load_data_thread') + thread.daemon = True + thread.start() + + # wait for data loading + time.sleep(1) diff --git a/mindinsight/explainer/manager/explain_parser.py b/mindinsight/explainer/manager/explain_parser.py index b08c0883..93d94b64 100644 --- a/mindinsight/explainer/manager/explain_parser.py +++ b/mindinsight/explainer/manager/explain_parser.py @@ -28,20 +28,36 @@ from mindinsight.explainer.common.log import logger from mindinsight.datavisual.data_access.file_handler import FileHandler from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 +from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain from mindinsight.utils.exceptions import UnknownError HEADER_SIZE = 8 CRC_STR_SIZE = 4 MAX_EVENT_STRING = 500000000 -ImageDataContainer = collections.namedtuple('ImageDataContainer', - ['image_id', 'image_data', 'ground_truth_label', - 'inference', 'explanation', 'status']) BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status']) MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status']) +class ImageDataContainer: + """ + Container for image data to allow pickling. + + Args: + explain_message (Explain): Explain proto buffer message. + """ + + def __init__(self, explain_message: Explain): + self.image_id = explain_message.image_id + self.image_data = explain_message.image_data + self.ground_truth_label = explain_message.ground_truth_label + self.inference = explain_message.inference + self.explanation = explain_message.explanation + self.status = explain_message.status + + class _ExplainParser(_SummaryParser): """The summary file parser.""" + def __init__(self, summary_dir): super(_ExplainParser, self).__init__(summary_dir) self._latest_filename = '' @@ -165,7 +181,6 @@ class _ExplainParser(_SummaryParser): tensor_value_list.append(tensor_value) return field_list, tensor_value_list - @staticmethod def _add_image_data(tensor_event_value): """ @@ -174,17 +189,9 @@ class _ExplainParser(_SummaryParser): Args: tensor_event_value: the object of Explain message """ - image_data = ImageDataContainer( - image_id=tensor_event_value.image_id, - image_data=tensor_event_value.image_data, - ground_truth_label=tensor_event_value.ground_truth_label, - inference=tensor_event_value.inference, - explanation=tensor_event_value.explanation, - status=tensor_event_value.status - ) + image_data = ImageDataContainer(tensor_event_value) return image_data - @staticmethod def _add_benchmark(tensor_event_value): """