Merge pull request !1935 from 李鸿章/policy_writertags/v0.5.0-beta
| @@ -79,6 +79,7 @@ if (ENABLE_DUMP_PROTO) | |||||
| file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
| "utils/anf_ir.proto" | "utils/anf_ir.proto" | ||||
| "utils/summary.proto" | "utils/summary.proto" | ||||
| "utils/lineage.proto" | |||||
| "utils/checkpoint.proto" | "utils/checkpoint.proto" | ||||
| ) | ) | ||||
| ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) | ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) | ||||
| @@ -0,0 +1,129 @@ | |||||
| // Copyright 2020 Huawei Technologies Co., Ltd. | |||||
| // | |||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| // you may not use this file except in compliance with the License. | |||||
| // You may obtain a copy of the License at | |||||
| // | |||||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||||
| // | |||||
| // Unless required by applicable law or agreed to in writing, software | |||||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| // See the License for the specific language governing permissions and | |||||
| // limitations under the License. | |||||
| syntax = "proto2"; | |||||
| package mindspore.irpb; | |||||
| option cc_enable_arenas = true; | |||||
| // Event Protocol buffer, Top define | |||||
| message LineageEvent { | |||||
| // Timestamp | |||||
| required double wall_time = 1; | |||||
| // The step of train. | |||||
| optional int64 step = 2; | |||||
| oneof what { | |||||
| // An event file was started, with the specified version. | |||||
| // Now version is "Mindspore.Event:1" | |||||
| string version = 3; | |||||
| // Train lineage | |||||
| TrainLineage train_lineage = 6; | |||||
| // Evaluation lineage | |||||
| EvaluationLineage evaluation_lineage = 7; | |||||
| // Dataset graph | |||||
| DatasetGraph dataset_graph = 9; | |||||
| // User defined info | |||||
| UserDefinedInfo user_defined_info = 10; | |||||
| } | |||||
| } | |||||
| // User defined info | |||||
| message UserDefinedInfo{ | |||||
| // repeated user defined info | |||||
| repeated UserDefinedInfo user_info = 1; | |||||
| // key/value which contains both scalar and dict | |||||
| map<string, UserDefinedInfo> map_dict = 2; | |||||
| map<string, int32> map_int32 = 3; | |||||
| map<string, string> map_str = 4; | |||||
| map<string, double> map_double = 5; | |||||
| } | |||||
| // TrainLineage records infos of a train. | |||||
| message TrainLineage{ | |||||
| message HyperParameters{ | |||||
| optional string optimizer = 1; | |||||
| optional float learning_rate = 2; | |||||
| optional string loss_function = 3; | |||||
| optional int32 epoch = 4; | |||||
| optional string parallel_mode = 5; | |||||
| optional int32 device_num = 6; | |||||
| optional int32 batch_size = 8; | |||||
| } | |||||
| message TrainDataset{ | |||||
| optional string train_dataset_path = 1; | |||||
| optional int32 train_dataset_size = 2; | |||||
| } | |||||
| message Algorithm{ | |||||
| optional string network = 1; | |||||
| optional float loss = 2; | |||||
| } | |||||
| message Model{ | |||||
| optional string path = 3; | |||||
| optional int64 size = 4; | |||||
| } | |||||
| optional HyperParameters hyper_parameters = 1; | |||||
| optional TrainDataset train_dataset = 2; | |||||
| optional Algorithm algorithm = 3; | |||||
| optional Model model = 4; | |||||
| } | |||||
| //EvalLineage records infos of evaluation. | |||||
| message EvaluationLineage{ | |||||
| message ValidDataset{ | |||||
| optional string valid_dataset_path = 1; | |||||
| optional int32 valid_dataset_size = 2; | |||||
| } | |||||
| optional string metric = 2; | |||||
| optional ValidDataset valid_dataset = 3; | |||||
| } | |||||
| // DatasetGraph | |||||
| message DatasetGraph { | |||||
| repeated DatasetGraph children = 1; | |||||
| optional OperationParameter parameter = 2; | |||||
| repeated Operation operations = 3; | |||||
| optional Operation sampler = 4; | |||||
| } | |||||
| message Operation { | |||||
| optional OperationParameter operationParam = 1; | |||||
| repeated int32 size = 2; | |||||
| repeated float weights = 3; | |||||
| } | |||||
| message OperationParameter{ | |||||
| map<string, string> mapStr = 1; | |||||
| map<string, StrList> mapStrList = 2; | |||||
| map<string, bool> mapBool = 3; | |||||
| map<string, int32> mapInt = 4; | |||||
| map<string, double> mapDouble = 5; | |||||
| } | |||||
| message StrList { | |||||
| repeated string strValue = 1; | |||||
| } | |||||
| @@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | |||||
| def _convert_type(types): | def _convert_type(types): | ||||
| """ | """ | ||||
| @@ -193,3 +194,38 @@ def _to_full_shapes(shapes, device_num): | |||||
| new_shape += (item,) | new_shape += (item,) | ||||
| new_shapes.append(new_shape) | new_shapes.append(new_shape) | ||||
| return new_shapes | return new_shapes | ||||
| def _check_to_numpy(plugin, tensor): | |||||
| """Check the tensor and return a numpy.ndarray.""" | |||||
| np_value = tensor.asnumpy() | |||||
| if plugin == 'scalar': | |||||
| if np_value.size == 1: | |||||
| return np_value | |||||
| raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.') | |||||
| if plugin == 'image': | |||||
| if np_value.ndim == 4: | |||||
| return np_value | |||||
| raise ValueError('The tensor seems not to hold a valid image.') | |||||
| if plugin in ('tensor', 'histogram'): | |||||
| if np_value.ndim > 0: | |||||
| return np_value | |||||
| raise ValueError('The tensor should not be empty.') | |||||
| return np_value | |||||
| def _check_lineage_value(plugin, value): | |||||
| """Check the lineage value.""" | |||||
| def raises(plugin, prototype): | |||||
| raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.') | |||||
| if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph): | |||||
| raises(plugin, DatasetGraph) | |||||
| if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage): | |||||
| raises(plugin, EvaluationLineage) | |||||
| if plugin == 'train_lineage' and not isinstance(value, TrainLineage): | |||||
| raises(plugin, TrainLineage) | |||||
| if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo): | |||||
| raises(plugin, UserDefinedInfo) | |||||
| @@ -1,88 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Writes events to disk in a logdir.""" | |||||
| import os | |||||
| import stat | |||||
| from collections import deque | |||||
| from multiprocessing import Pool, Process, Queue, cpu_count | |||||
| from ..._c_expression import EventWriter_ | |||||
| from ._summary_adapter import package_summary_event | |||||
| def _pack(result, step): | |||||
| summary_event = package_summary_event(result, step) | |||||
| return summary_event.SerializeToString() | |||||
| class EventWriter(Process): | |||||
| """ | |||||
| Creates a `EventWriter` and write event to file. | |||||
| Args: | |||||
| filepath (str): Summary event file path and file name. | |||||
| flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120. | |||||
| """ | |||||
| def __init__(self, filepath: str, flush_interval: int) -> None: | |||||
| super().__init__() | |||||
| _ = flush_interval | |||||
| with open(filepath, 'w'): | |||||
| os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR) | |||||
| self._writer = EventWriter_(filepath) | |||||
| self._queue = Queue(cpu_count() * 2) | |||||
| self.start() | |||||
| def run(self): | |||||
| with Pool(min(cpu_count(), 32)) as pool: | |||||
| deq = deque() | |||||
| while True: | |||||
| while deq and deq[0].ready(): | |||||
| self._writer.Write(deq.popleft().get()) | |||||
| if not self._queue.empty(): | |||||
| action, data = self._queue.get() | |||||
| if action == 'WRITE': | |||||
| if not isinstance(data, (str, bytes)): | |||||
| deq.append(pool.apply_async(_pack, data)) | |||||
| else: | |||||
| self._writer.Write(data) | |||||
| elif action == 'FLUSH': | |||||
| self._writer.Flush() | |||||
| elif action == 'END': | |||||
| break | |||||
| for res in deq: | |||||
| self._writer.Write(res.get()) | |||||
| self._writer.Shut() | |||||
| def write(self, data) -> None: | |||||
| """ | |||||
| Write the event to file. | |||||
| Args: | |||||
| data (Optional[str, Tuple[list, int]]): The data to write. | |||||
| """ | |||||
| self._queue.put(('WRITE', data)) | |||||
| def flush(self): | |||||
| """Flush the writer.""" | |||||
| self._queue.put(('FLUSH', None)) | |||||
| def close(self) -> None: | |||||
| """Close the writer.""" | |||||
| self._queue.put(('END', None)) | |||||
| self.join() | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Generate the lineage event which conform to proto format.""" | |||||
| import time | |||||
| from ..lineage_pb2 import LineageEvent | |||||
| def serialize_to_lineage_event(name, value): | |||||
| """Serialize value to lineage event.""" | |||||
| event = LineageEvent() | |||||
| event.wall_time = time.time() | |||||
| content = _get_lineage_content(name, event) | |||||
| content.ParseFromString(value) | |||||
| return event.SerializeToString() | |||||
| def _get_lineage_content(name, event): | |||||
| if name == 'dataset_graph': | |||||
| return event.dataset_graph | |||||
| if name == 'eval_lineage': | |||||
| return event.evaluation_lineage | |||||
| if name == 'train_lineage': | |||||
| return event.train_lineage | |||||
| if name == 'custom_lineage_data': | |||||
| return event.user_defined_info | |||||
| raise KeyError(f'No such field in LineageEvent') | |||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Generate the summary event which conform to proto format.""" | """Generate the summary event which conform to proto format.""" | ||||
| import socket | |||||
| import platform | |||||
| import time | import time | ||||
| import numpy as np | import numpy as np | ||||
| @@ -51,7 +51,7 @@ def get_event_file_name(prefix, suffix): | |||||
| _check_str_by_regular(suffix) | _check_str_by_regular(suffix) | ||||
| file_name = "" | file_name = "" | ||||
| time_second = str(int(time.time())) | time_second = str(int(time.time())) | ||||
| hostname = socket.gethostname() | |||||
| hostname = platform.node() | |||||
| if prefix is not None: | if prefix is not None: | ||||
| file_name = file_name + prefix | file_name = file_name + prefix | ||||
| @@ -0,0 +1,79 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Writes events to disk in a logdir.""" | |||||
| import os | |||||
| import stat | |||||
| from ..._c_expression import EventWriter_ | |||||
| from ._summary_adapter import package_init_event | |||||
| class BaseWriter: | |||||
| """BaseWriter to be subclass.""" | |||||
| def __init__(self, filepath) -> None: | |||||
| self._filepath = filepath | |||||
| self._writer: EventWriter_ = None | |||||
| def init_writer(self): | |||||
| """Write some metadata etc.""" | |||||
| @property | |||||
| def writer(self) -> EventWriter_: | |||||
| """Get the writer.""" | |||||
| if self._writer is not None: | |||||
| return self._writer | |||||
| with open(self._filepath, 'w'): | |||||
| os.chmod(self._filepath, stat.S_IWUSR | stat.S_IRUSR) | |||||
| self._writer = EventWriter_(self._filepath) | |||||
| self.init_writer() | |||||
| return self._writer | |||||
| def write(self, plugin, mode, data): | |||||
| """Write data to file.""" | |||||
| raise NotImplementedError() | |||||
| def flush(self): | |||||
| """Flush the writer.""" | |||||
| if self._writer is not None: | |||||
| self._writer.Flush() | |||||
| def close(self): | |||||
| """Close the writer.""" | |||||
| if self._writer is not None: | |||||
| self._writer.Shut() | |||||
| class SummaryWriter(BaseWriter): | |||||
| """SummaryWriter for write summaries.""" | |||||
| def init_writer(self): | |||||
| """Write some metadata etc.""" | |||||
| self.writer.Write(package_init_event().SerializeToString()) | |||||
| def write(self, plugin, mode, data): | |||||
| """Write data to file.""" | |||||
| if plugin in ('summary', 'graph'): | |||||
| self.writer.Write(data) | |||||
| class LineageWriter(BaseWriter): | |||||
| """LineageWriter for write lineage.""" | |||||
| def write(self, plugin, mode, data): | |||||
| """Write data to file.""" | |||||
| if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): | |||||
| self.writer.Write(data) | |||||
| @@ -0,0 +1,114 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Write events to disk in a base directory.""" | |||||
| import os | |||||
| from collections import deque | |||||
| from multiprocessing import Pool, Process, Queue, cpu_count | |||||
| from ._lineage_adapter import serialize_to_lineage_event | |||||
| from ._summary_adapter import package_graph_event, package_summary_event | |||||
| from ._summary_writer import SummaryWriter, LineageWriter | |||||
| def _pack_data(datadict): | |||||
| """Pack data according to which plugin.""" | |||||
| result = [] | |||||
| summaries, step, mode = [], None, None | |||||
| for plugin, datalist in datadict.items(): | |||||
| for data in datalist: | |||||
| if plugin == 'graph': | |||||
| result.append([plugin, data.get('mode'), package_graph_event(data.get('value')).SerializeToString()]) | |||||
| elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'): | |||||
| result.append([plugin, data.get('mode'), serialize_to_lineage_event(plugin, data.get('value'))]) | |||||
| elif plugin in ('scalar', 'tensor', 'histogram', 'image'): | |||||
| summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) | |||||
| step = data.get('step') | |||||
| mode = data.get('mode') | |||||
| if summaries: | |||||
| result.append(['summary', mode, package_summary_event(summaries, step).SerializeToString()]) | |||||
| return result | |||||
| class WriterPool(Process): | |||||
| """ | |||||
| Use a set of pooled resident processes for writing a list of file. | |||||
| Args: | |||||
| base_dir (str): The base directory to hold all the files. | |||||
| filelist (str): The mapping from short name to long filename. | |||||
| """ | |||||
| def __init__(self, base_dir, **filedict) -> None: | |||||
| super().__init__() | |||||
| self._base_dir, self._filedict = base_dir, filedict | |||||
| self._queue = Queue(cpu_count() * 2) | |||||
| self.start() | |||||
| def run(self): | |||||
| writers = self._get_writers() | |||||
| with Pool() as pool: | |||||
| deq = deque() | |||||
| while True: | |||||
| while deq and deq[0].ready(): | |||||
| for plugin, mode, data in deq.popleft().get(): | |||||
| for writer in writers: | |||||
| writer.write(plugin, mode, data) | |||||
| if not self._queue.empty(): | |||||
| action, data = self._queue.get() | |||||
| if action == 'WRITE': | |||||
| deq.append(pool.apply_async(_pack_data, (data,))) | |||||
| elif action == 'FLUSH': | |||||
| for writer in writers: | |||||
| writer.flush() | |||||
| elif action == 'END': | |||||
| break | |||||
| for result in deq: | |||||
| for plugin, mode, data in result.get(): | |||||
| for writer in writers: | |||||
| writer.write(plugin, mode, data) | |||||
| for writer in writers: | |||||
| writer.close() | |||||
| def _get_writers(self): | |||||
| writers = [] | |||||
| for plugin, filename in self._filedict.items(): | |||||
| filepath = os.path.join(self._base_dir, filename) | |||||
| if plugin == 'summary': | |||||
| writers.append(SummaryWriter(filepath)) | |||||
| elif plugin == 'lineage': | |||||
| writers.append(LineageWriter(filepath)) | |||||
| return writers | |||||
| def write(self, data) -> None: | |||||
| """ | |||||
| Write the event to file. | |||||
| Args: | |||||
| name (str): The key of a specified file. | |||||
| data (Optional[str, Tuple[list, int]]): The data to write. | |||||
| """ | |||||
| self._queue.put(('WRITE', data)) | |||||
| def flush(self): | |||||
| """Flush the writer and sync data to disk.""" | |||||
| self._queue.put(('FLUSH', None)) | |||||
| def close(self) -> None: | |||||
| """Close the writer.""" | |||||
| self._queue.put(('END', None)) | |||||
| self.join() | |||||
| @@ -21,9 +21,9 @@ from mindspore import log as logger | |||||
| from ..._c_expression import Tensor | from ..._c_expression import Tensor | ||||
| from ..._checkparam import _check_str_by_regular | from ..._checkparam import _check_str_by_regular | ||||
| from .._utils import _make_directory | |||||
| from ._event_writer import EventWriter | |||||
| from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event | |||||
| from .._utils import _make_directory, _check_to_numpy, _check_lineage_value | |||||
| from ._summary_adapter import get_event_file_name, package_graph_event | |||||
| from ._writer_pool import WriterPool | |||||
| # for the moment, this lock is for caution's sake, | # for the moment, this lock is for caution's sake, | ||||
| # there are actually no any concurrencies happening. | # there are actually no any concurrencies happening. | ||||
| @@ -53,16 +53,20 @@ def _get_summary_tensor_data(): | |||||
| return data | return data | ||||
| def _dictlist(): | |||||
| from collections import defaultdict | |||||
| return defaultdict(list) | |||||
| class SummaryRecord: | class SummaryRecord: | ||||
| """ | """ | ||||
| SummaryRecord is used to record the summary value. | |||||
| SummaryRecord is used to record the summary data and lineage data. | |||||
| Note: | Note: | ||||
| The API will create an event file in a given directory and add summaries and events to it. | |||||
| It writes the event log to a file by executing the record method. In addition, | |||||
| if the SummaryRecord object is created and the summary operator is used in the network, | |||||
| even if the record method is not called, the event in the cache will be written to the | |||||
| file at the end of execution. Make sure to close the SummaryRecord object at the end. | |||||
| The API will create a summary file and a lineage file lazily in a given directory and writes data to them. | |||||
| It writes the data to files by executing the record method. In addition to record the data bubbled up from | |||||
| the network by defining the summary operators, SummaryRecord also supports to record extra data which | |||||
| can be added by calling add_value. Finally, make sure to close the SummaryRecord object at the end. | |||||
| Args: | Args: | ||||
| log_dir (str): The log_dir is a directory location to save the summary. | log_dir (str): The log_dir is a directory location to save the summary. | ||||
| @@ -89,10 +93,12 @@ class SummaryRecord: | |||||
| file_suffix="_MS", | file_suffix="_MS", | ||||
| network=None): | network=None): | ||||
| self._event_writer, self._closed = None, False | |||||
| self._closed, self._mode = False, 'train' | |||||
| self._data_pool = _dictlist() | |||||
| _check_str_by_regular(file_prefix) | _check_str_by_regular(file_prefix) | ||||
| _check_str_by_regular(file_suffix) | _check_str_by_regular(file_suffix) | ||||
| self.log_path = _make_directory(log_dir) | self.log_path = _make_directory(log_dir) | ||||
| if not isinstance(queue_max_size, int) or not isinstance(flush_time, int): | if not isinstance(queue_max_size, int) or not isinstance(flush_time, int): | ||||
| @@ -123,16 +129,12 @@ class SummaryRecord: | |||||
| except Exception as ex: | except Exception as ex: | ||||
| raise RuntimeError(ex) | raise RuntimeError(ex) | ||||
| def _init_event_writer(self): | |||||
| """Init event writer and write metadata.""" | |||||
| event_writer = EventWriter(self.full_file_name, self.flush_time) | |||||
| event_writer.write(package_init_event().SerializeToString()) | |||||
| return event_writer | |||||
| self._event_writer = WriterPool(log_dir, | |||||
| summary=self.full_file_name, | |||||
| lineage=get_event_file_name('events', '_lineage')) | |||||
| def __enter__(self): | def __enter__(self): | ||||
| """Enter the context manager.""" | """Enter the context manager.""" | ||||
| if not self._event_writer: | |||||
| self._event_writer = self._init_event_writer() | |||||
| if self._closed: | if self._closed: | ||||
| raise ValueError('SummaryRecord has been closed.') | raise ValueError('SummaryRecord has been closed.') | ||||
| return self | return self | ||||
| @@ -141,6 +143,76 @@ class SummaryRecord: | |||||
| """Exit the context manager.""" | """Exit the context manager.""" | ||||
| self.close() | self.close() | ||||
| def set_mode(self, mode): | |||||
| """ | |||||
| Set the mode for the recorder to be aware. The mode is set 'train' by default. | |||||
| Args: | |||||
| mode (str): The mode to set, which should be 'train' or 'eval'. | |||||
| Raises: | |||||
| ValueError: When the mode is not recognized. | |||||
| Examples: | |||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||||
| >>> summary_record.set_mode('eval') | |||||
| """ | |||||
| mode_spec = 'train', 'eval' | |||||
| if mode not in mode_spec: | |||||
| raise ValueError(f'{repr(mode)} is not a recognized mode.') | |||||
| self._mode = mode | |||||
| def add_value(self, plugin, name, value): | |||||
| """ | |||||
| Add value to be record later on. | |||||
| When the plugin is 'tensor', 'scalar', 'image' or 'histogram', | |||||
| the name should be the tag name, and the value should be a Tensor. | |||||
| When the plugin plugin is 'graph', the value should be a GraphProto. | |||||
| When the plugin 'dataset_graph', 'train_lineage', 'eval_lineage', | |||||
| or 'custom_lineage_data', the value should be a proto message. | |||||
| Args: | |||||
| plugin (str): The plugin for the value. | |||||
| name (str): The name for the value. | |||||
| value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ | |||||
| The value to store. | |||||
| - GraphProto: The 'value' should be a serialized string this type when the plugin is 'graph'. | |||||
| - Tensor: The 'value' should be this type when the plugin is 'scalar', 'image', 'tensor' or 'histogram'. | |||||
| - TrainLineage: The 'value' should be this type when the plugin is 'train_lineage'. | |||||
| - EvaluationLineage: The 'value' should be this type when the plugin is 'eval_lineage'. | |||||
| - DatasetGraph: The 'value' should be this type when the plugin is 'dataset_graph'. | |||||
| - UserDefinedInfo: The 'value' should be this type when the plugin is 'custom_lineage_data'. | |||||
| Raises: | |||||
| ValueError: When the name is not valid. | |||||
| TypeError: When the value is not a Tensor. | |||||
| Examples: | |||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||||
| >>> summary_record.add_value('scalar', 'loss', Tensor(0.1)) | |||||
| """ | |||||
| if plugin in ('tensor', 'scalar', 'image', 'histogram'): | |||||
| if not name or not isinstance(name, str): | |||||
| raise ValueError(f'{repr(name)} is not a valid tag name.') | |||||
| if not isinstance(value, Tensor): | |||||
| raise TypeError(f'Expect the value to be Tensor, but got {type(value).__name__}') | |||||
| np_value = _check_to_numpy(plugin, value) | |||||
| self._data_pool[plugin].append(dict(tag=name, mode=self._mode, value=np_value)) | |||||
| elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'): | |||||
| _check_lineage_value(plugin, value) | |||||
| self._data_pool[plugin].append(dict(mode=self._mode, value=value.SerializeToString())) | |||||
| elif plugin == 'graph': | |||||
| package_graph_event(value) | |||||
| self._data_pool[plugin].append(dict(mode=self._mode, value=value)) | |||||
| else: | |||||
| raise ValueError(f'No such plugin of {repr(plugin)}') | |||||
| def record(self, step, train_network=None): | def record(self, step, train_network=None): | ||||
| """ | """ | ||||
| Record the summary. | Record the summary. | ||||
| @@ -149,12 +221,12 @@ class SummaryRecord: | |||||
| step (int): Represents training step number. | step (int): Represents training step number. | ||||
| train_network (Cell): The network that called the callback. | train_network (Cell): The network that called the callback. | ||||
| Returns: | |||||
| bool, whether the record process is successful or not. | |||||
| Examples: | Examples: | ||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | ||||
| >>> summary_record.record(step=2) | >>> summary_record.record(step=2) | ||||
| Returns: | |||||
| bool, whether the record process is successful or not. | |||||
| """ | """ | ||||
| logger.info("SummaryRecord step is %r.", step) | logger.info("SummaryRecord step is %r.", step) | ||||
| if self._closed: | if self._closed: | ||||
| @@ -163,10 +235,6 @@ class SummaryRecord: | |||||
| if not isinstance(step, int) or isinstance(step, bool): | if not isinstance(step, int) or isinstance(step, bool): | ||||
| raise ValueError("`step` should be int") | raise ValueError("`step` should be int") | ||||
| # Set the current summary of train step | # Set the current summary of train step | ||||
| if not self._event_writer: | |||||
| self._event_writer = self._init_event_writer() | |||||
| logger.warning('SummaryRecord should be used as context manager for a with statement.') | |||||
| if self.network is not None and not self.has_graph: | if self.network is not None and not self.has_graph: | ||||
| graph_proto = self.network.get_func_graph_proto() | graph_proto = self.network.get_func_graph_proto() | ||||
| if graph_proto is None and train_network is not None: | if graph_proto is None and train_network is not None: | ||||
| @@ -174,39 +242,48 @@ class SummaryRecord: | |||||
| if graph_proto is None: | if graph_proto is None: | ||||
| logger.error("Failed to get proto for graph") | logger.error("Failed to get proto for graph") | ||||
| else: | else: | ||||
| self._event_writer.write(package_graph_event(graph_proto).SerializeToString()) | |||||
| self._event_writer.write({'graph': [{'step': step, 'value': graph_proto}]}) | |||||
| self.has_graph = True | self.has_graph = True | ||||
| if not _summary_tensor_cache: | if not _summary_tensor_cache: | ||||
| return True | return True | ||||
| data = _get_summary_tensor_data() | |||||
| if not data: | |||||
| logger.info("The step(%r) does not have record data.", step) | |||||
| return False | |||||
| if self.queue_max_size > 0 and len(data) > self.queue_max_size: | |||||
| logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), | |||||
| self.queue_max_size) | |||||
| # process the data | |||||
| result = self._data_convert(data) | |||||
| if not result: | |||||
| logger.error("The step(%r) summary data is invalid.", step) | |||||
| return False | |||||
| self._event_writer.write((result, step)) | |||||
| logger.debug("Send the summary data to scheduler for saving, step = %d", step) | |||||
| if self._mode == 'train': | |||||
| self._add_summary_tensor_data() | |||||
| self._event_writer.write(self._consume_data_pool(step)) | |||||
| return True | return True | ||||
| def _add_summary_tensor_data(self): | |||||
| summary_data = _get_summary_tensor_data() | |||||
| if not summary_data: | |||||
| logger.debug(f'No summary data bubbled from the network.') | |||||
| for name, tensor in summary_data.items(): | |||||
| tag, plugin = SummaryRecord._parse_from(name) | |||||
| if (tag, plugin) == (None, None): | |||||
| logger.warning("The name(%r) is invalid, expected 'TAG[:TYPE]'.", name) | |||||
| else: | |||||
| self.add_value(plugin.lower(), tag, tensor) | |||||
| def _consume_data_pool(self, step): | |||||
| try: | |||||
| for values in self._data_pool.values(): | |||||
| for value in values: | |||||
| value['step'] = step | |||||
| return self._data_pool | |||||
| finally: | |||||
| self._data_pool = _dictlist() | |||||
| @property | @property | ||||
| def log_dir(self): | def log_dir(self): | ||||
| """ | """ | ||||
| Get the full path of the log file. | Get the full path of the log file. | ||||
| Returns: | |||||
| str, the full path of log file. | |||||
| Examples: | Examples: | ||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | ||||
| >>> print(summary_record.log_dir) | >>> print(summary_record.log_dir) | ||||
| Returns: | |||||
| String, the full path of log file. | |||||
| """ | """ | ||||
| return self.full_file_name | return self.full_file_name | ||||
| @@ -235,46 +312,19 @@ class SummaryRecord: | |||||
| """ | """ | ||||
| if not self._closed and self._event_writer: | if not self._closed and self._event_writer: | ||||
| # event writer flush and close | # event writer flush and close | ||||
| logger.info('Please wait it may take quite some time to finish writing and closing.') | |||||
| self._event_writer.close() | self._event_writer.close() | ||||
| self._closed = True | self._closed = True | ||||
| def __del__(self) -> None: | def __del__(self) -> None: | ||||
| self.close() | self.close() | ||||
| def _data_convert(self, summary): | |||||
| """Convert the data.""" | |||||
| # convert the summary to numpy | |||||
| result = [] | |||||
| for name, data in summary.items(): | |||||
| # confirm the data is valid | |||||
| summary_tag, summary_type = SummaryRecord._parse_from(name) | |||||
| if summary_tag is None: | |||||
| logger.error("The data type is invalid, name = %r, tensor = %r", name, data) | |||||
| return None | |||||
| if isinstance(data, Tensor): | |||||
| result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type}) | |||||
| else: | |||||
| logger.error("The data type is invalid, name = %r, tensor = %r", name, data) | |||||
| return None | |||||
| return result | |||||
| @staticmethod | @staticmethod | ||||
| def _parse_from(name: str = None): | def _parse_from(name: str = None): | ||||
| """ | |||||
| Parse the tag and type from name. | |||||
| Args: | |||||
| name (str): Format: TAG[:TYPE]. | |||||
| Returns: | |||||
| Tuple, (summary_tag, summary_type). | |||||
| """ | |||||
| if name is None: | |||||
| logger.error("The name is None") | |||||
| """Parse the tag and type from name.""" | |||||
| if not isinstance(name, str): | |||||
| return None, None | return None, None | ||||
| match = re.match(r'(.+)\[:(.+)\]', name) | match = re.match(r'(.+)\[:(.+)\]', name) | ||||
| if match: | if match: | ||||
| return match.groups() | return match.groups() | ||||
| logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name) | |||||
| return None, None | return None, None | ||||
| @@ -84,21 +84,6 @@ def test_histogram_multi_summary(): | |||||
| event = reader.read_event() | event = reader.read_event() | ||||
| assert event.summary.value[0].histogram.count == size | assert event.summary.value[0].histogram.count == size | ||||
| def test_histogram_summary_scalar_tensor(): | |||||
| """Test histogram summary, input is a scalar tensor.""" | |||||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| test_data = _wrap_test_data(Tensor(1)) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||||
| with SummaryReader(file_name) as reader: | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == 1 | |||||
| def test_histogram_summary_empty_tensor(): | def test_histogram_summary_empty_tensor(): | ||||
| """Test histogram summary, input is an empty tensor.""" | """Test histogram summary, input is an empty tensor.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||