Merge pull request !1935 from 李鸿章/policy_writertags/v0.5.0-beta
| @@ -79,6 +79,7 @@ if (ENABLE_DUMP_PROTO) | |||
| file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "utils/anf_ir.proto" | |||
| "utils/summary.proto" | |||
| "utils/lineage.proto" | |||
| "utils/checkpoint.proto" | |||
| ) | |||
| ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) | |||
| @@ -0,0 +1,129 @@ | |||
| // Copyright 2020 Huawei Technologies Co., Ltd. | |||
| // | |||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||
| // you may not use this file except in compliance with the License. | |||
| // You may obtain a copy of the License at | |||
| // | |||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software | |||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| // See the License for the specific language governing permissions and | |||
| // limitations under the License. | |||
| syntax = "proto2"; | |||
| package mindspore.irpb; | |||
| option cc_enable_arenas = true; | |||
| // Event Protocol buffer, Top define | |||
| message LineageEvent { | |||
| // Timestamp | |||
| required double wall_time = 1; | |||
| // The step of train. | |||
| optional int64 step = 2; | |||
| oneof what { | |||
| // An event file was started, with the specified version. | |||
| // Now version is "Mindspore.Event:1" | |||
| string version = 3; | |||
| // Train lineage | |||
| TrainLineage train_lineage = 6; | |||
| // Evaluation lineage | |||
| EvaluationLineage evaluation_lineage = 7; | |||
| // Dataset graph | |||
| DatasetGraph dataset_graph = 9; | |||
| // User defined info | |||
| UserDefinedInfo user_defined_info = 10; | |||
| } | |||
| } | |||
| // User defined info | |||
| message UserDefinedInfo{ | |||
| // repeated user defined info | |||
| repeated UserDefinedInfo user_info = 1; | |||
| // key/value which contains both scalar and dict | |||
| map<string, UserDefinedInfo> map_dict = 2; | |||
| map<string, int32> map_int32 = 3; | |||
| map<string, string> map_str = 4; | |||
| map<string, double> map_double = 5; | |||
| } | |||
| // TrainLineage records infos of a train. | |||
| message TrainLineage{ | |||
| message HyperParameters{ | |||
| optional string optimizer = 1; | |||
| optional float learning_rate = 2; | |||
| optional string loss_function = 3; | |||
| optional int32 epoch = 4; | |||
| optional string parallel_mode = 5; | |||
| optional int32 device_num = 6; | |||
| optional int32 batch_size = 8; | |||
| } | |||
| message TrainDataset{ | |||
| optional string train_dataset_path = 1; | |||
| optional int32 train_dataset_size = 2; | |||
| } | |||
| message Algorithm{ | |||
| optional string network = 1; | |||
| optional float loss = 2; | |||
| } | |||
| message Model{ | |||
| optional string path = 3; | |||
| optional int64 size = 4; | |||
| } | |||
| optional HyperParameters hyper_parameters = 1; | |||
| optional TrainDataset train_dataset = 2; | |||
| optional Algorithm algorithm = 3; | |||
| optional Model model = 4; | |||
| } | |||
| //EvalLineage records infos of evaluation. | |||
| message EvaluationLineage{ | |||
| message ValidDataset{ | |||
| optional string valid_dataset_path = 1; | |||
| optional int32 valid_dataset_size = 2; | |||
| } | |||
| optional string metric = 2; | |||
| optional ValidDataset valid_dataset = 3; | |||
| } | |||
| // DatasetGraph | |||
| message DatasetGraph { | |||
| repeated DatasetGraph children = 1; | |||
| optional OperationParameter parameter = 2; | |||
| repeated Operation operations = 3; | |||
| optional Operation sampler = 4; | |||
| } | |||
| message Operation { | |||
| optional OperationParameter operationParam = 1; | |||
| repeated int32 size = 2; | |||
| repeated float weights = 3; | |||
| } | |||
| message OperationParameter{ | |||
| map<string, string> mapStr = 1; | |||
| map<string, StrList> mapStrList = 2; | |||
| map<string, bool> mapBool = 3; | |||
| map<string, int32> mapInt = 4; | |||
| map<string, double> mapDouble = 5; | |||
| } | |||
| message StrList { | |||
| repeated string strValue = 1; | |||
| } | |||
| @@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype | |||
| from mindspore import log as logger | |||
| from mindspore.common.api import _executor | |||
| from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | |||
| def _convert_type(types): | |||
| """ | |||
| @@ -193,3 +194,38 @@ def _to_full_shapes(shapes, device_num): | |||
| new_shape += (item,) | |||
| new_shapes.append(new_shape) | |||
| return new_shapes | |||
| def _check_to_numpy(plugin, tensor): | |||
| """Check the tensor and return a numpy.ndarray.""" | |||
| np_value = tensor.asnumpy() | |||
| if plugin == 'scalar': | |||
| if np_value.size == 1: | |||
| return np_value | |||
| raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.') | |||
| if plugin == 'image': | |||
| if np_value.ndim == 4: | |||
| return np_value | |||
| raise ValueError('The tensor seems not to hold a valid image.') | |||
| if plugin in ('tensor', 'histogram'): | |||
| if np_value.ndim > 0: | |||
| return np_value | |||
| raise ValueError('The tensor should not be empty.') | |||
| return np_value | |||
| def _check_lineage_value(plugin, value): | |||
| """Check the lineage value.""" | |||
| def raises(plugin, prototype): | |||
| raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.') | |||
| if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph): | |||
| raises(plugin, DatasetGraph) | |||
| if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage): | |||
| raises(plugin, EvaluationLineage) | |||
| if plugin == 'train_lineage' and not isinstance(value, TrainLineage): | |||
| raises(plugin, TrainLineage) | |||
| if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo): | |||
| raises(plugin, UserDefinedInfo) | |||
| @@ -1,88 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Writes events to disk in a logdir.""" | |||
| import os | |||
| import stat | |||
| from collections import deque | |||
| from multiprocessing import Pool, Process, Queue, cpu_count | |||
| from ..._c_expression import EventWriter_ | |||
| from ._summary_adapter import package_summary_event | |||
| def _pack(result, step): | |||
| summary_event = package_summary_event(result, step) | |||
| return summary_event.SerializeToString() | |||
| class EventWriter(Process): | |||
| """ | |||
| Creates a `EventWriter` and write event to file. | |||
| Args: | |||
| filepath (str): Summary event file path and file name. | |||
| flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120. | |||
| """ | |||
| def __init__(self, filepath: str, flush_interval: int) -> None: | |||
| super().__init__() | |||
| _ = flush_interval | |||
| with open(filepath, 'w'): | |||
| os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR) | |||
| self._writer = EventWriter_(filepath) | |||
| self._queue = Queue(cpu_count() * 2) | |||
| self.start() | |||
| def run(self): | |||
| with Pool(min(cpu_count(), 32)) as pool: | |||
| deq = deque() | |||
| while True: | |||
| while deq and deq[0].ready(): | |||
| self._writer.Write(deq.popleft().get()) | |||
| if not self._queue.empty(): | |||
| action, data = self._queue.get() | |||
| if action == 'WRITE': | |||
| if not isinstance(data, (str, bytes)): | |||
| deq.append(pool.apply_async(_pack, data)) | |||
| else: | |||
| self._writer.Write(data) | |||
| elif action == 'FLUSH': | |||
| self._writer.Flush() | |||
| elif action == 'END': | |||
| break | |||
| for res in deq: | |||
| self._writer.Write(res.get()) | |||
| self._writer.Shut() | |||
| def write(self, data) -> None: | |||
| """ | |||
| Write the event to file. | |||
| Args: | |||
| data (Optional[str, Tuple[list, int]]): The data to write. | |||
| """ | |||
| self._queue.put(('WRITE', data)) | |||
| def flush(self): | |||
| """Flush the writer.""" | |||
| self._queue.put(('FLUSH', None)) | |||
| def close(self) -> None: | |||
| """Close the writer.""" | |||
| self._queue.put(('END', None)) | |||
| self.join() | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Generate the lineage event which conform to proto format.""" | |||
| import time | |||
| from ..lineage_pb2 import LineageEvent | |||
| def serialize_to_lineage_event(name, value): | |||
| """Serialize value to lineage event.""" | |||
| event = LineageEvent() | |||
| event.wall_time = time.time() | |||
| content = _get_lineage_content(name, event) | |||
| content.ParseFromString(value) | |||
| return event.SerializeToString() | |||
| def _get_lineage_content(name, event): | |||
| if name == 'dataset_graph': | |||
| return event.dataset_graph | |||
| if name == 'eval_lineage': | |||
| return event.evaluation_lineage | |||
| if name == 'train_lineage': | |||
| return event.train_lineage | |||
| if name == 'custom_lineage_data': | |||
| return event.user_defined_info | |||
| raise KeyError(f'No such field in LineageEvent') | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Generate the summary event which conform to proto format.""" | |||
| import socket | |||
| import platform | |||
| import time | |||
| import numpy as np | |||
| @@ -51,7 +51,7 @@ def get_event_file_name(prefix, suffix): | |||
| _check_str_by_regular(suffix) | |||
| file_name = "" | |||
| time_second = str(int(time.time())) | |||
| hostname = socket.gethostname() | |||
| hostname = platform.node() | |||
| if prefix is not None: | |||
| file_name = file_name + prefix | |||
| @@ -0,0 +1,79 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Writes events to disk in a logdir.""" | |||
| import os | |||
| import stat | |||
| from ..._c_expression import EventWriter_ | |||
| from ._summary_adapter import package_init_event | |||
| class BaseWriter: | |||
| """BaseWriter to be subclass.""" | |||
| def __init__(self, filepath) -> None: | |||
| self._filepath = filepath | |||
| self._writer: EventWriter_ = None | |||
| def init_writer(self): | |||
| """Write some metadata etc.""" | |||
| @property | |||
| def writer(self) -> EventWriter_: | |||
| """Get the writer.""" | |||
| if self._writer is not None: | |||
| return self._writer | |||
| with open(self._filepath, 'w'): | |||
| os.chmod(self._filepath, stat.S_IWUSR | stat.S_IRUSR) | |||
| self._writer = EventWriter_(self._filepath) | |||
| self.init_writer() | |||
| return self._writer | |||
| def write(self, plugin, mode, data): | |||
| """Write data to file.""" | |||
| raise NotImplementedError() | |||
| def flush(self): | |||
| """Flush the writer.""" | |||
| if self._writer is not None: | |||
| self._writer.Flush() | |||
| def close(self): | |||
| """Close the writer.""" | |||
| if self._writer is not None: | |||
| self._writer.Shut() | |||
| class SummaryWriter(BaseWriter): | |||
| """SummaryWriter for write summaries.""" | |||
| def init_writer(self): | |||
| """Write some metadata etc.""" | |||
| self.writer.Write(package_init_event().SerializeToString()) | |||
| def write(self, plugin, mode, data): | |||
| """Write data to file.""" | |||
| if plugin in ('summary', 'graph'): | |||
| self.writer.Write(data) | |||
| class LineageWriter(BaseWriter): | |||
| """LineageWriter for write lineage.""" | |||
| def write(self, plugin, mode, data): | |||
| """Write data to file.""" | |||
| if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): | |||
| self.writer.Write(data) | |||
| @@ -0,0 +1,114 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Write events to disk in a base directory.""" | |||
| import os | |||
| from collections import deque | |||
| from multiprocessing import Pool, Process, Queue, cpu_count | |||
| from ._lineage_adapter import serialize_to_lineage_event | |||
| from ._summary_adapter import package_graph_event, package_summary_event | |||
| from ._summary_writer import SummaryWriter, LineageWriter | |||
| def _pack_data(datadict): | |||
| """Pack data according to which plugin.""" | |||
| result = [] | |||
| summaries, step, mode = [], None, None | |||
| for plugin, datalist in datadict.items(): | |||
| for data in datalist: | |||
| if plugin == 'graph': | |||
| result.append([plugin, data.get('mode'), package_graph_event(data.get('value')).SerializeToString()]) | |||
| elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'): | |||
| result.append([plugin, data.get('mode'), serialize_to_lineage_event(plugin, data.get('value'))]) | |||
| elif plugin in ('scalar', 'tensor', 'histogram', 'image'): | |||
| summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) | |||
| step = data.get('step') | |||
| mode = data.get('mode') | |||
| if summaries: | |||
| result.append(['summary', mode, package_summary_event(summaries, step).SerializeToString()]) | |||
| return result | |||
| class WriterPool(Process): | |||
| """ | |||
| Use a set of pooled resident processes for writing a list of file. | |||
| Args: | |||
| base_dir (str): The base directory to hold all the files. | |||
| filelist (str): The mapping from short name to long filename. | |||
| """ | |||
| def __init__(self, base_dir, **filedict) -> None: | |||
| super().__init__() | |||
| self._base_dir, self._filedict = base_dir, filedict | |||
| self._queue = Queue(cpu_count() * 2) | |||
| self.start() | |||
| def run(self): | |||
| writers = self._get_writers() | |||
| with Pool() as pool: | |||
| deq = deque() | |||
| while True: | |||
| while deq and deq[0].ready(): | |||
| for plugin, mode, data in deq.popleft().get(): | |||
| for writer in writers: | |||
| writer.write(plugin, mode, data) | |||
| if not self._queue.empty(): | |||
| action, data = self._queue.get() | |||
| if action == 'WRITE': | |||
| deq.append(pool.apply_async(_pack_data, (data,))) | |||
| elif action == 'FLUSH': | |||
| for writer in writers: | |||
| writer.flush() | |||
| elif action == 'END': | |||
| break | |||
| for result in deq: | |||
| for plugin, mode, data in result.get(): | |||
| for writer in writers: | |||
| writer.write(plugin, mode, data) | |||
| for writer in writers: | |||
| writer.close() | |||
| def _get_writers(self): | |||
| writers = [] | |||
| for plugin, filename in self._filedict.items(): | |||
| filepath = os.path.join(self._base_dir, filename) | |||
| if plugin == 'summary': | |||
| writers.append(SummaryWriter(filepath)) | |||
| elif plugin == 'lineage': | |||
| writers.append(LineageWriter(filepath)) | |||
| return writers | |||
| def write(self, data) -> None: | |||
| """ | |||
| Write the event to file. | |||
| Args: | |||
| name (str): The key of a specified file. | |||
| data (Optional[str, Tuple[list, int]]): The data to write. | |||
| """ | |||
| self._queue.put(('WRITE', data)) | |||
| def flush(self): | |||
| """Flush the writer and sync data to disk.""" | |||
| self._queue.put(('FLUSH', None)) | |||
| def close(self) -> None: | |||
| """Close the writer.""" | |||
| self._queue.put(('END', None)) | |||
| self.join() | |||
| @@ -21,9 +21,9 @@ from mindspore import log as logger | |||
| from ..._c_expression import Tensor | |||
| from ..._checkparam import _check_str_by_regular | |||
| from .._utils import _make_directory | |||
| from ._event_writer import EventWriter | |||
| from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event | |||
| from .._utils import _make_directory, _check_to_numpy, _check_lineage_value | |||
| from ._summary_adapter import get_event_file_name, package_graph_event | |||
| from ._writer_pool import WriterPool | |||
| # for the moment, this lock is for caution's sake, | |||
| # there are actually no any concurrencies happening. | |||
| @@ -53,16 +53,20 @@ def _get_summary_tensor_data(): | |||
| return data | |||
| def _dictlist(): | |||
| from collections import defaultdict | |||
| return defaultdict(list) | |||
| class SummaryRecord: | |||
| """ | |||
| SummaryRecord is used to record the summary value. | |||
| SummaryRecord is used to record the summary data and lineage data. | |||
| Note: | |||
| The API will create an event file in a given directory and add summaries and events to it. | |||
| It writes the event log to a file by executing the record method. In addition, | |||
| if the SummaryRecord object is created and the summary operator is used in the network, | |||
| even if the record method is not called, the event in the cache will be written to the | |||
| file at the end of execution. Make sure to close the SummaryRecord object at the end. | |||
| The API will create a summary file and a lineage file lazily in a given directory and writes data to them. | |||
| It writes the data to files by executing the record method. In addition to record the data bubbled up from | |||
| the network by defining the summary operators, SummaryRecord also supports to record extra data which | |||
| can be added by calling add_value. Finally, make sure to close the SummaryRecord object at the end. | |||
| Args: | |||
| log_dir (str): The log_dir is a directory location to save the summary. | |||
| @@ -89,10 +93,12 @@ class SummaryRecord: | |||
| file_suffix="_MS", | |||
| network=None): | |||
| self._event_writer, self._closed = None, False | |||
| self._closed, self._mode = False, 'train' | |||
| self._data_pool = _dictlist() | |||
| _check_str_by_regular(file_prefix) | |||
| _check_str_by_regular(file_suffix) | |||
| self.log_path = _make_directory(log_dir) | |||
| if not isinstance(queue_max_size, int) or not isinstance(flush_time, int): | |||
| @@ -123,16 +129,12 @@ class SummaryRecord: | |||
| except Exception as ex: | |||
| raise RuntimeError(ex) | |||
| def _init_event_writer(self): | |||
| """Init event writer and write metadata.""" | |||
| event_writer = EventWriter(self.full_file_name, self.flush_time) | |||
| event_writer.write(package_init_event().SerializeToString()) | |||
| return event_writer | |||
| self._event_writer = WriterPool(log_dir, | |||
| summary=self.full_file_name, | |||
| lineage=get_event_file_name('events', '_lineage')) | |||
| def __enter__(self): | |||
| """Enter the context manager.""" | |||
| if not self._event_writer: | |||
| self._event_writer = self._init_event_writer() | |||
| if self._closed: | |||
| raise ValueError('SummaryRecord has been closed.') | |||
| return self | |||
| @@ -141,6 +143,76 @@ class SummaryRecord: | |||
| """Exit the context manager.""" | |||
| self.close() | |||
| def set_mode(self, mode): | |||
| """ | |||
| Set the mode for the recorder to be aware. The mode is set 'train' by default. | |||
| Args: | |||
| mode (str): The mode to set, which should be 'train' or 'eval'. | |||
| Raises: | |||
| ValueError: When the mode is not recognized. | |||
| Examples: | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> summary_record.set_mode('eval') | |||
| """ | |||
| mode_spec = 'train', 'eval' | |||
| if mode not in mode_spec: | |||
| raise ValueError(f'{repr(mode)} is not a recognized mode.') | |||
| self._mode = mode | |||
| def add_value(self, plugin, name, value): | |||
| """ | |||
| Add value to be record later on. | |||
| When the plugin is 'tensor', 'scalar', 'image' or 'histogram', | |||
| the name should be the tag name, and the value should be a Tensor. | |||
| When the plugin plugin is 'graph', the value should be a GraphProto. | |||
| When the plugin 'dataset_graph', 'train_lineage', 'eval_lineage', | |||
| or 'custom_lineage_data', the value should be a proto message. | |||
| Args: | |||
| plugin (str): The plugin for the value. | |||
| name (str): The name for the value. | |||
| value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ | |||
| The value to store. | |||
| - GraphProto: The 'value' should be a serialized string this type when the plugin is 'graph'. | |||
| - Tensor: The 'value' should be this type when the plugin is 'scalar', 'image', 'tensor' or 'histogram'. | |||
| - TrainLineage: The 'value' should be this type when the plugin is 'train_lineage'. | |||
| - EvaluationLineage: The 'value' should be this type when the plugin is 'eval_lineage'. | |||
| - DatasetGraph: The 'value' should be this type when the plugin is 'dataset_graph'. | |||
| - UserDefinedInfo: The 'value' should be this type when the plugin is 'custom_lineage_data'. | |||
| Raises: | |||
| ValueError: When the name is not valid. | |||
| TypeError: When the value is not a Tensor. | |||
| Examples: | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> summary_record.add_value('scalar', 'loss', Tensor(0.1)) | |||
| """ | |||
| if plugin in ('tensor', 'scalar', 'image', 'histogram'): | |||
| if not name or not isinstance(name, str): | |||
| raise ValueError(f'{repr(name)} is not a valid tag name.') | |||
| if not isinstance(value, Tensor): | |||
| raise TypeError(f'Expect the value to be Tensor, but got {type(value).__name__}') | |||
| np_value = _check_to_numpy(plugin, value) | |||
| self._data_pool[plugin].append(dict(tag=name, mode=self._mode, value=np_value)) | |||
| elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'): | |||
| _check_lineage_value(plugin, value) | |||
| self._data_pool[plugin].append(dict(mode=self._mode, value=value.SerializeToString())) | |||
| elif plugin == 'graph': | |||
| package_graph_event(value) | |||
| self._data_pool[plugin].append(dict(mode=self._mode, value=value)) | |||
| else: | |||
| raise ValueError(f'No such plugin of {repr(plugin)}') | |||
| def record(self, step, train_network=None): | |||
| """ | |||
| Record the summary. | |||
| @@ -149,12 +221,12 @@ class SummaryRecord: | |||
| step (int): Represents training step number. | |||
| train_network (Cell): The network that called the callback. | |||
| Returns: | |||
| bool, whether the record process is successful or not. | |||
| Examples: | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> summary_record.record(step=2) | |||
| Returns: | |||
| bool, whether the record process is successful or not. | |||
| """ | |||
| logger.info("SummaryRecord step is %r.", step) | |||
| if self._closed: | |||
| @@ -163,10 +235,6 @@ class SummaryRecord: | |||
| if not isinstance(step, int) or isinstance(step, bool): | |||
| raise ValueError("`step` should be int") | |||
| # Set the current summary of train step | |||
| if not self._event_writer: | |||
| self._event_writer = self._init_event_writer() | |||
| logger.warning('SummaryRecord should be used as context manager for a with statement.') | |||
| if self.network is not None and not self.has_graph: | |||
| graph_proto = self.network.get_func_graph_proto() | |||
| if graph_proto is None and train_network is not None: | |||
| @@ -174,39 +242,48 @@ class SummaryRecord: | |||
| if graph_proto is None: | |||
| logger.error("Failed to get proto for graph") | |||
| else: | |||
| self._event_writer.write(package_graph_event(graph_proto).SerializeToString()) | |||
| self._event_writer.write({'graph': [{'step': step, 'value': graph_proto}]}) | |||
| self.has_graph = True | |||
| if not _summary_tensor_cache: | |||
| return True | |||
| data = _get_summary_tensor_data() | |||
| if not data: | |||
| logger.info("The step(%r) does not have record data.", step) | |||
| return False | |||
| if self.queue_max_size > 0 and len(data) > self.queue_max_size: | |||
| logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), | |||
| self.queue_max_size) | |||
| # process the data | |||
| result = self._data_convert(data) | |||
| if not result: | |||
| logger.error("The step(%r) summary data is invalid.", step) | |||
| return False | |||
| self._event_writer.write((result, step)) | |||
| logger.debug("Send the summary data to scheduler for saving, step = %d", step) | |||
| if self._mode == 'train': | |||
| self._add_summary_tensor_data() | |||
| self._event_writer.write(self._consume_data_pool(step)) | |||
| return True | |||
| def _add_summary_tensor_data(self): | |||
| summary_data = _get_summary_tensor_data() | |||
| if not summary_data: | |||
| logger.debug(f'No summary data bubbled from the network.') | |||
| for name, tensor in summary_data.items(): | |||
| tag, plugin = SummaryRecord._parse_from(name) | |||
| if (tag, plugin) == (None, None): | |||
| logger.warning("The name(%r) is invalid, expected 'TAG[:TYPE]'.", name) | |||
| else: | |||
| self.add_value(plugin.lower(), tag, tensor) | |||
| def _consume_data_pool(self, step): | |||
| try: | |||
| for values in self._data_pool.values(): | |||
| for value in values: | |||
| value['step'] = step | |||
| return self._data_pool | |||
| finally: | |||
| self._data_pool = _dictlist() | |||
| @property | |||
| def log_dir(self): | |||
| """ | |||
| Get the full path of the log file. | |||
| Returns: | |||
| str, the full path of log file. | |||
| Examples: | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> print(summary_record.log_dir) | |||
| Returns: | |||
| String, the full path of log file. | |||
| """ | |||
| return self.full_file_name | |||
| @@ -235,46 +312,19 @@ class SummaryRecord: | |||
| """ | |||
| if not self._closed and self._event_writer: | |||
| # event writer flush and close | |||
| logger.info('Please wait it may take quite some time to finish writing and closing.') | |||
| self._event_writer.close() | |||
| self._closed = True | |||
| def __del__(self) -> None: | |||
| self.close() | |||
| def _data_convert(self, summary): | |||
| """Convert the data.""" | |||
| # convert the summary to numpy | |||
| result = [] | |||
| for name, data in summary.items(): | |||
| # confirm the data is valid | |||
| summary_tag, summary_type = SummaryRecord._parse_from(name) | |||
| if summary_tag is None: | |||
| logger.error("The data type is invalid, name = %r, tensor = %r", name, data) | |||
| return None | |||
| if isinstance(data, Tensor): | |||
| result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type}) | |||
| else: | |||
| logger.error("The data type is invalid, name = %r, tensor = %r", name, data) | |||
| return None | |||
| return result | |||
| @staticmethod | |||
| def _parse_from(name: str = None): | |||
| """ | |||
| Parse the tag and type from name. | |||
| Args: | |||
| name (str): Format: TAG[:TYPE]. | |||
| Returns: | |||
| Tuple, (summary_tag, summary_type). | |||
| """ | |||
| if name is None: | |||
| logger.error("The name is None") | |||
| """Parse the tag and type from name.""" | |||
| if not isinstance(name, str): | |||
| return None, None | |||
| match = re.match(r'(.+)\[:(.+)\]', name) | |||
| if match: | |||
| return match.groups() | |||
| logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name) | |||
| return None, None | |||
| @@ -84,21 +84,6 @@ def test_histogram_multi_summary(): | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == size | |||
| def test_histogram_summary_scalar_tensor(): | |||
| """Test histogram summary, input is a scalar tensor.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| test_data = _wrap_test_data(Tensor(1)) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| with SummaryReader(file_name) as reader: | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == 1 | |||
| def test_histogram_summary_empty_tensor(): | |||
| """Test histogram summary, input is an empty tensor.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||