| @@ -0,0 +1,180 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """EventParser for summary event.""" | |||
| from collections import namedtuple | |||
| from typing import Dict, Iterable, List, Optional, Tuple | |||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||
| from mindinsight.explainer.common.log import logger | |||
| from mindinsight.utils.exceptions import UnknownError | |||
| _IMAGE_DATA_TAGS = { | |||
| 'image_data': PluginNameEnum.IMAGE_DATA.value, | |||
| 'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value, | |||
| 'inference': PluginNameEnum.INFERENCE.value, | |||
| 'explanation': PluginNameEnum.EXPLANATION.value | |||
| } | |||
| class EventParser: | |||
| """Parser for event data.""" | |||
| def __init__(self, job): | |||
| self._job = job | |||
| self._sample_pool = {} | |||
| def clear(self): | |||
| """Clear the loaded data.""" | |||
| self._sample_pool.clear() | |||
| def parse_metadata(self, metadata) -> Tuple[List, List, List]: | |||
| """Parse the metadata event.""" | |||
| explainers = list(metadata.explain_method) | |||
| metrics = list(metadata.benchmark_method) | |||
| labels = list(metadata.label) | |||
| return explainers, metrics, labels | |||
| def parse_benchmark(self, benchmark) -> Dict: | |||
| """Parse the benchmark event.""" | |||
| imported_benchmark = {} | |||
| for explainer_result in benchmark: | |||
| explainer = explainer_result.explain_method | |||
| total_score = explainer_result.total_score | |||
| label_score = explainer_result.label_score | |||
| explainer_benchmark = { | |||
| 'explainer': explainer, | |||
| 'evaluations': EventParser._total_score_to_dict(total_score), | |||
| 'class_scores': EventParser._label_score_to_dict(label_score, self._job.labels) | |||
| } | |||
| imported_benchmark[explainer] = explainer_benchmark | |||
| return imported_benchmark | |||
| def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]: | |||
| """Parse the sample event.""" | |||
| sample_id = sample.image_id | |||
| if sample_id not in self._sample_pool: | |||
| self._sample_pool[sample_id] = sample | |||
| return None | |||
| for tag in _IMAGE_DATA_TAGS: | |||
| try: | |||
| if tag == PluginNameEnum.INFERENCE.value: | |||
| self._parse_inference(sample, sample_id) | |||
| elif tag == PluginNameEnum.EXPLANATION.value: | |||
| self._parse_explanation(sample, sample_id) | |||
| else: | |||
| self._parse_sample_info(sample, sample_id, tag) | |||
| except UnknownError as ex: | |||
| logger.warning("Parse %s data failed within image related data," | |||
| " detail: %r", tag, str(ex)) | |||
| continue | |||
| if EventParser._is_sample_data_complete(self._sample_pool[sample_id]): | |||
| return self._sample_pool.pop(sample_id) | |||
| if EventParser._is_ready_for_display(self._sample_pool[sample_id]): | |||
| return self._sample_pool[sample_id] | |||
| return None | |||
| def _parse_inference(self, event, sample_id): | |||
| """Parse the inference event.""" | |||
| self._sample_pool[sample_id].inference.ground_truth_prob.extend( | |||
| event.inference.ground_truth_prob) | |||
| self._sample_pool[sample_id].inference.predicted_label.extend( | |||
| event.inference.predicted_label) | |||
| self._sample_pool[sample_id].inference.predicted_prob.extend( | |||
| event.inference.predicted_prob) | |||
| def _parse_explanation(self, event, sample_id): | |||
| """Parse the explanation event.""" | |||
| if event.explanation: | |||
| for explanation_item in event.explanation: | |||
| new_explanation = self._sample_pool[sample_id].explanation.add() | |||
| new_explanation.explain_method = explanation_item.explain_method | |||
| new_explanation.label = explanation_item.label | |||
| new_explanation.heatmap = explanation_item.heatmap | |||
| def _parse_sample_info(self, event, sample_id, tag): | |||
| """Parse the event containing image info.""" | |||
| if not getattr(self._sample_pool[sample_id], tag): | |||
| setattr(self._sample_pool[sample_id], tag, getattr(event, tag)) | |||
| @staticmethod | |||
| def _total_score_to_dict(total_scores: Iterable): | |||
| """Transfer a list of benchmark score to a list of dict.""" | |||
| evaluation_info = [] | |||
| for total_score in total_scores: | |||
| metric_result = { | |||
| 'metric': total_score.benchmark_method, | |||
| 'score': total_score.score} | |||
| evaluation_info.append(metric_result) | |||
| return evaluation_info | |||
| @staticmethod | |||
| def _label_score_to_dict(label_scores: Iterable, labels: List[str]): | |||
| """Transfer a list of benchmark score.""" | |||
| evaluation_info = [{'label': label, 'evaluations': []} | |||
| for label in labels] | |||
| for label_score in label_scores: | |||
| metric = label_score.benchmark_method | |||
| for i, score in enumerate(label_score.score): | |||
| label_metric_score = { | |||
| 'metric': metric, | |||
| 'score': score} | |||
| evaluation_info[i]['evaluations'].append(label_metric_score) | |||
| return evaluation_info | |||
| @staticmethod | |||
| def _is_sample_data_complete(image_container: namedtuple) -> bool: | |||
| """Check whether sample data completely loaded.""" | |||
| required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference', 'explanation'] | |||
| for attr in required_attrs: | |||
| if not EventParser.is_attr_ready(image_container, attr): | |||
| return False | |||
| return True | |||
| @staticmethod | |||
| def _is_ready_for_display(image_container: namedtuple) -> bool: | |||
| """ | |||
| Check whether the image_container is ready for frontend display. | |||
| Args: | |||
| image_container (nametuple): container consists of sample data | |||
| Return: | |||
| bool: whether the image_container if ready for display | |||
| """ | |||
| required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference'] | |||
| for attr in required_attrs: | |||
| if not EventParser.is_attr_ready(image_container, attr): | |||
| return False | |||
| return True | |||
| @staticmethod | |||
| def is_attr_ready(image_container: namedtuple, attr: str) -> bool: | |||
| """ | |||
| Check whether the given attribute is ready in image_container. | |||
| Args: | |||
| image_container (nametuple): container consist of sample data | |||
| attr (str): attribute to check | |||
| Returns: | |||
| bool, whether the attr is ready | |||
| """ | |||
| if getattr(image_container, attr, False): | |||
| return True | |||
| return False | |||
| @@ -0,0 +1,395 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ExplainJob.""" | |||
| import os | |||
| from datetime import datetime | |||
| from typing import List, Iterable, Union | |||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||
| from mindinsight.explainer.common.log import logger | |||
| from mindinsight.explainer.manager.explain_parser import _ExplainParser | |||
| from mindinsight.explainer.manager.event_parse import EventParser | |||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| class ExplainJob: | |||
| """ExplainJob which manage the record in the summary file.""" | |||
| def __init__(self, | |||
| job_id: str, | |||
| summary_dir: str, | |||
| create_time: float, | |||
| latest_update_time: float): | |||
| self._job_id = job_id | |||
| self._summary_dir = summary_dir | |||
| self._parser = _ExplainParser(summary_dir) | |||
| self._event_parser = EventParser(self) | |||
| self._latest_update_time = latest_update_time | |||
| self._create_time = create_time | |||
| self._labels = [] | |||
| self._metrics = [] | |||
| self._explainers = [] | |||
| self._samples_info = {} | |||
| self._labels_info = {} | |||
| self._benchmark = {} | |||
| self._overlay_dict = {} | |||
| self._image_dict = {} | |||
| @property | |||
| def all_classes(self): | |||
| """ | |||
| Return a list of label info | |||
| Returns: | |||
| class_objs (List[ClassObj]): a list of class_objects, each object | |||
| contains: | |||
| - id (int): label id | |||
| - label (str): label name | |||
| - sample_count (int): number of samples for each label | |||
| """ | |||
| all_classes_return = [] | |||
| for label_id, label_info in self._labels_info.items(): | |||
| single_info = {'id': label_id, | |||
| 'label': label_info['label'], | |||
| 'sample_count': len(label_info['sample_ids'])} | |||
| all_classes_return.append(single_info) | |||
| return all_classes_return | |||
| @property | |||
| def explainers(self): | |||
| """ | |||
| Return a list of explainer names | |||
| Returns: | |||
| list(str), explainer names | |||
| """ | |||
| return self._explainers | |||
| @property | |||
| def explainer_scores(self): | |||
| """Return evaluation results for every explainer.""" | |||
| return [score for score in self._benchmark.values()] | |||
| @property | |||
| def sample_count(self): | |||
| """ | |||
| Return total number of samples in the job. | |||
| Return: | |||
| int, total number of samples | |||
| """ | |||
| return len(self._samples_info) | |||
| @property | |||
| def train_id(self): | |||
| """ | |||
| Return ID of explain job | |||
| Returns: | |||
| str, id of ExplainJob object | |||
| """ | |||
| return self._job_id | |||
| @property | |||
| def metrics(self): | |||
| """ | |||
| Return a list of metric names | |||
| Returns: | |||
| list(str), metric names | |||
| """ | |||
| return self._metrics | |||
| @property | |||
| def min_confidence(self): | |||
| """ | |||
| Return minimum confidence | |||
| Returns: | |||
| min_confidence (float): | |||
| """ | |||
| return None | |||
| @property | |||
| def create_time(self): | |||
| """ | |||
| Return the create time of summary file | |||
| Returns: | |||
| creation timestamp (float) | |||
| """ | |||
| return self._create_time | |||
| @property | |||
| def labels(self): | |||
| """Return the label contained in the job.""" | |||
| return self._labels | |||
| @property | |||
| def latest_update_time(self): | |||
| """ | |||
| Return last modification time stamp of summary file. | |||
| Returns: | |||
| float, last_modification_time stamp | |||
| """ | |||
| return self._latest_update_time | |||
| @latest_update_time.setter | |||
| def latest_update_time(self, new_time: Union[float, datetime]): | |||
| """ | |||
| Update the latest_update_time timestamp manually. | |||
| Args: | |||
| new_time stamp (union[float, datetime]): updated time for the job | |||
| """ | |||
| if isinstance(new_time, datetime): | |||
| self._latest_update_time = new_time.timestamp() | |||
| elif isinstance(new_time, str): | |||
| self._latest_update_time = new_time | |||
| else: | |||
| raise TypeError('new_time should have type of str or datetime') | |||
| @property | |||
| def loader_id(self): | |||
| """Return the job id.""" | |||
| return self._job_id | |||
| @property | |||
| def samples(self): | |||
| """Return the information of all samples in the job.""" | |||
| return self._samples_info | |||
| @staticmethod | |||
| def get_create_time(file_path: str) -> float: | |||
| """Return timestamp of create time of specific path.""" | |||
| create_time = os.stat(file_path).st_ctime | |||
| return create_time | |||
| @staticmethod | |||
| def get_update_time(file_path: str) -> float: | |||
| """Return timestamp of update time of specific path.""" | |||
| update_time = os.stat(file_path).st_mtime | |||
| return update_time | |||
| @staticmethod | |||
| def _total_score_to_dict(total_scores: Iterable): | |||
| """Transfer a list of benchmark score to a list of dict.""" | |||
| evaluation_info = [] | |||
| for total_score in total_scores: | |||
| metric_result = {'metric': total_score.benchmark_method, | |||
| 'score': total_score.score} | |||
| evaluation_info.append(metric_result) | |||
| return evaluation_info | |||
| @staticmethod | |||
| def _label_score_to_dict(label_scores: Iterable, labels: List[str]): | |||
| """Transfer a list of benchmark score.""" | |||
| evaluation_info = [{'label': label, 'evaluations': []} | |||
| for label in labels] | |||
| for label_score in label_scores: | |||
| metric = label_score.benchmark_method | |||
| for i, score in enumerate(label_score.score): | |||
| label_metric_score = dict() | |||
| label_metric_score['metric'] = metric | |||
| label_metric_score['score'] = score | |||
| evaluation_info[i]['evaluations'].append(label_metric_score) | |||
| return evaluation_info | |||
| def _initialize_labels_info(self): | |||
| """Initialize a dict for labels in the job.""" | |||
| if self._labels is None: | |||
| logger.warning('No labels is provided in job %s', self._job_id) | |||
| return | |||
| for label_id, label in enumerate(self._labels): | |||
| self._labels_info[label_id] = {'label': label, | |||
| 'sample_ids': set()} | |||
| def _explanation_to_dict(self, explanation, sample_id): | |||
| """Transfer the explanation from event to dict storage.""" | |||
| explainer_name = explanation.explain_method | |||
| explain_label = explanation.label | |||
| saliency = explanation.heatmap | |||
| saliency_id = '{}_{}_{}'.format( | |||
| sample_id, explain_label, explainer_name) | |||
| explain_info = { | |||
| 'explainer': explainer_name, | |||
| 'overlay': saliency_id, | |||
| } | |||
| self._overlay_dict[saliency_id] = saliency | |||
| return explain_info | |||
| def _image_container_to_dict(self, sample_data): | |||
| """Transfer the image container to dict storage.""" | |||
| sample_id = sample_data.image_id | |||
| sample_info = { | |||
| 'id': sample_id, | |||
| 'name': sample_id, | |||
| 'labels': [self._labels_info[x]['label'] | |||
| for x in sample_data.ground_truth_label], | |||
| 'inferences': []} | |||
| self._image_dict[sample_id] = sample_data.image_data | |||
| ground_truth_labels = list(sample_data.ground_truth_label) | |||
| ground_truth_probs = list(sample_data.inference.ground_truth_prob) | |||
| predicted_labels = list(sample_data.inference.predicted_label) | |||
| predicted_probs = list(sample_data.inference.predicted_prob) | |||
| inference_info = {} | |||
| for label, prob in zip( | |||
| ground_truth_labels + predicted_labels, | |||
| ground_truth_probs + predicted_probs): | |||
| inference_info[label] = { | |||
| 'label': self._labels_info[label]['label'], | |||
| 'confidence': prob, | |||
| 'saliency_maps': []} | |||
| if EventParser.is_attr_ready(sample_data, 'explanation'): | |||
| for explanation in sample_data.explanation: | |||
| explanation_dict = self._explanation_to_dict( | |||
| explanation, sample_id) | |||
| inference_info[explanation.label]['saliency_maps'].append(explanation_dict) | |||
| sample_info['inferences'] = list(inference_info.values()) | |||
| return sample_info | |||
| def _import_sample(self, sample): | |||
| """Add sample object of given sample id.""" | |||
| for label_id in sample.ground_truth_label: | |||
| self._labels_info[label_id]['sample_ids'].add(sample.image_id) | |||
| sample_info = self._image_container_to_dict(sample) | |||
| self._samples_info.update({sample_info['id']: sample_info}) | |||
| def retrieve_image(self, image_id: str): | |||
| """ | |||
| Retrieve image data from the job given image_id. | |||
| Return: | |||
| string, image data in base64 byte | |||
| """ | |||
| return self._image_dict.get(image_id, None) | |||
| def retrieve_overlay(self, overlay_id: str): | |||
| """ | |||
| Retrieve sample map from the job given overlay_id. | |||
| Return: | |||
| string, saliency_map data in base64 byte | |||
| """ | |||
| return self._overlay_dict.get(overlay_id, None) | |||
| def get_all_samples(self): | |||
| """ | |||
| Return a list of sample information cachced in the explain job | |||
| Returns: | |||
| sample_list (List[SampleObj]): a list of sample objects, each object | |||
| consists of: | |||
| - id (int): sample id | |||
| - name (str): basename of image | |||
| - labels (list[str]): list of labels | |||
| - inferences list[dict]) | |||
| """ | |||
| samples_in_list = list(self._samples_info.values()) | |||
| return samples_in_list | |||
| def _is_metadata_empty(self): | |||
| """Check whether metadata is loaded first.""" | |||
| if not self._explainers or not self._metrics or not self._labels: | |||
| return True | |||
| return False | |||
| def _import_data_from_event(self, event): | |||
| """Parse and import data from the event data.""" | |||
| tags = { | |||
| 'image_id': PluginNameEnum.IMAGE_ID, | |||
| 'benchmark': PluginNameEnum.BENCHMARK, | |||
| 'metadata': PluginNameEnum.METADATA | |||
| } | |||
| if 'metadata' not in event and self._is_metadata_empty(): | |||
| raise ValueError('metadata is empty, should write metadata first' | |||
| 'in the summary.') | |||
| for tag in tags: | |||
| if tag not in event: | |||
| continue | |||
| if tag == PluginNameEnum.IMAGE_ID.value: | |||
| sample_event = event[tag] | |||
| sample_data = self._event_parser.parse_sample(sample_event) | |||
| if sample_data is not None: | |||
| self._import_sample(sample_data) | |||
| continue | |||
| if tag == PluginNameEnum.BENCHMARK.value: | |||
| benchmark_event = event[tag].benchmark | |||
| benchmark = self._event_parser.parse_benchmark(benchmark_event) | |||
| self._benchmark = benchmark | |||
| elif tag == PluginNameEnum.METADATA.value: | |||
| metadata_event = event[tag].metadata | |||
| metadata = self._event_parser.parse_metadata(metadata_event) | |||
| self._explainers, self._metrics, self._labels = metadata | |||
| self._initialize_labels_info() | |||
| def load(self): | |||
| """ | |||
| Start loading data from parser. | |||
| """ | |||
| valid_file_names = [] | |||
| for filename in FileHandler.list_dir(self._summary_dir): | |||
| if FileHandler.is_file( | |||
| FileHandler.join(self._summary_dir, filename)): | |||
| valid_file_names.append(filename) | |||
| if not valid_file_names: | |||
| raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir) | |||
| is_end = False | |||
| while not is_end: | |||
| is_clean, is_end, event = self._parser.parse_explain(valid_file_names) | |||
| if is_clean: | |||
| logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir) | |||
| self._clean_job() | |||
| if event: | |||
| self._import_data_from_event(event) | |||
| def _clean_job(self): | |||
| """Clean the cached data in job.""" | |||
| self._latest_update_time = ExplainJob.get_update_time(self._summary_dir) | |||
| self._create_time = ExplainJob.get_update_time(self._summary_dir) | |||
| self._labels.clear() | |||
| self._metrics.clear() | |||
| self._explainers.clear() | |||
| self._samples_info.clear() | |||
| self._labels_info.clear() | |||
| self._benchmark.clear() | |||
| self._overlay_dict.clear() | |||
| self._image_dict.clear() | |||
| self._event_parser.clear() | |||
| @@ -0,0 +1,314 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ExplainManager.""" | |||
| import os | |||
| import threading | |||
| import time | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import BaseEnum | |||
| from mindinsight.explainer.common.log import logger | |||
| from mindinsight.explainer.manager.explain_job import ExplainJob | |||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||
| from mindinsight.utils.exceptions import MindInsightException, ParamValueError | |||
| _MAX_LOADER_NUM = 3 | |||
| _MAX_INTERVAL = 3 | |||
| class _ExplainManagerStatus(BaseEnum): | |||
| """Manager status.""" | |||
| INIT = 'INIT' | |||
| LOADING = 'LOADING' | |||
| DONE = 'DONE' | |||
| INVALID = 'INVALID' | |||
| class ExplainManager: | |||
| """ExplainManager.""" | |||
| def __init__(self, summary_base_dir: str): | |||
| self._summary_base_dir = summary_base_dir | |||
| self._loader_pool = {} | |||
| self._deleted_ids = [] | |||
| self._status = _ExplainManagerStatus.INIT.value | |||
| self._status_mutex = threading.Lock() | |||
| self._loader_pool_mutex = threading.Lock() | |||
| self._max_loader_num = _MAX_LOADER_NUM | |||
| self._reload_interval = None | |||
| def _reload_data(self): | |||
| """periodically load summary from file.""" | |||
| while True: | |||
| self._load_data() | |||
| if not self._reload_interval: | |||
| break | |||
| time.sleep(self._reload_interval) | |||
| def _load_data(self): | |||
| """Loading the summary in the given base directory.""" | |||
| logger.info( | |||
| 'Start to load data, reload interval: %r.', self._reload_interval) | |||
| with self._status_mutex: | |||
| if self._status == _ExplainManagerStatus.LOADING.value: | |||
| logger.info('Current status is %s, will ignore to load data.', | |||
| self._status) | |||
| return | |||
| self._status = _ExplainManagerStatus.LOADING.value | |||
| self._generate_loaders() | |||
| self._execute_load_data() | |||
| if not self._loader_pool: | |||
| self._status = _ExplainManagerStatus.INVALID.value | |||
| else: | |||
| self._status = _ExplainManagerStatus.DONE.value | |||
| logger.info('Load event data end, status: %r, ' | |||
| 'and loader pool size is %r', | |||
| self._status, len(self._loader_pool)) | |||
| def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): | |||
| """update the update time of loader of given id.""" | |||
| if latest_update_time is None: | |||
| latest_update_time = time.time() | |||
| self._loader_pool[loader_id].latest_update_time = latest_update_time | |||
| def _delete_loader(self, loader_id): | |||
| """delete loader given loader_id""" | |||
| if self._loader_pool.get(loader_id, None) is not None: | |||
| self._loader_pool.pop(loader_id) | |||
| logger.debug('delete loader %s', loader_id) | |||
| def _add_loader(self, loader): | |||
| """add loader to the loader_pool.""" | |||
| if len(self._loader_pool) >= _MAX_LOADER_NUM: | |||
| delete_num = len(self._loader_pool) - _MAX_LOADER_NUM + 1 | |||
| sorted_loaders = sorted( | |||
| self._loader_pool.items(), | |||
| key=lambda x: x[1].latest_update_time) | |||
| for index in range(delete_num): | |||
| delete_loader_id = sorted_loaders[index][0] | |||
| self._delete_loader(delete_loader_id) | |||
| self._loader_pool.update({loader.loader_id: loader}) | |||
| def _deal_loaders(self, latest_loaders): | |||
| """"update the loader pool.""" | |||
| with self._loader_pool_mutex: | |||
| for loader_id, loader in latest_loaders: | |||
| if self._loader_pool.get(loader_id, None) is None: | |||
| self._add_loader(loader) | |||
| continue | |||
| if (self._loader_pool[loader_id].latest_update_time | |||
| < loader.latest_update_time): | |||
| self._update_loader_latest_update_time( | |||
| loader_id, loader.latest_update_time) | |||
| @staticmethod | |||
| def _generate_loader_id(relative_path): | |||
| """Generate loader id for given path""" | |||
| loader_id = relative_path | |||
| return loader_id | |||
| @staticmethod | |||
| def _generate_loader_name(relative_path): | |||
| """Generate_loader name for given path.""" | |||
| loader_name = relative_path | |||
| return loader_name | |||
| def _generate_loader_by_relative_path(self, relative_path: str) -> ExplainJob: | |||
| """Generate explain job from given relative path.""" | |||
| current_dir = os.path.realpath(FileHandler.join( | |||
| self._summary_base_dir, relative_path | |||
| )) | |||
| loader_id = self._generate_loader_id(relative_path) | |||
| loader = ExplainJob( | |||
| job_id=loader_id, | |||
| summary_dir=current_dir, | |||
| create_time=ExplainJob.get_create_time(current_dir), | |||
| latest_update_time=ExplainJob.get_update_time(current_dir)) | |||
| return loader | |||
| def _generate_loaders(self): | |||
| """Generate job loaders from the summary watcher.""" | |||
| dir_map_mtime_dict = {} | |||
| loader_dict = {} | |||
| min_modify_time = None | |||
| _, summaries = SummaryWatcher().list_explain_directories( | |||
| self._summary_base_dir) | |||
| for item in summaries: | |||
| relative_path = item.get('relative_path') | |||
| modify_time = item.get('update_time').timestamp() | |||
| loader_id = self._generate_loader_id(relative_path) | |||
| loader = self._loader_pool.get(loader_id, None) | |||
| if loader is not None and loader.latest_update_time > modify_time: | |||
| modify_time = loader.latest_update_time | |||
| if min_modify_time is None: | |||
| min_modify_time = modify_time | |||
| if len(dir_map_mtime_dict) < _MAX_LOADER_NUM: | |||
| if modify_time < min_modify_time: | |||
| min_modify_time = modify_time | |||
| dir_map_mtime_dict.update({relative_path: modify_time}) | |||
| else: | |||
| if modify_time >= min_modify_time: | |||
| dir_map_mtime_dict.update({relative_path: modify_time}) | |||
| sorted_dir_tuple = sorted(dir_map_mtime_dict.items(), | |||
| key=lambda d: d[1])[-_MAX_LOADER_NUM:] | |||
| for relative_path, modify_time in sorted_dir_tuple: | |||
| loader_id = self._generate_loader_id(relative_path) | |||
| loader = self._generate_loader_by_relative_path(relative_path) | |||
| loader_dict.update({loader_id: loader}) | |||
| sorted_loaders = sorted(loader_dict.items(), | |||
| key=lambda x: x[1].latest_update_time) | |||
| latest_loaders = sorted_loaders[-_MAX_LOADER_NUM:] | |||
| self._deal_loaders(latest_loaders) | |||
| def _execute_loader(self, loader_id): | |||
| """Execute the data loading.""" | |||
| try: | |||
| with self._loader_pool_mutex: | |||
| loader = self._loader_pool.get(loader_id, None) | |||
| if loader is None: | |||
| logger.debug('Loader %r has been deleted, will not load' | |||
| 'data', loader_id) | |||
| return | |||
| loader.load() | |||
| except MindInsightException as e: | |||
| logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, e) | |||
| with self._loader_pool_mutex: | |||
| self._delete_loader(loader_id) | |||
| def _execute_load_data(self): | |||
| """Execute the loader in the pool to load data.""" | |||
| loader_pool = self._get_snapshot_loader_pool() | |||
| for loader_id in loader_pool: | |||
| self._execute_loader(loader_id) | |||
| def _get_snapshot_loader_pool(self): | |||
| """Get snapshot of loader_pool.""" | |||
| with self._loader_pool_mutex: | |||
| return dict(self._loader_pool) | |||
| def _check_status_valid(self): | |||
| """Check manager status.""" | |||
| if self._status == _ExplainManagerStatus.INIT.value: | |||
| raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._status) | |||
| @staticmethod | |||
| def _check_train_id_valid(train_id: str): | |||
| """Verify the train_id is valid.""" | |||
| if not train_id.startswith('./'): | |||
| logger.warning('train_id does not start with "./"') | |||
| return False | |||
| if len(train_id.split('/')) > 2: | |||
| logger.warning('train_id contains multiple "/"') | |||
| return False | |||
| return True | |||
| def _check_train_job_exist(self, train_id): | |||
| """Verify thee train_job is existed given train_id.""" | |||
| if train_id in self._loader_pool: | |||
| return | |||
| self._check_train_id_valid(train_id) | |||
| if SummaryWatcher().is_summary_directory(self._summary_base_dir, train_id): | |||
| return | |||
| raise ParamValueError('Can not find the train job in the manager, train_id: %s' % train_id) | |||
| def _reload_data_again(self): | |||
| """Reload the data one more time.""" | |||
| logger.debug('Start to reload data again.') | |||
| thread = threading.Thread(target=self._load_data, | |||
| name='reload_data_thread') | |||
| thread.daemon = False | |||
| thread.start() | |||
| def _get_job(self, train_id): | |||
| """Retrieve train_job given train_id.""" | |||
| is_reload = False | |||
| with self._loader_pool_mutex: | |||
| loader = self._loader_pool.get(train_id, None) | |||
| if loader is None: | |||
| relative_path = train_id | |||
| temp_loader = self._generate_loader_by_relative_path( | |||
| relative_path) | |||
| if temp_loader is None: | |||
| return None | |||
| self._add_loader(temp_loader) | |||
| is_reload = True | |||
| if is_reload: | |||
| self._reload_data_again() | |||
| return loader | |||
| @property | |||
| def summary_base_dir(self): | |||
| """Return the base directory for summary records.""" | |||
| return self._summary_base_dir | |||
| def get_job(self, train_id): | |||
| """ | |||
| Return ExplainJob given train_id. | |||
| If explain job w.r.t given train_id is not found, None will be returned. | |||
| Args: | |||
| train_id (str): The id of expected ExplainJob | |||
| Return: | |||
| explain_job | |||
| """ | |||
| self._check_status_valid() | |||
| self._check_train_job_exist(train_id) | |||
| loader = self._get_job(train_id) | |||
| if loader is None: | |||
| return None | |||
| return loader | |||
| def start_load_data(self, | |||
| reload_interval=_MAX_INTERVAL): | |||
| """ | |||
| Start threads for loading data. | |||
| Args: | |||
| reload_interval (int): interval to reload the summary from file | |||
| """ | |||
| self._reload_interval = reload_interval | |||
| thread = threading.Thread(target=self._reload_data, name='start_load_data_thread') | |||
| thread.daemon = True | |||
| thread.start() | |||
| # wait for data loading | |||
| time.sleep(1) | |||
| @@ -28,20 +28,36 @@ from mindinsight.explainer.common.log import logger | |||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||
| from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser | |||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain | |||
| from mindinsight.utils.exceptions import UnknownError | |||
| HEADER_SIZE = 8 | |||
| CRC_STR_SIZE = 4 | |||
| MAX_EVENT_STRING = 500000000 | |||
| ImageDataContainer = collections.namedtuple('ImageDataContainer', | |||
| ['image_id', 'image_data', 'ground_truth_label', | |||
| 'inference', 'explanation', 'status']) | |||
| BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status']) | |||
| MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status']) | |||
| class ImageDataContainer: | |||
| """ | |||
| Container for image data to allow pickling. | |||
| Args: | |||
| explain_message (Explain): Explain proto buffer message. | |||
| """ | |||
| def __init__(self, explain_message: Explain): | |||
| self.image_id = explain_message.image_id | |||
| self.image_data = explain_message.image_data | |||
| self.ground_truth_label = explain_message.ground_truth_label | |||
| self.inference = explain_message.inference | |||
| self.explanation = explain_message.explanation | |||
| self.status = explain_message.status | |||
| class _ExplainParser(_SummaryParser): | |||
| """The summary file parser.""" | |||
| def __init__(self, summary_dir): | |||
| super(_ExplainParser, self).__init__(summary_dir) | |||
| self._latest_filename = '' | |||
| @@ -165,7 +181,6 @@ class _ExplainParser(_SummaryParser): | |||
| tensor_value_list.append(tensor_value) | |||
| return field_list, tensor_value_list | |||
| @staticmethod | |||
| def _add_image_data(tensor_event_value): | |||
| """ | |||
| @@ -174,17 +189,9 @@ class _ExplainParser(_SummaryParser): | |||
| Args: | |||
| tensor_event_value: the object of Explain message | |||
| """ | |||
| image_data = ImageDataContainer( | |||
| image_id=tensor_event_value.image_id, | |||
| image_data=tensor_event_value.image_data, | |||
| ground_truth_label=tensor_event_value.ground_truth_label, | |||
| inference=tensor_event_value.inference, | |||
| explanation=tensor_event_value.explanation, | |||
| status=tensor_event_value.status | |||
| ) | |||
| image_data = ImageDataContainer(tensor_event_value) | |||
| return image_data | |||
| @staticmethod | |||
| def _add_benchmark(tensor_event_value): | |||
| """ | |||