The SummaryRecord.add_value() method is extended to record the data of MindExplain.tags/v1.1.0
| @@ -40,6 +40,8 @@ message Event { | |||||
| // Summary data | // Summary data | ||||
| Summary summary = 5; | Summary summary = 5; | ||||
| Explain explain = 6; | |||||
| } | } | ||||
| } | } | ||||
| @@ -101,3 +103,50 @@ message Summary { | |||||
| // Set of values for the summary. | // Set of values for the summary. | ||||
| repeated Value value = 1; | repeated Value value = 1; | ||||
| } | } | ||||
| message Explain { | |||||
| message Inference{ | |||||
| repeated float ground_truth_prob = 1; | |||||
| repeated int32 predicted_label = 2; | |||||
| repeated float predicted_prob = 3; | |||||
| } | |||||
| message Explanation{ | |||||
| optional string explain_method = 1; | |||||
| optional int32 label = 2; | |||||
| optional bytes heatmap = 3; | |||||
| } | |||||
| message Benchmark{ | |||||
| message TotalScore{ | |||||
| optional string benchmark_method = 1; | |||||
| optional float score = 2; | |||||
| } | |||||
| message LabelScore{ | |||||
| repeated float score = 1; | |||||
| optional string benchmark_method = 2; | |||||
| } | |||||
| optional string explain_method = 1; | |||||
| repeated TotalScore total_score = 2; | |||||
| repeated LabelScore label_score = 3; | |||||
| } | |||||
| message Metadata{ | |||||
| repeated string label = 1; | |||||
| repeated string explain_method = 2; | |||||
| repeated string benchmark_method = 3; | |||||
| } | |||||
| optional string image_id = 1; // The Metadata and image id must have one fill in | |||||
| optional bytes image_data = 2; | |||||
| repeated int32 ground_truth_label = 3; | |||||
| optional Inference inference = 4; | |||||
| repeated Explanation explanation = 5; | |||||
| repeated Benchmark benchmark = 6; | |||||
| optional Metadata metadata = 7; | |||||
| optional string status = 8; // enum value: run, end | |||||
| } | |||||
| @@ -26,7 +26,7 @@ from mindspore import log as logger | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.train.summary.summary_record import SummaryRecord | from mindspore.train.summary.summary_record import SummaryRecord | ||||
| from mindspore.train.summary.enum import PluginEnum, ModeEnum | |||||
| from mindspore.train.summary.enums import PluginEnum, ModeEnum | |||||
| from mindspore.train.callback import Callback, ModelCheckpoint | from mindspore.train.callback import Callback, ModelCheckpoint | ||||
| from mindspore.train import lineage_pb2 | from mindspore.train import lineage_pb2 | ||||
| from mindspore.train.callback._dataset_graph import DatasetGraph | from mindspore.train.callback._dataset_graph import DatasetGraph | ||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Generate the explain event which conform to proto format.""" | |||||
| import time | |||||
| from ..summary_pb2 import Event, Explain | |||||
| def check_explain_proto(explain): | |||||
| """ | |||||
| Package the explain event. | |||||
| Args: | |||||
| explain (Explain): The object of summary_pb2.Explain. | |||||
| """ | |||||
| if not isinstance(explain, Explain): | |||||
| raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.') | |||||
| if not explain.image_id and not explain.metadata.label and not explain.benchmark: | |||||
| raise ValueError(f'The Metadata and image id and benchmark must have one fill in.') | |||||
| def package_explain_event(explain_str): | |||||
| """ | |||||
| Package the explain event. | |||||
| Args: | |||||
| explain_str (string): The serialize string of summary_pb2.Explain. | |||||
| Returns: | |||||
| Event, event object. | |||||
| """ | |||||
| event = Event() | |||||
| event.wall_time = time.time() | |||||
| event.explain.ParseFromString(explain_str) | |||||
| return event.SerializeToString() | |||||
| @@ -21,7 +21,8 @@ import mindspore.log as logger | |||||
| from ._lineage_adapter import serialize_to_lineage_event | from ._lineage_adapter import serialize_to_lineage_event | ||||
| from ._summary_adapter import package_graph_event, package_summary_event | from ._summary_adapter import package_graph_event, package_summary_event | ||||
| from ._summary_writer import LineageWriter, SummaryWriter | |||||
| from ._explain_adapter import package_explain_event | |||||
| from .writer import LineageWriter, SummaryWriter, ExplainWriter | |||||
| try: | try: | ||||
| from multiprocessing import get_context | from multiprocessing import get_context | ||||
| @@ -42,6 +43,8 @@ def _pack_data(datadict, wall_time): | |||||
| elif plugin in ('scalar', 'tensor', 'histogram', 'image'): | elif plugin in ('scalar', 'tensor', 'histogram', 'image'): | ||||
| summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) | summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) | ||||
| step = data.get('step') | step = data.get('step') | ||||
| elif plugin == 'explainer': | |||||
| result.append([plugin, package_explain_event(data.get('value'))]) | |||||
| if summaries: | if summaries: | ||||
| result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()]) | result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()]) | ||||
| return result | return result | ||||
| @@ -98,6 +101,8 @@ class WriterPool(ctx.Process): | |||||
| self._writers_.append(SummaryWriter(filepath, self._max_file_size)) | self._writers_.append(SummaryWriter(filepath, self._max_file_size)) | ||||
| elif plugin == 'lineage': | elif plugin == 'lineage': | ||||
| self._writers_.append(LineageWriter(filepath, self._max_file_size)) | self._writers_.append(LineageWriter(filepath, self._max_file_size)) | ||||
| elif plugin == 'explainer': | |||||
| self._writers_.append(ExplainWriter(filepath, self._max_file_size)) | |||||
| return self._writers_ | return self._writers_ | ||||
| def _write(self, plugin, data): | def _write(self, plugin, data): | ||||
| @@ -125,7 +130,6 @@ class WriterPool(ctx.Process): | |||||
| Write the event to file. | Write the event to file. | ||||
| Args: | Args: | ||||
| name (str): The key of a specified file. | |||||
| data (Optional[str, Tuple[list, int]]): The data to write. | data (Optional[str, Tuple[list, int]]): The data to write. | ||||
| """ | """ | ||||
| self._queue.put(('WRITE', data)) | self._queue.put(('WRITE', data)) | ||||
| @@ -17,6 +17,7 @@ import atexit | |||||
| import os | import os | ||||
| import re | import re | ||||
| import threading | import threading | ||||
| from collections import defaultdict | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -24,6 +25,7 @@ from ..._c_expression import Tensor | |||||
| from ..._checkparam import Validator | from ..._checkparam import Validator | ||||
| from .._utils import _check_lineage_value, _check_to_numpy, _make_directory | from .._utils import _check_lineage_value, _check_to_numpy, _make_directory | ||||
| from ._summary_adapter import get_event_file_name, package_graph_event | from ._summary_adapter import get_event_file_name, package_graph_event | ||||
| from ._explain_adapter import check_explain_proto | |||||
| from ._writer_pool import WriterPool | from ._writer_pool import WriterPool | ||||
| # for the moment, this lock is for caution's sake, | # for the moment, this lock is for caution's sake, | ||||
| @@ -55,7 +57,6 @@ def _get_summary_tensor_data(): | |||||
| def _dictlist(): | def _dictlist(): | ||||
| from collections import defaultdict | |||||
| return defaultdict(list) | return defaultdict(list) | ||||
| @@ -133,7 +134,8 @@ class SummaryRecord: | |||||
| self._event_writer = WriterPool(log_dir, | self._event_writer = WriterPool(log_dir, | ||||
| max_file_size, | max_file_size, | ||||
| summary=self.full_file_name, | summary=self.full_file_name, | ||||
| lineage=get_event_file_name(self.prefix, '_lineage')) | |||||
| lineage=get_event_file_name(self.prefix, '_lineage'), | |||||
| explainer=get_event_file_name(self.prefix, '_explain')) | |||||
| _get_summary_tensor_data() | _get_summary_tensor_data() | ||||
| atexit.register(self.close) | atexit.register(self.close) | ||||
| @@ -149,10 +151,11 @@ class SummaryRecord: | |||||
| def set_mode(self, mode): | def set_mode(self, mode): | ||||
| """ | """ | ||||
| Set the mode for the recorder to be aware. The mode is set to 'train' by default. | |||||
| Sets the training phase. Different training phases affect data recording. | |||||
| Args: | Args: | ||||
| mode (str): The mode to be set, which should be 'train' or 'eval'. | |||||
| mode (str): The mode to be set, which should be 'train' or 'eval'. When the mode is 'eval', | |||||
| summary_record will not record the data of summary operators. | |||||
| Raises: | Raises: | ||||
| ValueError: When the mode is not recognized. | ValueError: When the mode is not recognized. | ||||
| @@ -170,29 +173,26 @@ class SummaryRecord: | |||||
| """ | """ | ||||
| Add value to be recorded later. | Add value to be recorded later. | ||||
| When the plugin is 'tensor', 'scalar', 'image' or 'histogram', | |||||
| the name should be the tag name, and the value should be a Tensor. | |||||
| When the plugin is 'graph', the value should be a GraphProto. | |||||
| When the plugin is 'dataset_graph', 'train_lineage', 'eval_lineage', | |||||
| or 'custom_lineage_data', the value should be a proto message. | |||||
| Args: | Args: | ||||
| plugin (str): The value of the plugin. | plugin (str): The value of the plugin. | ||||
| name (str): The value of the name. | name (str): The value of the name. | ||||
| value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ | value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ | ||||
| The value to store. | The value to store. | ||||
| - The data type of value should be 'GraphProto' when the plugin is 'graph'. | |||||
| - The data type of value should be 'Tensor' when the plugin is 'scalar', 'image', 'tensor' | |||||
| - The data type of value should be 'GraphProto' (see mindspore/ccsrc/anf_ir.proto) object | |||||
| when the plugin is 'graph'. | |||||
| - The data type of value should be 'Tensor' object when the plugin is 'scalar', 'image', 'tensor' | |||||
| or 'histogram'. | or 'histogram'. | ||||
| - The data type of value should be 'TrainLineage' when the plugin is 'train_lineage'. | |||||
| - The data type of value should be 'EvaluationLineage' when the plugin is 'eval_lineage'. | |||||
| - The data type of value should be 'DatasetGraph' when the plugin is 'dataset_graph'. | |||||
| - The data type of value should be 'UserDefinedInfo' when the plugin is 'custom_lineage_data'. | |||||
| - The data type of value should be a 'TrainLineage' object when the plugin is 'train_lineage', | |||||
| see mindspore/ccsrc/lineage.proto. | |||||
| - The data type of value should be a 'EvaluationLineage' object when the plugin is 'eval_lineage', | |||||
| see mindspore/ccsrc/lineage.proto. | |||||
| - The data type of value should be a 'DatasetGraph' object when the plugin is 'dataset_graph', | |||||
| see mindspore/ccsrc/lineage.proto. | |||||
| - The data type of value should be a 'UserDefinedInfo' object when the plugin is 'custom_lineage_data', | |||||
| see mindspore/ccsrc/lineage.proto. | |||||
| - The data type of value should be a 'Explain' object when the plugin is 'explainer', | |||||
| see mindspore/ccsrc/summary.proto. | |||||
| Raises: | Raises: | ||||
| ValueError: When the name is not valid. | ValueError: When the name is not valid. | ||||
| TypeError: When the value is not a Tensor. | TypeError: When the value is not a Tensor. | ||||
| @@ -218,6 +218,9 @@ class SummaryRecord: | |||||
| elif plugin == 'graph': | elif plugin == 'graph': | ||||
| package_graph_event(value) | package_graph_event(value) | ||||
| self._data_pool[plugin].append(dict(value=value)) | self._data_pool[plugin].append(dict(value=value)) | ||||
| elif plugin == 'explainer': | |||||
| check_explain_proto(value) | |||||
| self._data_pool[plugin].append(dict(value=value.SerializeToString())) | |||||
| else: | else: | ||||
| raise ValueError(f'No such plugin of {repr(plugin)}') | raise ValueError(f'No such plugin of {repr(plugin)}') | ||||
| @@ -94,3 +94,12 @@ class LineageWriter(BaseWriter): | |||||
| """Write data to file.""" | """Write data to file.""" | ||||
| if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): | if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): | ||||
| super().write(plugin, data) | super().write(plugin, data) | ||||
| class ExplainWriter(BaseWriter): | |||||
| """ExplainWriter for write explain data.""" | |||||
| def write(self, plugin, data): | |||||
| """Write data to file.""" | |||||
| if plugin == 'explainer': | |||||
| super().write(plugin, data) | |||||
| @@ -26,7 +26,7 @@ from mindspore import Tensor | |||||
| from mindspore import Parameter | from mindspore import Parameter | ||||
| from mindspore.train.callback import SummaryCollector | from mindspore.train.callback import SummaryCollector | ||||
| from mindspore.train.callback import _InternalCallbackParam | from mindspore.train.callback import _InternalCallbackParam | ||||
| from mindspore.train.summary.enum import ModeEnum, PluginEnum | |||||
| from mindspore.train.summary.enums import ModeEnum, PluginEnum | |||||
| from mindspore.train.summary import SummaryRecord | from mindspore.train.summary import SummaryRecord | ||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore.nn.optim.optimizer import Optimizer | from mindspore.nn.optim.optimizer import Optimizer | ||||