Browse Source

add ExplainJob, ExplainManager and EventParser for XAI backend

tags/v1.1.0
YuhanShi53 5 years ago
parent
commit
225bf22efb
4 changed files with 909 additions and 13 deletions
  1. +180
    -0
      mindinsight/explainer/manager/event_parse.py
  2. +395
    -0
      mindinsight/explainer/manager/explain_job.py
  3. +314
    -0
      mindinsight/explainer/manager/explain_manager.py
  4. +20
    -13
      mindinsight/explainer/manager/explain_parser.py

+ 180
- 0
mindinsight/explainer/manager/event_parse.py View File

@@ -0,0 +1,180 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""EventParser for summary event."""
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple

from mindinsight.explainer.common.enums import PluginNameEnum
from mindinsight.explainer.common.log import logger
from mindinsight.utils.exceptions import UnknownError

_IMAGE_DATA_TAGS = {
'image_data': PluginNameEnum.IMAGE_DATA.value,
'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value,
'inference': PluginNameEnum.INFERENCE.value,
'explanation': PluginNameEnum.EXPLANATION.value
}


class EventParser:
"""Parser for event data."""

def __init__(self, job):
self._job = job
self._sample_pool = {}

def clear(self):
"""Clear the loaded data."""
self._sample_pool.clear()

def parse_metadata(self, metadata) -> Tuple[List, List, List]:
"""Parse the metadata event."""
explainers = list(metadata.explain_method)
metrics = list(metadata.benchmark_method)
labels = list(metadata.label)
return explainers, metrics, labels

def parse_benchmark(self, benchmark) -> Dict:
"""Parse the benchmark event."""
imported_benchmark = {}
for explainer_result in benchmark:
explainer = explainer_result.explain_method
total_score = explainer_result.total_score
label_score = explainer_result.label_score

explainer_benchmark = {
'explainer': explainer,
'evaluations': EventParser._total_score_to_dict(total_score),
'class_scores': EventParser._label_score_to_dict(label_score, self._job.labels)
}
imported_benchmark[explainer] = explainer_benchmark
return imported_benchmark

def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]:
"""Parse the sample event."""
sample_id = sample.image_id

if sample_id not in self._sample_pool:
self._sample_pool[sample_id] = sample
return None

for tag in _IMAGE_DATA_TAGS:
try:
if tag == PluginNameEnum.INFERENCE.value:
self._parse_inference(sample, sample_id)
elif tag == PluginNameEnum.EXPLANATION.value:
self._parse_explanation(sample, sample_id)
else:
self._parse_sample_info(sample, sample_id, tag)
except UnknownError as ex:
logger.warning("Parse %s data failed within image related data,"
" detail: %r", tag, str(ex))
continue

if EventParser._is_sample_data_complete(self._sample_pool[sample_id]):
return self._sample_pool.pop(sample_id)
if EventParser._is_ready_for_display(self._sample_pool[sample_id]):
return self._sample_pool[sample_id]
return None

def _parse_inference(self, event, sample_id):
"""Parse the inference event."""
self._sample_pool[sample_id].inference.ground_truth_prob.extend(
event.inference.ground_truth_prob)
self._sample_pool[sample_id].inference.predicted_label.extend(
event.inference.predicted_label)
self._sample_pool[sample_id].inference.predicted_prob.extend(
event.inference.predicted_prob)

def _parse_explanation(self, event, sample_id):
"""Parse the explanation event."""
if event.explanation:
for explanation_item in event.explanation:
new_explanation = self._sample_pool[sample_id].explanation.add()
new_explanation.explain_method = explanation_item.explain_method
new_explanation.label = explanation_item.label
new_explanation.heatmap = explanation_item.heatmap

def _parse_sample_info(self, event, sample_id, tag):
"""Parse the event containing image info."""
if not getattr(self._sample_pool[sample_id], tag):
setattr(self._sample_pool[sample_id], tag, getattr(event, tag))

@staticmethod
def _total_score_to_dict(total_scores: Iterable):
"""Transfer a list of benchmark score to a list of dict."""
evaluation_info = []
for total_score in total_scores:
metric_result = {
'metric': total_score.benchmark_method,
'score': total_score.score}
evaluation_info.append(metric_result)
return evaluation_info

@staticmethod
def _label_score_to_dict(label_scores: Iterable, labels: List[str]):
"""Transfer a list of benchmark score."""
evaluation_info = [{'label': label, 'evaluations': []}
for label in labels]
for label_score in label_scores:
metric = label_score.benchmark_method
for i, score in enumerate(label_score.score):
label_metric_score = {
'metric': metric,
'score': score}
evaluation_info[i]['evaluations'].append(label_metric_score)
return evaluation_info

@staticmethod
def _is_sample_data_complete(image_container: namedtuple) -> bool:
"""Check whether sample data completely loaded."""
required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference', 'explanation']
for attr in required_attrs:
if not EventParser.is_attr_ready(image_container, attr):
return False
return True

@staticmethod
def _is_ready_for_display(image_container: namedtuple) -> bool:
"""
Check whether the image_container is ready for frontend display.

Args:
image_container (nametuple): container consists of sample data

Return:
bool: whether the image_container if ready for display
"""
required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference']
for attr in required_attrs:
if not EventParser.is_attr_ready(image_container, attr):
return False
return True

@staticmethod
def is_attr_ready(image_container: namedtuple, attr: str) -> bool:
"""
Check whether the given attribute is ready in image_container.

Args:
image_container (nametuple): container consist of sample data
attr (str): attribute to check

Returns:
bool, whether the attr is ready
"""
if getattr(image_container, attr, False):
return True
return False

+ 395
- 0
mindinsight/explainer/manager/explain_job.py View File

@@ -0,0 +1,395 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExplainJob."""

import os
from datetime import datetime
from typing import List, Iterable, Union

from mindinsight.explainer.common.enums import PluginNameEnum
from mindinsight.explainer.common.log import logger
from mindinsight.explainer.manager.explain_parser import _ExplainParser
from mindinsight.explainer.manager.event_parse import EventParser
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError


class ExplainJob:
"""ExplainJob which manage the record in the summary file."""

def __init__(self,
job_id: str,
summary_dir: str,
create_time: float,
latest_update_time: float):

self._job_id = job_id
self._summary_dir = summary_dir
self._parser = _ExplainParser(summary_dir)

self._event_parser = EventParser(self)
self._latest_update_time = latest_update_time
self._create_time = create_time
self._labels = []
self._metrics = []
self._explainers = []
self._samples_info = {}
self._labels_info = {}
self._benchmark = {}
self._overlay_dict = {}
self._image_dict = {}

@property
def all_classes(self):
"""
Return a list of label info

Returns:
class_objs (List[ClassObj]): a list of class_objects, each object
contains:

- id (int): label id
- label (str): label name
- sample_count (int): number of samples for each label
"""
all_classes_return = []
for label_id, label_info in self._labels_info.items():
single_info = {'id': label_id,
'label': label_info['label'],
'sample_count': len(label_info['sample_ids'])}
all_classes_return.append(single_info)
return all_classes_return

@property
def explainers(self):
"""
Return a list of explainer names

Returns:
list(str), explainer names
"""
return self._explainers

@property
def explainer_scores(self):
"""Return evaluation results for every explainer."""
return [score for score in self._benchmark.values()]

@property
def sample_count(self):
"""
Return total number of samples in the job.

Return:
int, total number of samples

"""
return len(self._samples_info)

@property
def train_id(self):
"""
Return ID of explain job

Returns:
str, id of ExplainJob object
"""
return self._job_id

@property
def metrics(self):
"""
Return a list of metric names

Returns:
list(str), metric names
"""
return self._metrics

@property
def min_confidence(self):
"""
Return minimum confidence

Returns:
min_confidence (float):
"""
return None

@property
def create_time(self):
"""
Return the create time of summary file

Returns:
creation timestamp (float)

"""
return self._create_time

@property
def labels(self):
"""Return the label contained in the job."""
return self._labels

@property
def latest_update_time(self):
"""
Return last modification time stamp of summary file.

Returns:
float, last_modification_time stamp
"""
return self._latest_update_time

@latest_update_time.setter
def latest_update_time(self, new_time: Union[float, datetime]):
"""
Update the latest_update_time timestamp manually.

Args:
new_time stamp (union[float, datetime]): updated time for the job
"""
if isinstance(new_time, datetime):
self._latest_update_time = new_time.timestamp()
elif isinstance(new_time, str):
self._latest_update_time = new_time
else:
raise TypeError('new_time should have type of str or datetime')

@property
def loader_id(self):
"""Return the job id."""
return self._job_id

@property
def samples(self):
"""Return the information of all samples in the job."""
return self._samples_info

@staticmethod
def get_create_time(file_path: str) -> float:
"""Return timestamp of create time of specific path."""
create_time = os.stat(file_path).st_ctime
return create_time

@staticmethod
def get_update_time(file_path: str) -> float:
"""Return timestamp of update time of specific path."""
update_time = os.stat(file_path).st_mtime
return update_time

@staticmethod
def _total_score_to_dict(total_scores: Iterable):
"""Transfer a list of benchmark score to a list of dict."""
evaluation_info = []
for total_score in total_scores:
metric_result = {'metric': total_score.benchmark_method,
'score': total_score.score}
evaluation_info.append(metric_result)
return evaluation_info

@staticmethod
def _label_score_to_dict(label_scores: Iterable, labels: List[str]):
"""Transfer a list of benchmark score."""
evaluation_info = [{'label': label, 'evaluations': []}
for label in labels]
for label_score in label_scores:
metric = label_score.benchmark_method
for i, score in enumerate(label_score.score):
label_metric_score = dict()
label_metric_score['metric'] = metric
label_metric_score['score'] = score
evaluation_info[i]['evaluations'].append(label_metric_score)
return evaluation_info

def _initialize_labels_info(self):
"""Initialize a dict for labels in the job."""
if self._labels is None:
logger.warning('No labels is provided in job %s', self._job_id)
return

for label_id, label in enumerate(self._labels):
self._labels_info[label_id] = {'label': label,
'sample_ids': set()}

def _explanation_to_dict(self, explanation, sample_id):
"""Transfer the explanation from event to dict storage."""
explainer_name = explanation.explain_method
explain_label = explanation.label
saliency = explanation.heatmap
saliency_id = '{}_{}_{}'.format(
sample_id, explain_label, explainer_name)
explain_info = {
'explainer': explainer_name,
'overlay': saliency_id,
}
self._overlay_dict[saliency_id] = saliency
return explain_info

def _image_container_to_dict(self, sample_data):
"""Transfer the image container to dict storage."""
sample_id = sample_data.image_id

sample_info = {
'id': sample_id,
'name': sample_id,
'labels': [self._labels_info[x]['label']
for x in sample_data.ground_truth_label],
'inferences': []}
self._image_dict[sample_id] = sample_data.image_data

ground_truth_labels = list(sample_data.ground_truth_label)
ground_truth_probs = list(sample_data.inference.ground_truth_prob)
predicted_labels = list(sample_data.inference.predicted_label)
predicted_probs = list(sample_data.inference.predicted_prob)

inference_info = {}
for label, prob in zip(
ground_truth_labels + predicted_labels,
ground_truth_probs + predicted_probs):
inference_info[label] = {
'label': self._labels_info[label]['label'],
'confidence': prob,
'saliency_maps': []}

if EventParser.is_attr_ready(sample_data, 'explanation'):
for explanation in sample_data.explanation:
explanation_dict = self._explanation_to_dict(
explanation, sample_id)
inference_info[explanation.label]['saliency_maps'].append(explanation_dict)

sample_info['inferences'] = list(inference_info.values())
return sample_info

def _import_sample(self, sample):
"""Add sample object of given sample id."""
for label_id in sample.ground_truth_label:
self._labels_info[label_id]['sample_ids'].add(sample.image_id)

sample_info = self._image_container_to_dict(sample)
self._samples_info.update({sample_info['id']: sample_info})

def retrieve_image(self, image_id: str):
"""
Retrieve image data from the job given image_id.

Return:
string, image data in base64 byte

"""
return self._image_dict.get(image_id, None)

def retrieve_overlay(self, overlay_id: str):
"""
Retrieve sample map from the job given overlay_id.

Return:
string, saliency_map data in base64 byte
"""
return self._overlay_dict.get(overlay_id, None)

def get_all_samples(self):
"""
Return a list of sample information cachced in the explain job

Returns:
sample_list (List[SampleObj]): a list of sample objects, each object
consists of:

- id (int): sample id
- name (str): basename of image
- labels (list[str]): list of labels
- inferences list[dict])
"""
samples_in_list = list(self._samples_info.values())
return samples_in_list

def _is_metadata_empty(self):
"""Check whether metadata is loaded first."""
if not self._explainers or not self._metrics or not self._labels:
return True
return False

def _import_data_from_event(self, event):
"""Parse and import data from the event data."""
tags = {
'image_id': PluginNameEnum.IMAGE_ID,
'benchmark': PluginNameEnum.BENCHMARK,
'metadata': PluginNameEnum.METADATA
}

if 'metadata' not in event and self._is_metadata_empty():
raise ValueError('metadata is empty, should write metadata first'
'in the summary.')
for tag in tags:
if tag not in event:
continue

if tag == PluginNameEnum.IMAGE_ID.value:
sample_event = event[tag]
sample_data = self._event_parser.parse_sample(sample_event)
if sample_data is not None:
self._import_sample(sample_data)
continue

if tag == PluginNameEnum.BENCHMARK.value:
benchmark_event = event[tag].benchmark
benchmark = self._event_parser.parse_benchmark(benchmark_event)
self._benchmark = benchmark

elif tag == PluginNameEnum.METADATA.value:
metadata_event = event[tag].metadata
metadata = self._event_parser.parse_metadata(metadata_event)
self._explainers, self._metrics, self._labels = metadata
self._initialize_labels_info()

def load(self):
"""
Start loading data from parser.
"""
valid_file_names = []
for filename in FileHandler.list_dir(self._summary_dir):
if FileHandler.is_file(
FileHandler.join(self._summary_dir, filename)):
valid_file_names.append(filename)

if not valid_file_names:
raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir)

is_end = False
while not is_end:
is_clean, is_end, event = self._parser.parse_explain(valid_file_names)

if is_clean:
logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir)
self._clean_job()

if event:
self._import_data_from_event(event)

def _clean_job(self):
"""Clean the cached data in job."""
self._latest_update_time = ExplainJob.get_update_time(self._summary_dir)
self._create_time = ExplainJob.get_update_time(self._summary_dir)
self._labels.clear()
self._metrics.clear()
self._explainers.clear()
self._samples_info.clear()
self._labels_info.clear()
self._benchmark.clear()
self._overlay_dict.clear()
self._image_dict.clear()
self._event_parser.clear()

+ 314
- 0
mindinsight/explainer/manager/explain_manager.py View File

@@ -0,0 +1,314 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExplainManager."""

import os
import threading
import time

from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import BaseEnum
from mindinsight.explainer.common.log import logger
from mindinsight.explainer.manager.explain_job import ExplainJob
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.utils.exceptions import MindInsightException, ParamValueError

_MAX_LOADER_NUM = 3
_MAX_INTERVAL = 3


class _ExplainManagerStatus(BaseEnum):
"""Manager status."""
INIT = 'INIT'
LOADING = 'LOADING'
DONE = 'DONE'
INVALID = 'INVALID'


class ExplainManager:
"""ExplainManager."""

def __init__(self, summary_base_dir: str):
self._summary_base_dir = summary_base_dir
self._loader_pool = {}
self._deleted_ids = []
self._status = _ExplainManagerStatus.INIT.value
self._status_mutex = threading.Lock()
self._loader_pool_mutex = threading.Lock()
self._max_loader_num = _MAX_LOADER_NUM
self._reload_interval = None

def _reload_data(self):
"""periodically load summary from file."""
while True:
self._load_data()

if not self._reload_interval:
break
time.sleep(self._reload_interval)

def _load_data(self):
"""Loading the summary in the given base directory."""
logger.info(
'Start to load data, reload interval: %r.', self._reload_interval)

with self._status_mutex:
if self._status == _ExplainManagerStatus.LOADING.value:
logger.info('Current status is %s, will ignore to load data.',
self._status)
return

self._status = _ExplainManagerStatus.LOADING.value

self._generate_loaders()
self._execute_load_data()

if not self._loader_pool:
self._status = _ExplainManagerStatus.INVALID.value
else:
self._status = _ExplainManagerStatus.DONE.value

logger.info('Load event data end, status: %r, '
'and loader pool size is %r',
self._status, len(self._loader_pool))

def _update_loader_latest_update_time(self, loader_id, latest_update_time=None):
"""update the update time of loader of given id."""
if latest_update_time is None:
latest_update_time = time.time()
self._loader_pool[loader_id].latest_update_time = latest_update_time

def _delete_loader(self, loader_id):
"""delete loader given loader_id"""
if self._loader_pool.get(loader_id, None) is not None:
self._loader_pool.pop(loader_id)
logger.debug('delete loader %s', loader_id)

def _add_loader(self, loader):
"""add loader to the loader_pool."""
if len(self._loader_pool) >= _MAX_LOADER_NUM:
delete_num = len(self._loader_pool) - _MAX_LOADER_NUM + 1
sorted_loaders = sorted(
self._loader_pool.items(),
key=lambda x: x[1].latest_update_time)

for index in range(delete_num):
delete_loader_id = sorted_loaders[index][0]
self._delete_loader(delete_loader_id)
self._loader_pool.update({loader.loader_id: loader})

def _deal_loaders(self, latest_loaders):
""""update the loader pool."""
with self._loader_pool_mutex:
for loader_id, loader in latest_loaders:
if self._loader_pool.get(loader_id, None) is None:
self._add_loader(loader)
continue

if (self._loader_pool[loader_id].latest_update_time
< loader.latest_update_time):
self._update_loader_latest_update_time(
loader_id, loader.latest_update_time)

@staticmethod
def _generate_loader_id(relative_path):
"""Generate loader id for given path"""
loader_id = relative_path
return loader_id

@staticmethod
def _generate_loader_name(relative_path):
"""Generate_loader name for given path."""
loader_name = relative_path
return loader_name

def _generate_loader_by_relative_path(self, relative_path: str) -> ExplainJob:
"""Generate explain job from given relative path."""
current_dir = os.path.realpath(FileHandler.join(
self._summary_base_dir, relative_path
))
loader_id = self._generate_loader_id(relative_path)
loader = ExplainJob(
job_id=loader_id,
summary_dir=current_dir,
create_time=ExplainJob.get_create_time(current_dir),
latest_update_time=ExplainJob.get_update_time(current_dir))
return loader

def _generate_loaders(self):
"""Generate job loaders from the summary watcher."""
dir_map_mtime_dict = {}
loader_dict = {}
min_modify_time = None
_, summaries = SummaryWatcher().list_explain_directories(
self._summary_base_dir)

for item in summaries:
relative_path = item.get('relative_path')
modify_time = item.get('update_time').timestamp()
loader_id = self._generate_loader_id(relative_path)

loader = self._loader_pool.get(loader_id, None)
if loader is not None and loader.latest_update_time > modify_time:
modify_time = loader.latest_update_time

if min_modify_time is None:
min_modify_time = modify_time

if len(dir_map_mtime_dict) < _MAX_LOADER_NUM:
if modify_time < min_modify_time:
min_modify_time = modify_time
dir_map_mtime_dict.update({relative_path: modify_time})
else:
if modify_time >= min_modify_time:
dir_map_mtime_dict.update({relative_path: modify_time})

sorted_dir_tuple = sorted(dir_map_mtime_dict.items(),
key=lambda d: d[1])[-_MAX_LOADER_NUM:]

for relative_path, modify_time in sorted_dir_tuple:
loader_id = self._generate_loader_id(relative_path)
loader = self._generate_loader_by_relative_path(relative_path)
loader_dict.update({loader_id: loader})

sorted_loaders = sorted(loader_dict.items(),
key=lambda x: x[1].latest_update_time)
latest_loaders = sorted_loaders[-_MAX_LOADER_NUM:]
self._deal_loaders(latest_loaders)

def _execute_loader(self, loader_id):
"""Execute the data loading."""
try:
with self._loader_pool_mutex:
loader = self._loader_pool.get(loader_id, None)
if loader is None:
logger.debug('Loader %r has been deleted, will not load'
'data', loader_id)
return
loader.load()

except MindInsightException as e:
logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, e)
with self._loader_pool_mutex:
self._delete_loader(loader_id)

def _execute_load_data(self):
"""Execute the loader in the pool to load data."""
loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool:
self._execute_loader(loader_id)

def _get_snapshot_loader_pool(self):
"""Get snapshot of loader_pool."""
with self._loader_pool_mutex:
return dict(self._loader_pool)

def _check_status_valid(self):
"""Check manager status."""
if self._status == _ExplainManagerStatus.INIT.value:
raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._status)

@staticmethod
def _check_train_id_valid(train_id: str):
"""Verify the train_id is valid."""
if not train_id.startswith('./'):
logger.warning('train_id does not start with "./"')
return False

if len(train_id.split('/')) > 2:
logger.warning('train_id contains multiple "/"')
return False
return True

def _check_train_job_exist(self, train_id):
"""Verify thee train_job is existed given train_id."""
if train_id in self._loader_pool:
return
self._check_train_id_valid(train_id)
if SummaryWatcher().is_summary_directory(self._summary_base_dir, train_id):
return
raise ParamValueError('Can not find the train job in the manager, train_id: %s' % train_id)

def _reload_data_again(self):
"""Reload the data one more time."""
logger.debug('Start to reload data again.')
thread = threading.Thread(target=self._load_data,
name='reload_data_thread')
thread.daemon = False
thread.start()

def _get_job(self, train_id):
"""Retrieve train_job given train_id."""
is_reload = False
with self._loader_pool_mutex:
loader = self._loader_pool.get(train_id, None)

if loader is None:
relative_path = train_id
temp_loader = self._generate_loader_by_relative_path(
relative_path)

if temp_loader is None:
return None

self._add_loader(temp_loader)
is_reload = True

if is_reload:
self._reload_data_again()
return loader

@property
def summary_base_dir(self):
"""Return the base directory for summary records."""
return self._summary_base_dir

def get_job(self, train_id):
"""
Return ExplainJob given train_id.

If explain job w.r.t given train_id is not found, None will be returned.

Args:
train_id (str): The id of expected ExplainJob

Return:
explain_job
"""
self._check_status_valid()
self._check_train_job_exist(train_id)

loader = self._get_job(train_id)
if loader is None:
return None
return loader

def start_load_data(self,
reload_interval=_MAX_INTERVAL):
"""
Start threads for loading data.

Args:
reload_interval (int): interval to reload the summary from file
"""
self._reload_interval = reload_interval

thread = threading.Thread(target=self._reload_data, name='start_load_data_thread')
thread.daemon = True
thread.start()

# wait for data loading
time.sleep(1)

+ 20
- 13
mindinsight/explainer/manager/explain_parser.py View File

@@ -28,20 +28,36 @@ from mindinsight.explainer.common.log import logger
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain
from mindinsight.utils.exceptions import UnknownError

HEADER_SIZE = 8
CRC_STR_SIZE = 4
MAX_EVENT_STRING = 500000000
ImageDataContainer = collections.namedtuple('ImageDataContainer',
['image_id', 'image_data', 'ground_truth_label',
'inference', 'explanation', 'status'])
BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status'])
MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status'])


class ImageDataContainer:
"""
Container for image data to allow pickling.

Args:
explain_message (Explain): Explain proto buffer message.
"""

def __init__(self, explain_message: Explain):
self.image_id = explain_message.image_id
self.image_data = explain_message.image_data
self.ground_truth_label = explain_message.ground_truth_label
self.inference = explain_message.inference
self.explanation = explain_message.explanation
self.status = explain_message.status


class _ExplainParser(_SummaryParser):
"""The summary file parser."""

def __init__(self, summary_dir):
super(_ExplainParser, self).__init__(summary_dir)
self._latest_filename = ''
@@ -165,7 +181,6 @@ class _ExplainParser(_SummaryParser):
tensor_value_list.append(tensor_value)
return field_list, tensor_value_list


@staticmethod
def _add_image_data(tensor_event_value):
"""
@@ -174,17 +189,9 @@ class _ExplainParser(_SummaryParser):
Args:
tensor_event_value: the object of Explain message
"""
image_data = ImageDataContainer(
image_id=tensor_event_value.image_id,
image_data=tensor_event_value.image_data,
ground_truth_label=tensor_event_value.ground_truth_label,
inference=tensor_event_value.inference,
explanation=tensor_event_value.explanation,
status=tensor_event_value.status
)
image_data = ImageDataContainer(tensor_event_value)
return image_data


@staticmethod
def _add_benchmark(tensor_event_value):
"""


Loading…
Cancel
Save