The SummaryRecord.add_value() method is extended to record the data of MindExplain.tags/v1.1.0
| @@ -40,6 +40,8 @@ message Event { | |||
| // Summary data | |||
| Summary summary = 5; | |||
| Explain explain = 6; | |||
| } | |||
| } | |||
| @@ -101,3 +103,50 @@ message Summary { | |||
| // Set of values for the summary. | |||
| repeated Value value = 1; | |||
| } | |||
| message Explain { | |||
| message Inference{ | |||
| repeated float ground_truth_prob = 1; | |||
| repeated int32 predicted_label = 2; | |||
| repeated float predicted_prob = 3; | |||
| } | |||
| message Explanation{ | |||
| optional string explain_method = 1; | |||
| optional int32 label = 2; | |||
| optional bytes heatmap = 3; | |||
| } | |||
| message Benchmark{ | |||
| message TotalScore{ | |||
| optional string benchmark_method = 1; | |||
| optional float score = 2; | |||
| } | |||
| message LabelScore{ | |||
| repeated float score = 1; | |||
| optional string benchmark_method = 2; | |||
| } | |||
| optional string explain_method = 1; | |||
| repeated TotalScore total_score = 2; | |||
| repeated LabelScore label_score = 3; | |||
| } | |||
| message Metadata{ | |||
| repeated string label = 1; | |||
| repeated string explain_method = 2; | |||
| repeated string benchmark_method = 3; | |||
| } | |||
| optional string image_id = 1; // The Metadata and image id must have one fill in | |||
| optional bytes image_data = 2; | |||
| repeated int32 ground_truth_label = 3; | |||
| optional Inference inference = 4; | |||
| repeated Explanation explanation = 5; | |||
| repeated Benchmark benchmark = 6; | |||
| optional Metadata metadata = 7; | |||
| optional string status = 8; // enum value: run, end | |||
| } | |||
| @@ -26,7 +26,7 @@ from mindspore import log as logger | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| from mindspore.train.summary.enum import PluginEnum, ModeEnum | |||
| from mindspore.train.summary.enums import PluginEnum, ModeEnum | |||
| from mindspore.train.callback import Callback, ModelCheckpoint | |||
| from mindspore.train import lineage_pb2 | |||
| from mindspore.train.callback._dataset_graph import DatasetGraph | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Generate the explain event which conform to proto format.""" | |||
| import time | |||
| from ..summary_pb2 import Event, Explain | |||
| def check_explain_proto(explain): | |||
| """ | |||
| Package the explain event. | |||
| Args: | |||
| explain (Explain): The object of summary_pb2.Explain. | |||
| """ | |||
| if not isinstance(explain, Explain): | |||
| raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.') | |||
| if not explain.image_id and not explain.metadata.label and not explain.benchmark: | |||
| raise ValueError(f'The Metadata and image id and benchmark must have one fill in.') | |||
| def package_explain_event(explain_str): | |||
| """ | |||
| Package the explain event. | |||
| Args: | |||
| explain_str (string): The serialize string of summary_pb2.Explain. | |||
| Returns: | |||
| Event, event object. | |||
| """ | |||
| event = Event() | |||
| event.wall_time = time.time() | |||
| event.explain.ParseFromString(explain_str) | |||
| return event.SerializeToString() | |||
| @@ -21,7 +21,8 @@ import mindspore.log as logger | |||
| from ._lineage_adapter import serialize_to_lineage_event | |||
| from ._summary_adapter import package_graph_event, package_summary_event | |||
| from ._summary_writer import LineageWriter, SummaryWriter | |||
| from ._explain_adapter import package_explain_event | |||
| from .writer import LineageWriter, SummaryWriter, ExplainWriter | |||
| try: | |||
| from multiprocessing import get_context | |||
| @@ -42,6 +43,8 @@ def _pack_data(datadict, wall_time): | |||
| elif plugin in ('scalar', 'tensor', 'histogram', 'image'): | |||
| summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) | |||
| step = data.get('step') | |||
| elif plugin == 'explainer': | |||
| result.append([plugin, package_explain_event(data.get('value'))]) | |||
| if summaries: | |||
| result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()]) | |||
| return result | |||
| @@ -98,6 +101,8 @@ class WriterPool(ctx.Process): | |||
| self._writers_.append(SummaryWriter(filepath, self._max_file_size)) | |||
| elif plugin == 'lineage': | |||
| self._writers_.append(LineageWriter(filepath, self._max_file_size)) | |||
| elif plugin == 'explainer': | |||
| self._writers_.append(ExplainWriter(filepath, self._max_file_size)) | |||
| return self._writers_ | |||
| def _write(self, plugin, data): | |||
| @@ -125,7 +130,6 @@ class WriterPool(ctx.Process): | |||
| Write the event to file. | |||
| Args: | |||
| name (str): The key of a specified file. | |||
| data (Optional[str, Tuple[list, int]]): The data to write. | |||
| """ | |||
| self._queue.put(('WRITE', data)) | |||
| @@ -17,6 +17,7 @@ import atexit | |||
| import os | |||
| import re | |||
| import threading | |||
| from collections import defaultdict | |||
| from mindspore import log as logger | |||
| @@ -24,6 +25,7 @@ from ..._c_expression import Tensor | |||
| from ..._checkparam import Validator | |||
| from .._utils import _check_lineage_value, _check_to_numpy, _make_directory | |||
| from ._summary_adapter import get_event_file_name, package_graph_event | |||
| from ._explain_adapter import check_explain_proto | |||
| from ._writer_pool import WriterPool | |||
| # for the moment, this lock is for caution's sake, | |||
| @@ -55,7 +57,6 @@ def _get_summary_tensor_data(): | |||
| def _dictlist(): | |||
| from collections import defaultdict | |||
| return defaultdict(list) | |||
| @@ -133,7 +134,8 @@ class SummaryRecord: | |||
| self._event_writer = WriterPool(log_dir, | |||
| max_file_size, | |||
| summary=self.full_file_name, | |||
| lineage=get_event_file_name(self.prefix, '_lineage')) | |||
| lineage=get_event_file_name(self.prefix, '_lineage'), | |||
| explainer=get_event_file_name(self.prefix, '_explain')) | |||
| _get_summary_tensor_data() | |||
| atexit.register(self.close) | |||
| @@ -149,10 +151,11 @@ class SummaryRecord: | |||
| def set_mode(self, mode): | |||
| """ | |||
| Set the mode for the recorder to be aware. The mode is set to 'train' by default. | |||
| Sets the training phase. Different training phases affect data recording. | |||
| Args: | |||
| mode (str): The mode to be set, which should be 'train' or 'eval'. | |||
| mode (str): The mode to be set, which should be 'train' or 'eval'. When the mode is 'eval', | |||
| summary_record will not record the data of summary operators. | |||
| Raises: | |||
| ValueError: When the mode is not recognized. | |||
| @@ -170,29 +173,26 @@ class SummaryRecord: | |||
| """ | |||
| Add value to be recorded later. | |||
| When the plugin is 'tensor', 'scalar', 'image' or 'histogram', | |||
| the name should be the tag name, and the value should be a Tensor. | |||
| When the plugin is 'graph', the value should be a GraphProto. | |||
| When the plugin is 'dataset_graph', 'train_lineage', 'eval_lineage', | |||
| or 'custom_lineage_data', the value should be a proto message. | |||
| Args: | |||
| plugin (str): The value of the plugin. | |||
| name (str): The value of the name. | |||
| value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ | |||
| The value to store. | |||
| - The data type of value should be 'GraphProto' when the plugin is 'graph'. | |||
| - The data type of value should be 'Tensor' when the plugin is 'scalar', 'image', 'tensor' | |||
| - The data type of value should be 'GraphProto' (see mindspore/ccsrc/anf_ir.proto) object | |||
| when the plugin is 'graph'. | |||
| - The data type of value should be 'Tensor' object when the plugin is 'scalar', 'image', 'tensor' | |||
| or 'histogram'. | |||
| - The data type of value should be 'TrainLineage' when the plugin is 'train_lineage'. | |||
| - The data type of value should be 'EvaluationLineage' when the plugin is 'eval_lineage'. | |||
| - The data type of value should be 'DatasetGraph' when the plugin is 'dataset_graph'. | |||
| - The data type of value should be 'UserDefinedInfo' when the plugin is 'custom_lineage_data'. | |||
| - The data type of value should be a 'TrainLineage' object when the plugin is 'train_lineage', | |||
| see mindspore/ccsrc/lineage.proto. | |||
| - The data type of value should be a 'EvaluationLineage' object when the plugin is 'eval_lineage', | |||
| see mindspore/ccsrc/lineage.proto. | |||
| - The data type of value should be a 'DatasetGraph' object when the plugin is 'dataset_graph', | |||
| see mindspore/ccsrc/lineage.proto. | |||
| - The data type of value should be a 'UserDefinedInfo' object when the plugin is 'custom_lineage_data', | |||
| see mindspore/ccsrc/lineage.proto. | |||
| - The data type of value should be a 'Explain' object when the plugin is 'explainer', | |||
| see mindspore/ccsrc/summary.proto. | |||
| Raises: | |||
| ValueError: When the name is not valid. | |||
| TypeError: When the value is not a Tensor. | |||
| @@ -218,6 +218,9 @@ class SummaryRecord: | |||
| elif plugin == 'graph': | |||
| package_graph_event(value) | |||
| self._data_pool[plugin].append(dict(value=value)) | |||
| elif plugin == 'explainer': | |||
| check_explain_proto(value) | |||
| self._data_pool[plugin].append(dict(value=value.SerializeToString())) | |||
| else: | |||
| raise ValueError(f'No such plugin of {repr(plugin)}') | |||
| @@ -94,3 +94,12 @@ class LineageWriter(BaseWriter): | |||
| """Write data to file.""" | |||
| if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): | |||
| super().write(plugin, data) | |||
| class ExplainWriter(BaseWriter): | |||
| """ExplainWriter for write explain data.""" | |||
| def write(self, plugin, data): | |||
| """Write data to file.""" | |||
| if plugin == 'explainer': | |||
| super().write(plugin, data) | |||
| @@ -26,7 +26,7 @@ from mindspore import Tensor | |||
| from mindspore import Parameter | |||
| from mindspore.train.callback import SummaryCollector | |||
| from mindspore.train.callback import _InternalCallbackParam | |||
| from mindspore.train.summary.enum import ModeEnum, PluginEnum | |||
| from mindspore.train.summary.enums import ModeEnum, PluginEnum | |||
| from mindspore.train.summary import SummaryRecord | |||
| from mindspore.nn import Cell | |||
| from mindspore.nn.optim.optimizer import Optimizer | |||