# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ExplainJob.""" import os from datetime import datetime from typing import List, Iterable, Union from mindinsight.explainer.common.enums import PluginNameEnum from mindinsight.explainer.common.log import logger from mindinsight.explainer.manager.explain_parser import _ExplainParser from mindinsight.explainer.manager.event_parse import EventParser from mindinsight.datavisual.data_access.file_handler import FileHandler from mindinsight.datavisual.common.exceptions import TrainJobNotExistError class ExplainJob: """ExplainJob which manage the record in the summary file.""" def __init__(self, job_id: str, summary_dir: str, create_time: float, latest_update_time: float): self._job_id = job_id self._summary_dir = summary_dir self._parser = _ExplainParser(summary_dir) self._event_parser = EventParser(self) self._latest_update_time = latest_update_time self._create_time = create_time self._labels = [] self._metrics = [] self._explainers = [] self._samples_info = {} self._labels_info = {} self._benchmark = {} self._overlay_dict = {} self._image_dict = {} @property def all_classes(self): """ Return a list of label info Returns: class_objs (List[ClassObj]): a list of class_objects, each object contains: - id (int): label id - label (str): label name - sample_count (int): number of samples for each label """ all_classes_return = [] for label_id, label_info in self._labels_info.items(): single_info = {'id': label_id, 'label': label_info['label'], 'sample_count': len(label_info['sample_ids'])} all_classes_return.append(single_info) return all_classes_return @property def explainers(self): """ Return a list of explainer names Returns: list(str), explainer names """ return self._explainers @property def explainer_scores(self): """Return evaluation results for every explainer.""" return [score for score in self._benchmark.values()] @property def sample_count(self): """ Return total number of samples in the job. Return: int, total number of samples """ return len(self._samples_info) @property def train_id(self): """ Return ID of explain job Returns: str, id of ExplainJob object """ return self._job_id @property def metrics(self): """ Return a list of metric names Returns: list(str), metric names """ return self._metrics @property def min_confidence(self): """ Return minimum confidence Returns: min_confidence (float): """ return None @property def create_time(self): """ Return the create time of summary file Returns: creation timestamp (float) """ return self._create_time @property def labels(self): """Return the label contained in the job.""" return self._labels @property def latest_update_time(self): """ Return last modification time stamp of summary file. Returns: float, last_modification_time stamp """ return self._latest_update_time @latest_update_time.setter def latest_update_time(self, new_time: Union[float, datetime]): """ Update the latest_update_time timestamp manually. Args: new_time stamp (union[float, datetime]): updated time for the job """ if isinstance(new_time, datetime): self._latest_update_time = new_time.timestamp() elif isinstance(new_time, str): self._latest_update_time = new_time else: raise TypeError('new_time should have type of str or datetime') @property def loader_id(self): """Return the job id.""" return self._job_id @property def samples(self): """Return the information of all samples in the job.""" return self._samples_info @staticmethod def get_create_time(file_path: str) -> float: """Return timestamp of create time of specific path.""" create_time = os.stat(file_path).st_ctime return create_time @staticmethod def get_update_time(file_path: str) -> float: """Return timestamp of update time of specific path.""" update_time = os.stat(file_path).st_mtime return update_time @staticmethod def _total_score_to_dict(total_scores: Iterable): """Transfer a list of benchmark score to a list of dict.""" evaluation_info = [] for total_score in total_scores: metric_result = {'metric': total_score.benchmark_method, 'score': total_score.score} evaluation_info.append(metric_result) return evaluation_info @staticmethod def _label_score_to_dict(label_scores: Iterable, labels: List[str]): """Transfer a list of benchmark score.""" evaluation_info = [{'label': label, 'evaluations': []} for label in labels] for label_score in label_scores: metric = label_score.benchmark_method for i, score in enumerate(label_score.score): label_metric_score = dict() label_metric_score['metric'] = metric label_metric_score['score'] = score evaluation_info[i]['evaluations'].append(label_metric_score) return evaluation_info def _initialize_labels_info(self): """Initialize a dict for labels in the job.""" if self._labels is None: logger.warning('No labels is provided in job %s', self._job_id) return for label_id, label in enumerate(self._labels): self._labels_info[label_id] = {'label': label, 'sample_ids': set()} def _explanation_to_dict(self, explanation, sample_id): """Transfer the explanation from event to dict storage.""" explainer_name = explanation.explain_method explain_label = explanation.label saliency = explanation.heatmap saliency_id = '{}_{}_{}'.format( sample_id, explain_label, explainer_name) explain_info = { 'explainer': explainer_name, 'overlay': saliency_id, } self._overlay_dict[saliency_id] = saliency return explain_info def _image_container_to_dict(self, sample_data): """Transfer the image container to dict storage.""" sample_id = sample_data.image_id sample_info = { 'id': sample_id, 'name': sample_id, 'labels': [self._labels_info[x]['label'] for x in sample_data.ground_truth_label], 'inferences': []} self._image_dict[sample_id] = sample_data.image_data ground_truth_labels = list(sample_data.ground_truth_label) ground_truth_probs = list(sample_data.inference.ground_truth_prob) predicted_labels = list(sample_data.inference.predicted_label) predicted_probs = list(sample_data.inference.predicted_prob) inference_info = {} for label, prob in zip( ground_truth_labels + predicted_labels, ground_truth_probs + predicted_probs): inference_info[label] = { 'label': self._labels_info[label]['label'], 'confidence': prob, 'saliency_maps': []} if EventParser.is_attr_ready(sample_data, 'explanation'): for explanation in sample_data.explanation: explanation_dict = self._explanation_to_dict( explanation, sample_id) inference_info[explanation.label]['saliency_maps'].append(explanation_dict) sample_info['inferences'] = list(inference_info.values()) return sample_info def _import_sample(self, sample): """Add sample object of given sample id.""" for label_id in sample.ground_truth_label: self._labels_info[label_id]['sample_ids'].add(sample.image_id) sample_info = self._image_container_to_dict(sample) self._samples_info.update({sample_info['id']: sample_info}) def retrieve_image(self, image_id: str): """ Retrieve image data from the job given image_id. Return: string, image data in base64 byte """ return self._image_dict.get(image_id, None) def retrieve_overlay(self, overlay_id: str): """ Retrieve sample map from the job given overlay_id. Return: string, saliency_map data in base64 byte """ return self._overlay_dict.get(overlay_id, None) def get_all_samples(self): """ Return a list of sample information cachced in the explain job Returns: sample_list (List[SampleObj]): a list of sample objects, each object consists of: - id (int): sample id - name (str): basename of image - labels (list[str]): list of labels - inferences list[dict]) """ samples_in_list = list(self._samples_info.values()) return samples_in_list def _is_metadata_empty(self): """Check whether metadata is loaded first.""" if not self._explainers or not self._metrics or not self._labels: return True return False def _import_data_from_event(self, event): """Parse and import data from the event data.""" tags = { 'image_id': PluginNameEnum.IMAGE_ID, 'benchmark': PluginNameEnum.BENCHMARK, 'metadata': PluginNameEnum.METADATA } if 'metadata' not in event and self._is_metadata_empty(): raise ValueError('metadata is empty, should write metadata first' 'in the summary.') for tag in tags: if tag not in event: continue if tag == PluginNameEnum.IMAGE_ID.value: sample_event = event[tag] sample_data = self._event_parser.parse_sample(sample_event) if sample_data is not None: self._import_sample(sample_data) continue if tag == PluginNameEnum.BENCHMARK.value: benchmark_event = event[tag].benchmark benchmark = self._event_parser.parse_benchmark(benchmark_event) self._benchmark = benchmark elif tag == PluginNameEnum.METADATA.value: metadata_event = event[tag].metadata metadata = self._event_parser.parse_metadata(metadata_event) self._explainers, self._metrics, self._labels = metadata self._initialize_labels_info() def load(self): """ Start loading data from parser. """ valid_file_names = [] for filename in FileHandler.list_dir(self._summary_dir): if FileHandler.is_file( FileHandler.join(self._summary_dir, filename)): valid_file_names.append(filename) if not valid_file_names: raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir) is_end = False while not is_end: is_clean, is_end, event = self._parser.parse_explain(valid_file_names) if is_clean: logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir) self._clean_job() if event: self._import_data_from_event(event) def _clean_job(self): """Clean the cached data in job.""" self._latest_update_time = ExplainJob.get_update_time(self._summary_dir) self._create_time = ExplainJob.get_update_time(self._summary_dir) self._labels.clear() self._metrics.clear() self._explainers.clear() self._samples_info.clear() self._labels_info.clear() self._benchmark.clear() self._overlay_dict.clear() self._image_dict.clear() self._event_parser.clear()