| @@ -0,0 +1,180 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """EventParser for summary event.""" | |||||
| from collections import namedtuple | |||||
| from typing import Dict, Iterable, List, Optional, Tuple | |||||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||||
| from mindinsight.explainer.common.log import logger | |||||
| from mindinsight.utils.exceptions import UnknownError | |||||
| _IMAGE_DATA_TAGS = { | |||||
| 'image_data': PluginNameEnum.IMAGE_DATA.value, | |||||
| 'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value, | |||||
| 'inference': PluginNameEnum.INFERENCE.value, | |||||
| 'explanation': PluginNameEnum.EXPLANATION.value | |||||
| } | |||||
| class EventParser: | |||||
| """Parser for event data.""" | |||||
| def __init__(self, job): | |||||
| self._job = job | |||||
| self._sample_pool = {} | |||||
| def clear(self): | |||||
| """Clear the loaded data.""" | |||||
| self._sample_pool.clear() | |||||
| def parse_metadata(self, metadata) -> Tuple[List, List, List]: | |||||
| """Parse the metadata event.""" | |||||
| explainers = list(metadata.explain_method) | |||||
| metrics = list(metadata.benchmark_method) | |||||
| labels = list(metadata.label) | |||||
| return explainers, metrics, labels | |||||
| def parse_benchmark(self, benchmark) -> Dict: | |||||
| """Parse the benchmark event.""" | |||||
| imported_benchmark = {} | |||||
| for explainer_result in benchmark: | |||||
| explainer = explainer_result.explain_method | |||||
| total_score = explainer_result.total_score | |||||
| label_score = explainer_result.label_score | |||||
| explainer_benchmark = { | |||||
| 'explainer': explainer, | |||||
| 'evaluations': EventParser._total_score_to_dict(total_score), | |||||
| 'class_scores': EventParser._label_score_to_dict(label_score, self._job.labels) | |||||
| } | |||||
| imported_benchmark[explainer] = explainer_benchmark | |||||
| return imported_benchmark | |||||
| def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]: | |||||
| """Parse the sample event.""" | |||||
| sample_id = sample.image_id | |||||
| if sample_id not in self._sample_pool: | |||||
| self._sample_pool[sample_id] = sample | |||||
| return None | |||||
| for tag in _IMAGE_DATA_TAGS: | |||||
| try: | |||||
| if tag == PluginNameEnum.INFERENCE.value: | |||||
| self._parse_inference(sample, sample_id) | |||||
| elif tag == PluginNameEnum.EXPLANATION.value: | |||||
| self._parse_explanation(sample, sample_id) | |||||
| else: | |||||
| self._parse_sample_info(sample, sample_id, tag) | |||||
| except UnknownError as ex: | |||||
| logger.warning("Parse %s data failed within image related data," | |||||
| " detail: %r", tag, str(ex)) | |||||
| continue | |||||
| if EventParser._is_sample_data_complete(self._sample_pool[sample_id]): | |||||
| return self._sample_pool.pop(sample_id) | |||||
| if EventParser._is_ready_for_display(self._sample_pool[sample_id]): | |||||
| return self._sample_pool[sample_id] | |||||
| return None | |||||
| def _parse_inference(self, event, sample_id): | |||||
| """Parse the inference event.""" | |||||
| self._sample_pool[sample_id].inference.ground_truth_prob.extend( | |||||
| event.inference.ground_truth_prob) | |||||
| self._sample_pool[sample_id].inference.predicted_label.extend( | |||||
| event.inference.predicted_label) | |||||
| self._sample_pool[sample_id].inference.predicted_prob.extend( | |||||
| event.inference.predicted_prob) | |||||
| def _parse_explanation(self, event, sample_id): | |||||
| """Parse the explanation event.""" | |||||
| if event.explanation: | |||||
| for explanation_item in event.explanation: | |||||
| new_explanation = self._sample_pool[sample_id].explanation.add() | |||||
| new_explanation.explain_method = explanation_item.explain_method | |||||
| new_explanation.label = explanation_item.label | |||||
| new_explanation.heatmap = explanation_item.heatmap | |||||
| def _parse_sample_info(self, event, sample_id, tag): | |||||
| """Parse the event containing image info.""" | |||||
| if not getattr(self._sample_pool[sample_id], tag): | |||||
| setattr(self._sample_pool[sample_id], tag, getattr(event, tag)) | |||||
| @staticmethod | |||||
| def _total_score_to_dict(total_scores: Iterable): | |||||
| """Transfer a list of benchmark score to a list of dict.""" | |||||
| evaluation_info = [] | |||||
| for total_score in total_scores: | |||||
| metric_result = { | |||||
| 'metric': total_score.benchmark_method, | |||||
| 'score': total_score.score} | |||||
| evaluation_info.append(metric_result) | |||||
| return evaluation_info | |||||
| @staticmethod | |||||
| def _label_score_to_dict(label_scores: Iterable, labels: List[str]): | |||||
| """Transfer a list of benchmark score.""" | |||||
| evaluation_info = [{'label': label, 'evaluations': []} | |||||
| for label in labels] | |||||
| for label_score in label_scores: | |||||
| metric = label_score.benchmark_method | |||||
| for i, score in enumerate(label_score.score): | |||||
| label_metric_score = { | |||||
| 'metric': metric, | |||||
| 'score': score} | |||||
| evaluation_info[i]['evaluations'].append(label_metric_score) | |||||
| return evaluation_info | |||||
| @staticmethod | |||||
| def _is_sample_data_complete(image_container: namedtuple) -> bool: | |||||
| """Check whether sample data completely loaded.""" | |||||
| required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference', 'explanation'] | |||||
| for attr in required_attrs: | |||||
| if not EventParser.is_attr_ready(image_container, attr): | |||||
| return False | |||||
| return True | |||||
| @staticmethod | |||||
| def _is_ready_for_display(image_container: namedtuple) -> bool: | |||||
| """ | |||||
| Check whether the image_container is ready for frontend display. | |||||
| Args: | |||||
| image_container (nametuple): container consists of sample data | |||||
| Return: | |||||
| bool: whether the image_container if ready for display | |||||
| """ | |||||
| required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference'] | |||||
| for attr in required_attrs: | |||||
| if not EventParser.is_attr_ready(image_container, attr): | |||||
| return False | |||||
| return True | |||||
| @staticmethod | |||||
| def is_attr_ready(image_container: namedtuple, attr: str) -> bool: | |||||
| """ | |||||
| Check whether the given attribute is ready in image_container. | |||||
| Args: | |||||
| image_container (nametuple): container consist of sample data | |||||
| attr (str): attribute to check | |||||
| Returns: | |||||
| bool, whether the attr is ready | |||||
| """ | |||||
| if getattr(image_container, attr, False): | |||||
| return True | |||||
| return False | |||||
| @@ -0,0 +1,395 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ExplainJob.""" | |||||
| import os | |||||
| from datetime import datetime | |||||
| from typing import List, Iterable, Union | |||||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||||
| from mindinsight.explainer.common.log import logger | |||||
| from mindinsight.explainer.manager.explain_parser import _ExplainParser | |||||
| from mindinsight.explainer.manager.event_parse import EventParser | |||||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||||
| class ExplainJob: | |||||
| """ExplainJob which manage the record in the summary file.""" | |||||
| def __init__(self, | |||||
| job_id: str, | |||||
| summary_dir: str, | |||||
| create_time: float, | |||||
| latest_update_time: float): | |||||
| self._job_id = job_id | |||||
| self._summary_dir = summary_dir | |||||
| self._parser = _ExplainParser(summary_dir) | |||||
| self._event_parser = EventParser(self) | |||||
| self._latest_update_time = latest_update_time | |||||
| self._create_time = create_time | |||||
| self._labels = [] | |||||
| self._metrics = [] | |||||
| self._explainers = [] | |||||
| self._samples_info = {} | |||||
| self._labels_info = {} | |||||
| self._benchmark = {} | |||||
| self._overlay_dict = {} | |||||
| self._image_dict = {} | |||||
| @property | |||||
| def all_classes(self): | |||||
| """ | |||||
| Return a list of label info | |||||
| Returns: | |||||
| class_objs (List[ClassObj]): a list of class_objects, each object | |||||
| contains: | |||||
| - id (int): label id | |||||
| - label (str): label name | |||||
| - sample_count (int): number of samples for each label | |||||
| """ | |||||
| all_classes_return = [] | |||||
| for label_id, label_info in self._labels_info.items(): | |||||
| single_info = {'id': label_id, | |||||
| 'label': label_info['label'], | |||||
| 'sample_count': len(label_info['sample_ids'])} | |||||
| all_classes_return.append(single_info) | |||||
| return all_classes_return | |||||
| @property | |||||
| def explainers(self): | |||||
| """ | |||||
| Return a list of explainer names | |||||
| Returns: | |||||
| list(str), explainer names | |||||
| """ | |||||
| return self._explainers | |||||
| @property | |||||
| def explainer_scores(self): | |||||
| """Return evaluation results for every explainer.""" | |||||
| return [score for score in self._benchmark.values()] | |||||
| @property | |||||
| def sample_count(self): | |||||
| """ | |||||
| Return total number of samples in the job. | |||||
| Return: | |||||
| int, total number of samples | |||||
| """ | |||||
| return len(self._samples_info) | |||||
| @property | |||||
| def train_id(self): | |||||
| """ | |||||
| Return ID of explain job | |||||
| Returns: | |||||
| str, id of ExplainJob object | |||||
| """ | |||||
| return self._job_id | |||||
| @property | |||||
| def metrics(self): | |||||
| """ | |||||
| Return a list of metric names | |||||
| Returns: | |||||
| list(str), metric names | |||||
| """ | |||||
| return self._metrics | |||||
| @property | |||||
| def min_confidence(self): | |||||
| """ | |||||
| Return minimum confidence | |||||
| Returns: | |||||
| min_confidence (float): | |||||
| """ | |||||
| return None | |||||
| @property | |||||
| def create_time(self): | |||||
| """ | |||||
| Return the create time of summary file | |||||
| Returns: | |||||
| creation timestamp (float) | |||||
| """ | |||||
| return self._create_time | |||||
| @property | |||||
| def labels(self): | |||||
| """Return the label contained in the job.""" | |||||
| return self._labels | |||||
| @property | |||||
| def latest_update_time(self): | |||||
| """ | |||||
| Return last modification time stamp of summary file. | |||||
| Returns: | |||||
| float, last_modification_time stamp | |||||
| """ | |||||
| return self._latest_update_time | |||||
| @latest_update_time.setter | |||||
| def latest_update_time(self, new_time: Union[float, datetime]): | |||||
| """ | |||||
| Update the latest_update_time timestamp manually. | |||||
| Args: | |||||
| new_time stamp (union[float, datetime]): updated time for the job | |||||
| """ | |||||
| if isinstance(new_time, datetime): | |||||
| self._latest_update_time = new_time.timestamp() | |||||
| elif isinstance(new_time, str): | |||||
| self._latest_update_time = new_time | |||||
| else: | |||||
| raise TypeError('new_time should have type of str or datetime') | |||||
| @property | |||||
| def loader_id(self): | |||||
| """Return the job id.""" | |||||
| return self._job_id | |||||
| @property | |||||
| def samples(self): | |||||
| """Return the information of all samples in the job.""" | |||||
| return self._samples_info | |||||
| @staticmethod | |||||
| def get_create_time(file_path: str) -> float: | |||||
| """Return timestamp of create time of specific path.""" | |||||
| create_time = os.stat(file_path).st_ctime | |||||
| return create_time | |||||
| @staticmethod | |||||
| def get_update_time(file_path: str) -> float: | |||||
| """Return timestamp of update time of specific path.""" | |||||
| update_time = os.stat(file_path).st_mtime | |||||
| return update_time | |||||
| @staticmethod | |||||
| def _total_score_to_dict(total_scores: Iterable): | |||||
| """Transfer a list of benchmark score to a list of dict.""" | |||||
| evaluation_info = [] | |||||
| for total_score in total_scores: | |||||
| metric_result = {'metric': total_score.benchmark_method, | |||||
| 'score': total_score.score} | |||||
| evaluation_info.append(metric_result) | |||||
| return evaluation_info | |||||
| @staticmethod | |||||
| def _label_score_to_dict(label_scores: Iterable, labels: List[str]): | |||||
| """Transfer a list of benchmark score.""" | |||||
| evaluation_info = [{'label': label, 'evaluations': []} | |||||
| for label in labels] | |||||
| for label_score in label_scores: | |||||
| metric = label_score.benchmark_method | |||||
| for i, score in enumerate(label_score.score): | |||||
| label_metric_score = dict() | |||||
| label_metric_score['metric'] = metric | |||||
| label_metric_score['score'] = score | |||||
| evaluation_info[i]['evaluations'].append(label_metric_score) | |||||
| return evaluation_info | |||||
| def _initialize_labels_info(self): | |||||
| """Initialize a dict for labels in the job.""" | |||||
| if self._labels is None: | |||||
| logger.warning('No labels is provided in job %s', self._job_id) | |||||
| return | |||||
| for label_id, label in enumerate(self._labels): | |||||
| self._labels_info[label_id] = {'label': label, | |||||
| 'sample_ids': set()} | |||||
| def _explanation_to_dict(self, explanation, sample_id): | |||||
| """Transfer the explanation from event to dict storage.""" | |||||
| explainer_name = explanation.explain_method | |||||
| explain_label = explanation.label | |||||
| saliency = explanation.heatmap | |||||
| saliency_id = '{}_{}_{}'.format( | |||||
| sample_id, explain_label, explainer_name) | |||||
| explain_info = { | |||||
| 'explainer': explainer_name, | |||||
| 'overlay': saliency_id, | |||||
| } | |||||
| self._overlay_dict[saliency_id] = saliency | |||||
| return explain_info | |||||
| def _image_container_to_dict(self, sample_data): | |||||
| """Transfer the image container to dict storage.""" | |||||
| sample_id = sample_data.image_id | |||||
| sample_info = { | |||||
| 'id': sample_id, | |||||
| 'name': sample_id, | |||||
| 'labels': [self._labels_info[x]['label'] | |||||
| for x in sample_data.ground_truth_label], | |||||
| 'inferences': []} | |||||
| self._image_dict[sample_id] = sample_data.image_data | |||||
| ground_truth_labels = list(sample_data.ground_truth_label) | |||||
| ground_truth_probs = list(sample_data.inference.ground_truth_prob) | |||||
| predicted_labels = list(sample_data.inference.predicted_label) | |||||
| predicted_probs = list(sample_data.inference.predicted_prob) | |||||
| inference_info = {} | |||||
| for label, prob in zip( | |||||
| ground_truth_labels + predicted_labels, | |||||
| ground_truth_probs + predicted_probs): | |||||
| inference_info[label] = { | |||||
| 'label': self._labels_info[label]['label'], | |||||
| 'confidence': prob, | |||||
| 'saliency_maps': []} | |||||
| if EventParser.is_attr_ready(sample_data, 'explanation'): | |||||
| for explanation in sample_data.explanation: | |||||
| explanation_dict = self._explanation_to_dict( | |||||
| explanation, sample_id) | |||||
| inference_info[explanation.label]['saliency_maps'].append(explanation_dict) | |||||
| sample_info['inferences'] = list(inference_info.values()) | |||||
| return sample_info | |||||
| def _import_sample(self, sample): | |||||
| """Add sample object of given sample id.""" | |||||
| for label_id in sample.ground_truth_label: | |||||
| self._labels_info[label_id]['sample_ids'].add(sample.image_id) | |||||
| sample_info = self._image_container_to_dict(sample) | |||||
| self._samples_info.update({sample_info['id']: sample_info}) | |||||
| def retrieve_image(self, image_id: str): | |||||
| """ | |||||
| Retrieve image data from the job given image_id. | |||||
| Return: | |||||
| string, image data in base64 byte | |||||
| """ | |||||
| return self._image_dict.get(image_id, None) | |||||
| def retrieve_overlay(self, overlay_id: str): | |||||
| """ | |||||
| Retrieve sample map from the job given overlay_id. | |||||
| Return: | |||||
| string, saliency_map data in base64 byte | |||||
| """ | |||||
| return self._overlay_dict.get(overlay_id, None) | |||||
| def get_all_samples(self): | |||||
| """ | |||||
| Return a list of sample information cachced in the explain job | |||||
| Returns: | |||||
| sample_list (List[SampleObj]): a list of sample objects, each object | |||||
| consists of: | |||||
| - id (int): sample id | |||||
| - name (str): basename of image | |||||
| - labels (list[str]): list of labels | |||||
| - inferences list[dict]) | |||||
| """ | |||||
| samples_in_list = list(self._samples_info.values()) | |||||
| return samples_in_list | |||||
| def _is_metadata_empty(self): | |||||
| """Check whether metadata is loaded first.""" | |||||
| if not self._explainers or not self._metrics or not self._labels: | |||||
| return True | |||||
| return False | |||||
| def _import_data_from_event(self, event): | |||||
| """Parse and import data from the event data.""" | |||||
| tags = { | |||||
| 'image_id': PluginNameEnum.IMAGE_ID, | |||||
| 'benchmark': PluginNameEnum.BENCHMARK, | |||||
| 'metadata': PluginNameEnum.METADATA | |||||
| } | |||||
| if 'metadata' not in event and self._is_metadata_empty(): | |||||
| raise ValueError('metadata is empty, should write metadata first' | |||||
| 'in the summary.') | |||||
| for tag in tags: | |||||
| if tag not in event: | |||||
| continue | |||||
| if tag == PluginNameEnum.IMAGE_ID.value: | |||||
| sample_event = event[tag] | |||||
| sample_data = self._event_parser.parse_sample(sample_event) | |||||
| if sample_data is not None: | |||||
| self._import_sample(sample_data) | |||||
| continue | |||||
| if tag == PluginNameEnum.BENCHMARK.value: | |||||
| benchmark_event = event[tag].benchmark | |||||
| benchmark = self._event_parser.parse_benchmark(benchmark_event) | |||||
| self._benchmark = benchmark | |||||
| elif tag == PluginNameEnum.METADATA.value: | |||||
| metadata_event = event[tag].metadata | |||||
| metadata = self._event_parser.parse_metadata(metadata_event) | |||||
| self._explainers, self._metrics, self._labels = metadata | |||||
| self._initialize_labels_info() | |||||
| def load(self): | |||||
| """ | |||||
| Start loading data from parser. | |||||
| """ | |||||
| valid_file_names = [] | |||||
| for filename in FileHandler.list_dir(self._summary_dir): | |||||
| if FileHandler.is_file( | |||||
| FileHandler.join(self._summary_dir, filename)): | |||||
| valid_file_names.append(filename) | |||||
| if not valid_file_names: | |||||
| raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir) | |||||
| is_end = False | |||||
| while not is_end: | |||||
| is_clean, is_end, event = self._parser.parse_explain(valid_file_names) | |||||
| if is_clean: | |||||
| logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir) | |||||
| self._clean_job() | |||||
| if event: | |||||
| self._import_data_from_event(event) | |||||
| def _clean_job(self): | |||||
| """Clean the cached data in job.""" | |||||
| self._latest_update_time = ExplainJob.get_update_time(self._summary_dir) | |||||
| self._create_time = ExplainJob.get_update_time(self._summary_dir) | |||||
| self._labels.clear() | |||||
| self._metrics.clear() | |||||
| self._explainers.clear() | |||||
| self._samples_info.clear() | |||||
| self._labels_info.clear() | |||||
| self._benchmark.clear() | |||||
| self._overlay_dict.clear() | |||||
| self._image_dict.clear() | |||||
| self._event_parser.clear() | |||||
| @@ -0,0 +1,314 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ExplainManager.""" | |||||
| import os | |||||
| import threading | |||||
| import time | |||||
| from mindinsight.datavisual.common import exceptions | |||||
| from mindinsight.datavisual.common.enums import BaseEnum | |||||
| from mindinsight.explainer.common.log import logger | |||||
| from mindinsight.explainer.manager.explain_job import ExplainJob | |||||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||||
| from mindinsight.utils.exceptions import MindInsightException, ParamValueError | |||||
| _MAX_LOADER_NUM = 3 | |||||
| _MAX_INTERVAL = 3 | |||||
| class _ExplainManagerStatus(BaseEnum): | |||||
| """Manager status.""" | |||||
| INIT = 'INIT' | |||||
| LOADING = 'LOADING' | |||||
| DONE = 'DONE' | |||||
| INVALID = 'INVALID' | |||||
| class ExplainManager: | |||||
| """ExplainManager.""" | |||||
| def __init__(self, summary_base_dir: str): | |||||
| self._summary_base_dir = summary_base_dir | |||||
| self._loader_pool = {} | |||||
| self._deleted_ids = [] | |||||
| self._status = _ExplainManagerStatus.INIT.value | |||||
| self._status_mutex = threading.Lock() | |||||
| self._loader_pool_mutex = threading.Lock() | |||||
| self._max_loader_num = _MAX_LOADER_NUM | |||||
| self._reload_interval = None | |||||
| def _reload_data(self): | |||||
| """periodically load summary from file.""" | |||||
| while True: | |||||
| self._load_data() | |||||
| if not self._reload_interval: | |||||
| break | |||||
| time.sleep(self._reload_interval) | |||||
| def _load_data(self): | |||||
| """Loading the summary in the given base directory.""" | |||||
| logger.info( | |||||
| 'Start to load data, reload interval: %r.', self._reload_interval) | |||||
| with self._status_mutex: | |||||
| if self._status == _ExplainManagerStatus.LOADING.value: | |||||
| logger.info('Current status is %s, will ignore to load data.', | |||||
| self._status) | |||||
| return | |||||
| self._status = _ExplainManagerStatus.LOADING.value | |||||
| self._generate_loaders() | |||||
| self._execute_load_data() | |||||
| if not self._loader_pool: | |||||
| self._status = _ExplainManagerStatus.INVALID.value | |||||
| else: | |||||
| self._status = _ExplainManagerStatus.DONE.value | |||||
| logger.info('Load event data end, status: %r, ' | |||||
| 'and loader pool size is %r', | |||||
| self._status, len(self._loader_pool)) | |||||
| def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): | |||||
| """update the update time of loader of given id.""" | |||||
| if latest_update_time is None: | |||||
| latest_update_time = time.time() | |||||
| self._loader_pool[loader_id].latest_update_time = latest_update_time | |||||
| def _delete_loader(self, loader_id): | |||||
| """delete loader given loader_id""" | |||||
| if self._loader_pool.get(loader_id, None) is not None: | |||||
| self._loader_pool.pop(loader_id) | |||||
| logger.debug('delete loader %s', loader_id) | |||||
| def _add_loader(self, loader): | |||||
| """add loader to the loader_pool.""" | |||||
| if len(self._loader_pool) >= _MAX_LOADER_NUM: | |||||
| delete_num = len(self._loader_pool) - _MAX_LOADER_NUM + 1 | |||||
| sorted_loaders = sorted( | |||||
| self._loader_pool.items(), | |||||
| key=lambda x: x[1].latest_update_time) | |||||
| for index in range(delete_num): | |||||
| delete_loader_id = sorted_loaders[index][0] | |||||
| self._delete_loader(delete_loader_id) | |||||
| self._loader_pool.update({loader.loader_id: loader}) | |||||
| def _deal_loaders(self, latest_loaders): | |||||
| """"update the loader pool.""" | |||||
| with self._loader_pool_mutex: | |||||
| for loader_id, loader in latest_loaders: | |||||
| if self._loader_pool.get(loader_id, None) is None: | |||||
| self._add_loader(loader) | |||||
| continue | |||||
| if (self._loader_pool[loader_id].latest_update_time | |||||
| < loader.latest_update_time): | |||||
| self._update_loader_latest_update_time( | |||||
| loader_id, loader.latest_update_time) | |||||
| @staticmethod | |||||
| def _generate_loader_id(relative_path): | |||||
| """Generate loader id for given path""" | |||||
| loader_id = relative_path | |||||
| return loader_id | |||||
| @staticmethod | |||||
| def _generate_loader_name(relative_path): | |||||
| """Generate_loader name for given path.""" | |||||
| loader_name = relative_path | |||||
| return loader_name | |||||
| def _generate_loader_by_relative_path(self, relative_path: str) -> ExplainJob: | |||||
| """Generate explain job from given relative path.""" | |||||
| current_dir = os.path.realpath(FileHandler.join( | |||||
| self._summary_base_dir, relative_path | |||||
| )) | |||||
| loader_id = self._generate_loader_id(relative_path) | |||||
| loader = ExplainJob( | |||||
| job_id=loader_id, | |||||
| summary_dir=current_dir, | |||||
| create_time=ExplainJob.get_create_time(current_dir), | |||||
| latest_update_time=ExplainJob.get_update_time(current_dir)) | |||||
| return loader | |||||
| def _generate_loaders(self): | |||||
| """Generate job loaders from the summary watcher.""" | |||||
| dir_map_mtime_dict = {} | |||||
| loader_dict = {} | |||||
| min_modify_time = None | |||||
| _, summaries = SummaryWatcher().list_explain_directories( | |||||
| self._summary_base_dir) | |||||
| for item in summaries: | |||||
| relative_path = item.get('relative_path') | |||||
| modify_time = item.get('update_time').timestamp() | |||||
| loader_id = self._generate_loader_id(relative_path) | |||||
| loader = self._loader_pool.get(loader_id, None) | |||||
| if loader is not None and loader.latest_update_time > modify_time: | |||||
| modify_time = loader.latest_update_time | |||||
| if min_modify_time is None: | |||||
| min_modify_time = modify_time | |||||
| if len(dir_map_mtime_dict) < _MAX_LOADER_NUM: | |||||
| if modify_time < min_modify_time: | |||||
| min_modify_time = modify_time | |||||
| dir_map_mtime_dict.update({relative_path: modify_time}) | |||||
| else: | |||||
| if modify_time >= min_modify_time: | |||||
| dir_map_mtime_dict.update({relative_path: modify_time}) | |||||
| sorted_dir_tuple = sorted(dir_map_mtime_dict.items(), | |||||
| key=lambda d: d[1])[-_MAX_LOADER_NUM:] | |||||
| for relative_path, modify_time in sorted_dir_tuple: | |||||
| loader_id = self._generate_loader_id(relative_path) | |||||
| loader = self._generate_loader_by_relative_path(relative_path) | |||||
| loader_dict.update({loader_id: loader}) | |||||
| sorted_loaders = sorted(loader_dict.items(), | |||||
| key=lambda x: x[1].latest_update_time) | |||||
| latest_loaders = sorted_loaders[-_MAX_LOADER_NUM:] | |||||
| self._deal_loaders(latest_loaders) | |||||
| def _execute_loader(self, loader_id): | |||||
| """Execute the data loading.""" | |||||
| try: | |||||
| with self._loader_pool_mutex: | |||||
| loader = self._loader_pool.get(loader_id, None) | |||||
| if loader is None: | |||||
| logger.debug('Loader %r has been deleted, will not load' | |||||
| 'data', loader_id) | |||||
| return | |||||
| loader.load() | |||||
| except MindInsightException as e: | |||||
| logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, e) | |||||
| with self._loader_pool_mutex: | |||||
| self._delete_loader(loader_id) | |||||
| def _execute_load_data(self): | |||||
| """Execute the loader in the pool to load data.""" | |||||
| loader_pool = self._get_snapshot_loader_pool() | |||||
| for loader_id in loader_pool: | |||||
| self._execute_loader(loader_id) | |||||
| def _get_snapshot_loader_pool(self): | |||||
| """Get snapshot of loader_pool.""" | |||||
| with self._loader_pool_mutex: | |||||
| return dict(self._loader_pool) | |||||
| def _check_status_valid(self): | |||||
| """Check manager status.""" | |||||
| if self._status == _ExplainManagerStatus.INIT.value: | |||||
| raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._status) | |||||
| @staticmethod | |||||
| def _check_train_id_valid(train_id: str): | |||||
| """Verify the train_id is valid.""" | |||||
| if not train_id.startswith('./'): | |||||
| logger.warning('train_id does not start with "./"') | |||||
| return False | |||||
| if len(train_id.split('/')) > 2: | |||||
| logger.warning('train_id contains multiple "/"') | |||||
| return False | |||||
| return True | |||||
| def _check_train_job_exist(self, train_id): | |||||
| """Verify thee train_job is existed given train_id.""" | |||||
| if train_id in self._loader_pool: | |||||
| return | |||||
| self._check_train_id_valid(train_id) | |||||
| if SummaryWatcher().is_summary_directory(self._summary_base_dir, train_id): | |||||
| return | |||||
| raise ParamValueError('Can not find the train job in the manager, train_id: %s' % train_id) | |||||
| def _reload_data_again(self): | |||||
| """Reload the data one more time.""" | |||||
| logger.debug('Start to reload data again.') | |||||
| thread = threading.Thread(target=self._load_data, | |||||
| name='reload_data_thread') | |||||
| thread.daemon = False | |||||
| thread.start() | |||||
| def _get_job(self, train_id): | |||||
| """Retrieve train_job given train_id.""" | |||||
| is_reload = False | |||||
| with self._loader_pool_mutex: | |||||
| loader = self._loader_pool.get(train_id, None) | |||||
| if loader is None: | |||||
| relative_path = train_id | |||||
| temp_loader = self._generate_loader_by_relative_path( | |||||
| relative_path) | |||||
| if temp_loader is None: | |||||
| return None | |||||
| self._add_loader(temp_loader) | |||||
| is_reload = True | |||||
| if is_reload: | |||||
| self._reload_data_again() | |||||
| return loader | |||||
| @property | |||||
| def summary_base_dir(self): | |||||
| """Return the base directory for summary records.""" | |||||
| return self._summary_base_dir | |||||
| def get_job(self, train_id): | |||||
| """ | |||||
| Return ExplainJob given train_id. | |||||
| If explain job w.r.t given train_id is not found, None will be returned. | |||||
| Args: | |||||
| train_id (str): The id of expected ExplainJob | |||||
| Return: | |||||
| explain_job | |||||
| """ | |||||
| self._check_status_valid() | |||||
| self._check_train_job_exist(train_id) | |||||
| loader = self._get_job(train_id) | |||||
| if loader is None: | |||||
| return None | |||||
| return loader | |||||
| def start_load_data(self, | |||||
| reload_interval=_MAX_INTERVAL): | |||||
| """ | |||||
| Start threads for loading data. | |||||
| Args: | |||||
| reload_interval (int): interval to reload the summary from file | |||||
| """ | |||||
| self._reload_interval = reload_interval | |||||
| thread = threading.Thread(target=self._reload_data, name='start_load_data_thread') | |||||
| thread.daemon = True | |||||
| thread.start() | |||||
| # wait for data loading | |||||
| time.sleep(1) | |||||
| @@ -28,20 +28,36 @@ from mindinsight.explainer.common.log import logger | |||||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | from mindinsight.datavisual.data_access.file_handler import FileHandler | ||||
| from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser | from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser | ||||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | ||||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain | |||||
| from mindinsight.utils.exceptions import UnknownError | from mindinsight.utils.exceptions import UnknownError | ||||
| HEADER_SIZE = 8 | HEADER_SIZE = 8 | ||||
| CRC_STR_SIZE = 4 | CRC_STR_SIZE = 4 | ||||
| MAX_EVENT_STRING = 500000000 | MAX_EVENT_STRING = 500000000 | ||||
| ImageDataContainer = collections.namedtuple('ImageDataContainer', | |||||
| ['image_id', 'image_data', 'ground_truth_label', | |||||
| 'inference', 'explanation', 'status']) | |||||
| BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status']) | BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status']) | ||||
| MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status']) | MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status']) | ||||
| class ImageDataContainer: | |||||
| """ | |||||
| Container for image data to allow pickling. | |||||
| Args: | |||||
| explain_message (Explain): Explain proto buffer message. | |||||
| """ | |||||
| def __init__(self, explain_message: Explain): | |||||
| self.image_id = explain_message.image_id | |||||
| self.image_data = explain_message.image_data | |||||
| self.ground_truth_label = explain_message.ground_truth_label | |||||
| self.inference = explain_message.inference | |||||
| self.explanation = explain_message.explanation | |||||
| self.status = explain_message.status | |||||
| class _ExplainParser(_SummaryParser): | class _ExplainParser(_SummaryParser): | ||||
| """The summary file parser.""" | """The summary file parser.""" | ||||
| def __init__(self, summary_dir): | def __init__(self, summary_dir): | ||||
| super(_ExplainParser, self).__init__(summary_dir) | super(_ExplainParser, self).__init__(summary_dir) | ||||
| self._latest_filename = '' | self._latest_filename = '' | ||||
| @@ -165,7 +181,6 @@ class _ExplainParser(_SummaryParser): | |||||
| tensor_value_list.append(tensor_value) | tensor_value_list.append(tensor_value) | ||||
| return field_list, tensor_value_list | return field_list, tensor_value_list | ||||
| @staticmethod | @staticmethod | ||||
| def _add_image_data(tensor_event_value): | def _add_image_data(tensor_event_value): | ||||
| """ | """ | ||||
| @@ -174,17 +189,9 @@ class _ExplainParser(_SummaryParser): | |||||
| Args: | Args: | ||||
| tensor_event_value: the object of Explain message | tensor_event_value: the object of Explain message | ||||
| """ | """ | ||||
| image_data = ImageDataContainer( | |||||
| image_id=tensor_event_value.image_id, | |||||
| image_data=tensor_event_value.image_data, | |||||
| ground_truth_label=tensor_event_value.ground_truth_label, | |||||
| inference=tensor_event_value.inference, | |||||
| explanation=tensor_event_value.explanation, | |||||
| status=tensor_event_value.status | |||||
| ) | |||||
| image_data = ImageDataContainer(tensor_event_value) | |||||
| return image_data | return image_data | ||||
| @staticmethod | @staticmethod | ||||
| def _add_benchmark(tensor_event_value): | def _add_benchmark(tensor_event_value): | ||||
| """ | """ | ||||