Merge pull request !1069 from 李鸿章/context_managertags/v0.3.0-alpha
| @@ -14,91 +14,74 @@ | |||
| # ============================================================================ | |||
| """Writes events to disk in a logdir.""" | |||
| import os | |||
| import time | |||
| import stat | |||
| from mindspore import log as logger | |||
| from collections import deque | |||
| from multiprocessing import Pool, Process, Queue, cpu_count | |||
| from ..._c_expression import EventWriter_ | |||
| from ._summary_adapter import package_init_event | |||
| from ._summary_adapter import package_summary_event | |||
| class _WrapEventWriter(EventWriter_): | |||
| """ | |||
| Wrap the c++ EventWriter object. | |||
| def _pack(result, step): | |||
| summary_event = package_summary_event(result, step) | |||
| return summary_event.SerializeToString() | |||
| Args: | |||
| full_file_name (str): Include directory and file name. | |||
| """ | |||
| def __init__(self, full_file_name): | |||
| if full_file_name is not None: | |||
| EventWriter_.__init__(self, full_file_name) | |||
| class EventRecord: | |||
| class EventWriter(Process): | |||
| """ | |||
| Creates a `EventFileWriter` and write event to file. | |||
| Creates a `EventWriter` and write event to file. | |||
| Args: | |||
| full_file_name (str): Summary event file path and file name. | |||
| flush_time (int): The flush seconds to flush the pending events to disk. Default: 120. | |||
| filepath (str): Summary event file path and file name. | |||
| flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120. | |||
| """ | |||
| def __init__(self, full_file_name: str, flush_time: int = 120): | |||
| self.full_file_name = full_file_name | |||
| # The first event will be flushed immediately. | |||
| self.flush_time = flush_time | |||
| self.next_flush_time = 0 | |||
| # create event write object | |||
| self.event_writer = self._create_event_file() | |||
| self._init_event_file() | |||
| # count the events | |||
| self.event_count = 0 | |||
| def _create_event_file(self): | |||
| """Create the event write file.""" | |||
| with open(self.full_file_name, 'w'): | |||
| os.chmod(self.full_file_name, stat.S_IWUSR | stat.S_IRUSR) | |||
| # create c++ event write object | |||
| event_writer = _WrapEventWriter(self.full_file_name) | |||
| return event_writer | |||
| def _init_event_file(self): | |||
| """Send the init event to file.""" | |||
| self.event_writer.Write((package_init_event()).SerializeToString()) | |||
| self.flush() | |||
| return True | |||
| def write_event_to_file(self, event_str): | |||
| """Write the event to file.""" | |||
| self.event_writer.Write(event_str) | |||
| def get_data_count(self): | |||
| """Return the event count.""" | |||
| return self.event_count | |||
| def flush_cycle(self): | |||
| """Flush file by timer.""" | |||
| self.event_count = self.event_count + 1 | |||
| # Flush the event writer every so often. | |||
| now = int(time.time()) | |||
| if now > self.next_flush_time: | |||
| self.flush() | |||
| # update the flush time | |||
| self.next_flush_time = now + self.flush_time | |||
| def count_event(self): | |||
| """Count event.""" | |||
| logger.debug("Write the event count is %r", self.event_count) | |||
| self.event_count = self.event_count + 1 | |||
| return self.event_count | |||
| def __init__(self, filepath: str, flush_interval: int) -> None: | |||
| super().__init__() | |||
| with open(filepath, 'w'): | |||
| os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR) | |||
| self._writer = EventWriter_(filepath) | |||
| self._queue = Queue(cpu_count() * 2) | |||
| self.start() | |||
| def run(self): | |||
| with Pool() as pool: | |||
| deq = deque() | |||
| while True: | |||
| while deq and deq[0].ready(): | |||
| self._writer.Write(deq.popleft().get()) | |||
| if not self._queue.empty(): | |||
| action, data = self._queue.get() | |||
| if action == 'WRITE': | |||
| if not isinstance(data, (str, bytes)): | |||
| deq.append(pool.apply_async(_pack, data)) | |||
| else: | |||
| self._writer.Write(data) | |||
| elif action == 'FLUSH': | |||
| self._writer.Flush() | |||
| elif action == 'END': | |||
| break | |||
| for res in deq: | |||
| self._writer.Write(res.get()) | |||
| self._writer.Shut() | |||
| def write(self, data) -> None: | |||
| """ | |||
| Write the event to file. | |||
| Args: | |||
| data (Optional[str, Tuple[list, int]]): The data to write. | |||
| """ | |||
| self._queue.put(('WRITE', data)) | |||
| def flush(self): | |||
| """Flush the event file to disk.""" | |||
| self.event_writer.Flush() | |||
| """Flush the writer.""" | |||
| self._queue.put(('FLUSH', None)) | |||
| def close(self): | |||
| """Flush the event file to disk and close the file.""" | |||
| self.flush() | |||
| self.event_writer.Shut() | |||
| def close(self) -> None: | |||
| """Close the writer.""" | |||
| self._queue.put(('END', None)) | |||
| self.join() | |||
| @@ -13,17 +13,17 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Generate the summary event which conform to proto format.""" | |||
| import time | |||
| import socket | |||
| import math | |||
| from enum import Enum, unique | |||
| import time | |||
| import numpy as np | |||
| from PIL import Image | |||
| from mindspore import log as logger | |||
| from ..summary_pb2 import Event | |||
| from ..anf_ir_pb2 import ModelProto, DataType | |||
| from ..._checkparam import _check_str_by_regular | |||
| from ..anf_ir_pb2 import DataType, ModelProto | |||
| from ..summary_pb2 import Event | |||
| # define the MindSpore image format | |||
| MS_IMAGE_TENSOR_FORMAT = 'NCHW' | |||
| @@ -32,55 +32,6 @@ EVENT_FILE_NAME_MARK = ".out.events.summary." | |||
| # Set the init event of version and mark | |||
| EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:" | |||
| EVENT_FILE_INIT_VERSION = 1 | |||
| # cache the summary data dict | |||
| # {id: SummaryData} | |||
| # |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...] | |||
| g_summary_data_dict = {} | |||
| def save_summary_data(data_id, data): | |||
| """Save the global summary cache.""" | |||
| global g_summary_data_dict | |||
| g_summary_data_dict[data_id] = data | |||
| def del_summary_data(data_id): | |||
| """Save the global summary cache.""" | |||
| global g_summary_data_dict | |||
| if data_id in g_summary_data_dict: | |||
| del g_summary_data_dict[data_id] | |||
| else: | |||
| logger.warning("Can't del the data because data_id(%r) " | |||
| "does not have data in g_summary_data_dict", data_id) | |||
| def get_summary_data(data_id): | |||
| """Save the global summary cache.""" | |||
| ret = None | |||
| global g_summary_data_dict | |||
| if data_id in g_summary_data_dict: | |||
| ret = g_summary_data_dict.get(data_id) | |||
| else: | |||
| logger.warning("The data_id(%r) does not have data in g_summary_data_dict", data_id) | |||
| return ret | |||
| @unique | |||
| class SummaryType(Enum): | |||
| """ | |||
| Summary type. | |||
| Args: | |||
| SCALAR (Number): Summary Scalar enum. | |||
| TENSOR (Number): Summary TENSOR enum. | |||
| IMAGE (Number): Summary image enum. | |||
| GRAPH (Number): Summary graph enum. | |||
| HISTOGRAM (Number): Summary histogram enum. | |||
| INVALID (Number): Unknow type. | |||
| """ | |||
| SCALAR = 1 # Scalar summary | |||
| TENSOR = 2 # Tensor summary | |||
| IMAGE = 3 # Image summary | |||
| GRAPH = 4 # graph | |||
| HISTOGRAM = 5 # Histogram Summary | |||
| INVALID = 0xFF # unknow type | |||
| def get_event_file_name(prefix, suffix): | |||
| @@ -138,7 +89,7 @@ def package_graph_event(data): | |||
| return graph_event | |||
| def package_summary_event(data_id, step): | |||
| def package_summary_event(data_list, step): | |||
| """ | |||
| Package the summary to event protobuffer. | |||
| @@ -149,50 +100,37 @@ def package_summary_event(data_id, step): | |||
| Returns: | |||
| Summary, the summary event. | |||
| """ | |||
| data_list = get_summary_data(data_id) | |||
| if data_list is None: | |||
| logger.error("The step(%r) does not have record data.", step) | |||
| del_summary_data(data_id) | |||
| # create the event of summary | |||
| summary_event = Event() | |||
| summary = summary_event.summary | |||
| summary_event.wall_time = time.time() | |||
| summary_event.step = int(step) | |||
| for value in data_list: | |||
| tag = value["name"] | |||
| summary_type = value["_type"] | |||
| data = value["data"] | |||
| summary_type = value["type"] | |||
| tag = value["name"] | |||
| logger.debug("Now process %r summary, tag = %r", summary_type, tag) | |||
| summary_value = summary.value.add() | |||
| summary_value.tag = tag | |||
| # get the summary type and parse the tag | |||
| if summary_type is SummaryType.SCALAR: | |||
| logger.debug("Now process Scalar summary, tag = %r", tag) | |||
| summary_value = summary.value.add() | |||
| summary_value.tag = tag | |||
| if summary_type == 'Scalar': | |||
| summary_value.scalar_value = _get_scalar_summary(tag, data) | |||
| elif summary_type is SummaryType.TENSOR: | |||
| logger.debug("Now process Tensor summary, tag = %r", tag) | |||
| summary_value = summary.value.add() | |||
| summary_value.tag = tag | |||
| elif summary_type == 'Tensor': | |||
| summary_tensor = summary_value.tensor | |||
| _get_tensor_summary(tag, data, summary_tensor) | |||
| elif summary_type is SummaryType.IMAGE: | |||
| logger.debug("Now process Image summary, tag = %r", tag) | |||
| summary_value = summary.value.add() | |||
| summary_value.tag = tag | |||
| elif summary_type == 'Image': | |||
| summary_image = summary_value.image | |||
| _get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT) | |||
| elif summary_type is SummaryType.HISTOGRAM: | |||
| logger.debug("Now process Histogram summary, tag = %r", tag) | |||
| summary_value = summary.value.add() | |||
| summary_value.tag = tag | |||
| elif summary_type == 'Histogram': | |||
| summary_histogram = summary_value.histogram | |||
| _fill_histogram_summary(tag, data, summary_histogram) | |||
| else: | |||
| # The data is invalid ,jump the data | |||
| logger.error("Summary type is error, tag = %r", tag) | |||
| continue | |||
| logger.error("Summary type(%r) is error, tag = %r", summary_type, tag) | |||
| summary_event.wall_time = time.time() | |||
| summary_event.step = int(step) | |||
| return summary_event | |||
| @@ -255,11 +193,11 @@ def _get_scalar_summary(tag: str, np_value): | |||
| # So consider the dim = 1, shape = (1,) tensor is scalar | |||
| scalar_value = np_value[0] | |||
| if np_value.shape != (1,): | |||
| logger.error("The tensor is not Scalar, tag = %r, Value = %r", tag, np_value) | |||
| logger.error("The tensor is not Scalar, tag = %r, Shape = %r", tag, np_value.shape) | |||
| else: | |||
| np_list = np_value.reshape(-1).tolist() | |||
| scalar_value = np_list[0] | |||
| logger.error("The value is not Scalar, tag = %r, Value = %r", tag, np_value) | |||
| logger.error("The value is not Scalar, tag = %r, ndim = %r", tag, np_value.ndim) | |||
| logger.debug("The tag(%r) value is: %r", tag, scalar_value) | |||
| return scalar_value | |||
| @@ -307,8 +245,7 @@ def _calc_histogram_bins(count): | |||
| Returns: | |||
| int, number of histogram bins. | |||
| """ | |||
| number_per_bucket = 10 | |||
| max_bins = 90 | |||
| max_bins, max_per_bin = 90, 10 | |||
| if not count: | |||
| return 1 | |||
| @@ -318,78 +255,50 @@ def _calc_histogram_bins(count): | |||
| return 3 | |||
| if count <= 880: | |||
| # note that math.ceil(881/10) + 1 equals 90 | |||
| return int(math.ceil(count / number_per_bucket) + 1) | |||
| return count // max_per_bin + 1 | |||
| return max_bins | |||
| def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None: | |||
| def _fill_histogram_summary(tag: str, np_value: np.ndarray, summary) -> None: | |||
| """ | |||
| Package the histogram summary. | |||
| Args: | |||
| tag (str): Summary tag describe. | |||
| np_value (np.array): Summary data. | |||
| summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data. | |||
| np_value (np.ndarray): Summary data. | |||
| summary (summary_pb2.Summary.Histogram): Summary histogram data. | |||
| """ | |||
| logger.debug("Set(%r) the histogram summary value", tag) | |||
| # Default bucket for tensor with no valid data. | |||
| default_bucket_left = -0.5 | |||
| default_bucket_width = 1.0 | |||
| if np_value.size == 0: | |||
| bucket = summary_histogram.buckets.add() | |||
| bucket.left = default_bucket_left | |||
| bucket.width = default_bucket_width | |||
| bucket.count = 0 | |||
| summary_histogram.nan_count = 0 | |||
| summary_histogram.pos_inf_count = 0 | |||
| summary_histogram.neg_inf_count = 0 | |||
| summary_histogram.max = 0 | |||
| summary_histogram.min = 0 | |||
| summary_histogram.sum = 0 | |||
| summary_histogram.count = 0 | |||
| return | |||
| summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value)) | |||
| summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value)) | |||
| summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value)) | |||
| summary_histogram.count = np_value.size | |||
| masked_value = np.ma.masked_invalid(np_value) | |||
| tensor_max = masked_value.max() | |||
| tensor_min = masked_value.min() | |||
| tensor_sum = masked_value.sum() | |||
| # No valid value in tensor. | |||
| if tensor_max is np.ma.masked: | |||
| bucket = summary_histogram.buckets.add() | |||
| bucket.left = default_bucket_left | |||
| bucket.width = default_bucket_width | |||
| bucket.count = 0 | |||
| summary_histogram.max = np.nan | |||
| summary_histogram.min = np.nan | |||
| summary_histogram.sum = 0 | |||
| return | |||
| bin_number = _calc_histogram_bins(masked_value.count()) | |||
| counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max)) | |||
| ma_value = np.ma.masked_invalid(np_value) | |||
| total, valid = np_value.size, ma_value.count() | |||
| invalids = [] | |||
| for isfn in np.isnan, np.isposinf, np.isneginf: | |||
| if total - valid > sum(invalids): | |||
| count = np.count_nonzero(isfn(np_value)) | |||
| invalids.append(count) | |||
| else: | |||
| invalids.append(0) | |||
| for ind, count in enumerate(counts): | |||
| bucket = summary_histogram.buckets.add() | |||
| bucket.left = edges[ind] | |||
| bucket.width = edges[ind + 1] - edges[ind] | |||
| bucket.count = count | |||
| summary.count = total | |||
| summary.nan_count, summary.pos_inf_count, summary.neg_inf_count = invalids | |||
| if not valid: | |||
| logger.warning('There are no valid values in the ndarray(size=%d, shape=%d)', total, np_value.shape) | |||
| # summary.{min, max, sum} are 0s by default, no need to explicitly set | |||
| else: | |||
| summary.min = ma_value.min() | |||
| summary.max = ma_value.max() | |||
| summary.sum = ma_value.sum() | |||
| bins = _calc_histogram_bins(valid) | |||
| range_ = summary.min, summary.max | |||
| hists, edges = np.histogram(np_value, bins=bins, range=range_) | |||
| summary_histogram.max = tensor_max | |||
| summary_histogram.min = tensor_min | |||
| summary_histogram.sum = tensor_sum | |||
| for hist, edge1, edge2 in zip(hists, edges, edges[1:]): | |||
| bucket = summary.buckets.add() | |||
| bucket.width = edge2 - edge1 | |||
| bucket.count = hist | |||
| bucket.left = edge1 | |||
| def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): | |||
| @@ -407,7 +316,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): | |||
| """ | |||
| logger.debug("Set(%r) the image summary value", tag) | |||
| if np_value.ndim != 4: | |||
| logger.error("The value is not Image, tag = %r, Value = %r", tag, np_value) | |||
| logger.error("The value is not Image, tag = %r, ndim = %r", tag, np_value.ndim) | |||
| # convert the tensor format | |||
| tensor = _convert_image_format(np_value, input_format) | |||
| @@ -469,8 +378,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'): | |||
| """ | |||
| out_tensor = None | |||
| if np_tensor.ndim != len(input_format): | |||
| logger.error("The tensor(%r) can't convert the format(%r) because dim not same", | |||
| np_tensor, input_format) | |||
| logger.error("The tensor with dim(%r) can't convert the format(%r) because dim not same", np_tensor.ndim, | |||
| input_format) | |||
| return out_tensor | |||
| input_format = input_format.upper() | |||
| @@ -512,7 +421,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8): | |||
| # check the tensor format | |||
| if tensor.ndim != 4 or tensor.shape[1] != 3: | |||
| logger.error("The image tensor(%r) is not 'NCHW' format", tensor) | |||
| logger.error("The image tensor with ndim(%r) and shape(%r) is not 'NCHW' format", tensor.ndim, tensor.shape) | |||
| return out_canvas | |||
| # expand the N | |||
| @@ -1,308 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Schedule the event writer process.""" | |||
| import multiprocessing as mp | |||
| from enum import Enum, unique | |||
| from mindspore import log as logger | |||
| from ..._c_expression import Tensor | |||
| from ._summary_adapter import SummaryType, package_summary_event, save_summary_data | |||
| # define the type of summary | |||
| FORMAT_SCALAR_STR = "Scalar" | |||
| FORMAT_TENSOR_STR = "Tensor" | |||
| FORMAT_IMAGE_STR = "Image" | |||
| FORMAT_HISTOGRAM_STR = "Histogram" | |||
| FORMAT_BEGIN_SLICE = "[:" | |||
| FORMAT_END_SLICE = "]" | |||
| # cache the summary data dict | |||
| # {id: SummaryData} | |||
| # |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...] | |||
| g_summary_data_id = 0 | |||
| g_summary_data_dict = {} | |||
| # cache the summary data file | |||
| g_summary_writer_id = 0 | |||
| g_summary_file = {} | |||
| @unique | |||
| class ScheduleMethod(Enum): | |||
| """Schedule method type.""" | |||
| FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue | |||
| TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy) | |||
| CACHE_DATA = 2 # Cache data util have idle worker to process it | |||
| @unique | |||
| class WorkerStatus(Enum): | |||
| """Worker status.""" | |||
| WORKER_INIT = 0 # data is exist but not process | |||
| WORKER_PROCESSING = 1 # data is processing | |||
| WORKER_PROCESSED = 2 # data already processed | |||
| def _parse_tag_format(tag: str): | |||
| """ | |||
| Parse the tag. | |||
| Args: | |||
| tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor]. | |||
| Returns: | |||
| Tuple, (SummaryType, summary_tag). | |||
| """ | |||
| summary_type = SummaryType.INVALID | |||
| summary_tag = tag | |||
| if tag is None: | |||
| logger.error("The tag is None") | |||
| return summary_type, summary_tag | |||
| # search the slice | |||
| slice_begin = FORMAT_BEGIN_SLICE | |||
| slice_end = FORMAT_END_SLICE | |||
| index = tag.rfind(slice_begin) | |||
| if index is -1: | |||
| logger.error("The tag(%s) have not the key slice.", tag) | |||
| return summary_type, summary_tag | |||
| # slice the tag | |||
| summary_tag = tag[:index] | |||
| # check the slice end | |||
| if tag[-1:] != slice_end: | |||
| logger.error("The tag(%s) end format is error", tag) | |||
| return summary_type, summary_tag | |||
| # check the type | |||
| type_str = tag[index + 2: -1] | |||
| logger.debug("The summary_tag is = %r", summary_tag) | |||
| logger.debug("The type_str value is = %r", type_str) | |||
| if type_str == FORMAT_SCALAR_STR: | |||
| summary_type = SummaryType.SCALAR | |||
| elif type_str == FORMAT_TENSOR_STR: | |||
| summary_type = SummaryType.TENSOR | |||
| elif type_str == FORMAT_IMAGE_STR: | |||
| summary_type = SummaryType.IMAGE | |||
| elif type_str == FORMAT_HISTOGRAM_STR: | |||
| summary_type = SummaryType.HISTOGRAM | |||
| else: | |||
| logger.error("The tag(%s) type is invalid.", tag) | |||
| summary_type = SummaryType.INVALID | |||
| return summary_type, summary_tag | |||
| class SummaryDataManager: | |||
| """Manage the summary global data cache.""" | |||
| def __init__(self): | |||
| global g_summary_data_dict | |||
| self.size = len(g_summary_data_dict) | |||
| @classmethod | |||
| def summary_data_save(cls, data): | |||
| """Save the global summary cache.""" | |||
| global g_summary_data_id | |||
| data_id = g_summary_data_id | |||
| save_summary_data(data_id, data) | |||
| g_summary_data_id += 1 | |||
| return data_id | |||
| @classmethod | |||
| def summary_file_set(cls, event_writer): | |||
| """Support the many event_writer.""" | |||
| global g_summary_file, g_summary_writer_id | |||
| g_summary_writer_id += 1 | |||
| g_summary_file[g_summary_writer_id] = event_writer | |||
| return g_summary_writer_id | |||
| @classmethod | |||
| def summary_file_get(cls, writer_id=1): | |||
| ret = None | |||
| global g_summary_file | |||
| if writer_id in g_summary_file: | |||
| ret = g_summary_file.get(writer_id) | |||
| return ret | |||
| class WorkerScheduler: | |||
| """ | |||
| Create worker and schedule data to worker. | |||
| Args: | |||
| writer_id (int): The index of writer. | |||
| """ | |||
| def __init__(self, writer_id): | |||
| # Create the process of write event file | |||
| self.write_lock = mp.Lock() | |||
| # Schedule info for all worker | |||
| # Format: {worker: (step, WorkerStatus)} | |||
| self.schedule_table = {} | |||
| # write id | |||
| self.writer_id = writer_id | |||
| self.has_graph = False | |||
| def dispatch(self, step, data): | |||
| """ | |||
| Select schedule strategy and dispatch data. | |||
| Args: | |||
| step (Number): The number of step index. | |||
| data (Object): The data of recode for summary. | |||
| Retruns: | |||
| bool, run successfully or not. | |||
| """ | |||
| # save the data to global cache , convert the tensor to numpy | |||
| result, size, data = self._data_convert(data) | |||
| if result is False: | |||
| logger.error("The step(%r) summary data(%r) is invalid.", step, size) | |||
| return False | |||
| data_id = SummaryDataManager.summary_data_save(data) | |||
| self._start_worker(step, data_id) | |||
| return True | |||
| def _start_worker(self, step, data_id): | |||
| """ | |||
| Start worker. | |||
| Args: | |||
| step (Number): The index of recode. | |||
| data_id (str): The id of work. | |||
| Return: | |||
| bool, run successfully or not. | |||
| """ | |||
| # assign the worker | |||
| policy = self._make_policy() | |||
| if policy == ScheduleMethod.TEMP_WORKER: | |||
| worker = SummaryDataProcess(step, data_id, self.write_lock, self.writer_id) | |||
| # update the schedule table | |||
| self.schedule_table[worker] = (step, data_id, WorkerStatus.WORKER_INIT) | |||
| # start the worker | |||
| worker.start() | |||
| else: | |||
| logger.error("Do not support the other scheduler policy now.") | |||
| # update the scheduler infor | |||
| self._update_scheduler() | |||
| return True | |||
| def _data_convert(self, data_list): | |||
| """Convert the data.""" | |||
| if data_list is None: | |||
| logger.warning("The step does not have record data.") | |||
| return False, 0, None | |||
| # convert the summary to numpy | |||
| size = 0 | |||
| for v_dict in data_list: | |||
| tag = v_dict["name"] | |||
| data = v_dict["data"] | |||
| # confirm the data is valid | |||
| summary_type, summary_tag = _parse_tag_format(tag) | |||
| if summary_type == SummaryType.INVALID: | |||
| logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data) | |||
| return False, 0, None | |||
| if isinstance(data, Tensor): | |||
| # get the summary type and parse the tag | |||
| v_dict["name"] = summary_tag | |||
| v_dict["type"] = summary_type | |||
| v_dict["data"] = data.asnumpy() | |||
| size += v_dict["data"].size | |||
| else: | |||
| logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data) | |||
| return False, 0, None | |||
| return True, size, data_list | |||
| def _update_scheduler(self): | |||
| """Check the worker status and update schedule table.""" | |||
| workers = list(self.schedule_table.keys()) | |||
| for worker in workers: | |||
| if not worker.is_alive(): | |||
| # update the table | |||
| worker.join() | |||
| del self.schedule_table[worker] | |||
| def close(self): | |||
| """Confirm all worker is end.""" | |||
| workers = self.schedule_table.keys() | |||
| for worker in workers: | |||
| if worker.is_alive(): | |||
| worker.join() | |||
| def _make_policy(self): | |||
| """Select the schedule strategy by data.""" | |||
| # now only support the temp worker | |||
| return ScheduleMethod.TEMP_WORKER | |||
| class SummaryDataProcess(mp.Process): | |||
| """ | |||
| Process that consume the summarydata. | |||
| Args: | |||
| step (int): The index of step. | |||
| data_id (int): The index of summary data. | |||
| write_lock (Lock): The process lock for writer same file. | |||
| writer_id (int): The index of writer. | |||
| """ | |||
| def __init__(self, step, data_id, write_lock, writer_id): | |||
| super(SummaryDataProcess, self).__init__() | |||
| self.daemon = True | |||
| self.writer_id = writer_id | |||
| self.writer = SummaryDataManager.summary_file_get(self.writer_id) | |||
| if self.writer is None: | |||
| logger.error("The writer_id(%r) does not have writer", writer_id) | |||
| self.step = step | |||
| self.data_id = data_id | |||
| self.write_lock = write_lock | |||
| self.name = "SummaryDataConsumer_" + str(self.step) | |||
| def run(self): | |||
| """The consumer is process the step data and exit.""" | |||
| # convert the data to event | |||
| # All exceptions need to be caught and end the queue | |||
| try: | |||
| logger.debug("process(%r) process a data(%r)", self.name, self.step) | |||
| # package the summary event | |||
| summary_event = package_summary_event(self.data_id, self.step) | |||
| # send the event to file | |||
| self._write_summary(summary_event) | |||
| except Exception as e: | |||
| logger.error("Summary data mq consumer exception occurred, value = %r", e) | |||
| def _write_summary(self, summary_event): | |||
| """ | |||
| Write the summary to event file. | |||
| Note: | |||
| The write record format: | |||
| 1 uint64 : data length. | |||
| 2 uint32 : mask crc value of data length. | |||
| 3 bytes : data. | |||
| 4 uint32 : mask crc value of data. | |||
| Args: | |||
| summary_event (Event): The summary event of proto. | |||
| """ | |||
| event_str = summary_event.SerializeToString() | |||
| self.write_lock.acquire() | |||
| self.writer.write_event_to_file(event_str) | |||
| self.writer.flush() | |||
| self.write_lock.release() | |||
| @@ -14,17 +14,22 @@ | |||
| # ============================================================================ | |||
| """Record the summary event.""" | |||
| import os | |||
| import re | |||
| import threading | |||
| from mindspore import log as logger | |||
| from ._summary_scheduler import WorkerScheduler, SummaryDataManager | |||
| from ._summary_adapter import get_event_file_name, package_graph_event | |||
| from ._event_writer import EventRecord | |||
| from .._utils import _make_directory | |||
| from ..._c_expression import Tensor | |||
| from ..._checkparam import _check_str_by_regular | |||
| from .._utils import _make_directory | |||
| from ._event_writer import EventWriter | |||
| from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event | |||
| # for the moment, this lock is for caution's sake, | |||
| # there are actually no any concurrencies happening. | |||
| _summary_lock = threading.Lock() | |||
| # cache the summary data | |||
| _summary_tensor_cache = {} | |||
| _summary_lock = threading.Lock() | |||
| def _cache_summary_tensor_data(summary): | |||
| @@ -34,14 +39,18 @@ def _cache_summary_tensor_data(summary): | |||
| Args: | |||
| summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]. | |||
| """ | |||
| _summary_lock.acquire() | |||
| if "SummaryRecord" in _summary_tensor_cache: | |||
| for record in summary: | |||
| _summary_tensor_cache["SummaryRecord"].append(record) | |||
| else: | |||
| _summary_tensor_cache["SummaryRecord"] = summary | |||
| _summary_lock.release() | |||
| return True | |||
| with _summary_lock: | |||
| for item in summary: | |||
| _summary_tensor_cache[item['name']] = item['data'] | |||
| return True | |||
| def _get_summary_tensor_data(): | |||
| global _summary_tensor_cache | |||
| with _summary_lock: | |||
| data = _summary_tensor_cache | |||
| _summary_tensor_cache = {} | |||
| return data | |||
| class SummaryRecord: | |||
| @@ -53,7 +62,7 @@ class SummaryRecord: | |||
| It writes the event log to a file by executing the record method. In addition, | |||
| if the SummaryRecord object is created and the summary operator is used in the network, | |||
| even if the record method is not called, the event in the cache will be written to the | |||
| file at the end of execution or when the summary is closed. | |||
| file at the end of execution. Make sure to close the SummaryRecord object at the end. | |||
| Args: | |||
| log_dir (str): The log_dir is a directory location to save the summary. | |||
| @@ -68,9 +77,10 @@ class SummaryRecord: | |||
| RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname. | |||
| Examples: | |||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> pass | |||
| """ | |||
| def __init__(self, | |||
| log_dir, | |||
| queue_max_size=0, | |||
| @@ -101,26 +111,36 @@ class SummaryRecord: | |||
| self.prefix = file_prefix | |||
| self.suffix = file_suffix | |||
| self.network = network | |||
| self.has_graph = False | |||
| self._closed = False | |||
| # create the summary writer file | |||
| self.event_file_name = get_event_file_name(self.prefix, self.suffix) | |||
| if self.log_path[-1:] == '/': | |||
| self.full_file_name = self.log_path + self.event_file_name | |||
| else: | |||
| self.full_file_name = self.log_path + '/' + self.event_file_name | |||
| try: | |||
| self.full_file_name = os.path.realpath(self.full_file_name) | |||
| self.full_file_name = os.path.join(self.log_path, self.event_file_name) | |||
| except Exception as ex: | |||
| raise RuntimeError(ex) | |||
| self.event_writer = EventRecord(self.full_file_name, self.flush_time) | |||
| self.writer_id = SummaryDataManager.summary_file_set(self.event_writer) | |||
| self.worker_scheduler = WorkerScheduler(self.writer_id) | |||
| self.step = 0 | |||
| self._closed = False | |||
| self.network = network | |||
| self.has_graph = False | |||
| self._event_writer = None | |||
| def _init_event_writer(self): | |||
| """Init event writer and write metadata.""" | |||
| event_writer = EventWriter(self.full_file_name, self.flush_time) | |||
| event_writer.write(package_init_event().SerializeToString()) | |||
| return event_writer | |||
| def __enter__(self): | |||
| """Enter the context manager.""" | |||
| if not self._event_writer: | |||
| self._event_writer = self._init_event_writer() | |||
| if self._closed: | |||
| raise ValueError('SummaryRecord has been closed.') | |||
| return self | |||
| def __exit__(self, extype, exvalue, traceback): | |||
| """Exit the context manager.""" | |||
| self.close() | |||
| def record(self, step, train_network=None): | |||
| """ | |||
| @@ -131,9 +151,8 @@ class SummaryRecord: | |||
| train_network (Cell): The network that called the callback. | |||
| Examples: | |||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||
| >>> summary_record.record(step=2) | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> summary_record.record(step=2) | |||
| Returns: | |||
| bool, whether the record process is successful or not. | |||
| @@ -145,42 +164,37 @@ class SummaryRecord: | |||
| if not isinstance(step, int) or isinstance(step, bool): | |||
| raise ValueError("`step` should be int") | |||
| # Set the current summary of train step | |||
| self.step = step | |||
| if not self._event_writer: | |||
| self._event_writer = self._init_event_writer() | |||
| logger.warning('SummaryRecord should be used as context manager for a with statement.') | |||
| if self.network is not None and self.has_graph is False: | |||
| if self.network is not None and not self.has_graph: | |||
| graph_proto = self.network.get_func_graph_proto() | |||
| if graph_proto is None and train_network is not None: | |||
| graph_proto = train_network.get_func_graph_proto() | |||
| if graph_proto is None: | |||
| logger.error("Failed to get proto for graph") | |||
| else: | |||
| self.event_writer.write_event_to_file( | |||
| package_graph_event(graph_proto).SerializeToString()) | |||
| self.event_writer.flush() | |||
| self._event_writer.write(package_graph_event(graph_proto).SerializeToString()) | |||
| self.has_graph = True | |||
| data = _summary_tensor_cache.get("SummaryRecord") | |||
| if data is None: | |||
| if not _summary_tensor_cache: | |||
| return True | |||
| data = _summary_tensor_cache.get("SummaryRecord") | |||
| if data is None: | |||
| logger.error("The step(%r) does not have record data.", self.step) | |||
| data = _get_summary_tensor_data() | |||
| if not data: | |||
| logger.error("The step(%r) does not have record data.", step) | |||
| return False | |||
| if self.queue_max_size > 0 and len(data) > self.queue_max_size: | |||
| logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), | |||
| self.queue_max_size) | |||
| # clean the data of cache | |||
| del _summary_tensor_cache["SummaryRecord"] | |||
| # process the data | |||
| self.worker_scheduler.dispatch(self.step, data) | |||
| # count & flush | |||
| self.event_writer.count_event() | |||
| self.event_writer.flush_cycle() | |||
| logger.debug("Send the summary data to scheduler for saving, step = %d", self.step) | |||
| result = self._data_convert(data) | |||
| if not result: | |||
| logger.error("The step(%r) summary data is invalid.", step) | |||
| return False | |||
| self._event_writer.write((result, step)) | |||
| logger.debug("Send the summary data to scheduler for saving, step = %d", step) | |||
| return True | |||
| @property | |||
| @@ -189,14 +203,13 @@ class SummaryRecord: | |||
| Get the full path of the log file. | |||
| Examples: | |||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||
| >>> print(summary_record.log_dir) | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> print(summary_record.log_dir) | |||
| Returns: | |||
| String, the full path of log file. | |||
| """ | |||
| return self.event_writer.full_file_name | |||
| return self.full_file_name | |||
| def flush(self): | |||
| """ | |||
| @@ -205,39 +218,64 @@ class SummaryRecord: | |||
| Call it to make sure that all pending events have been written to disk. | |||
| Examples: | |||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||
| >>> summary_record.flush() | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> summary_record.flush() | |||
| """ | |||
| if self._closed: | |||
| logger.error("The record writer is closed and can not flush.") | |||
| else: | |||
| self.event_writer.flush() | |||
| elif self._event_writer: | |||
| self._event_writer.flush() | |||
| def close(self): | |||
| """ | |||
| Flush all events and close summary records. | |||
| Flush all events and close summary records. Please use with statement to autoclose. | |||
| Examples: | |||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||
| >>> summary_record.close() | |||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||
| >>> pass # summary_record autoclosed | |||
| """ | |||
| if not self._closed: | |||
| self._check_data_before_close() | |||
| self.worker_scheduler.close() | |||
| if not self._closed and self._event_writer: | |||
| # event writer flush and close | |||
| self.event_writer.close() | |||
| self._event_writer.close() | |||
| self._closed = True | |||
| def __del__(self): | |||
| """Process exit is called.""" | |||
| if hasattr(self, "worker_scheduler"): | |||
| if self.worker_scheduler: | |||
| self.close() | |||
| def _check_data_before_close(self): | |||
| "Check whether there is any data in the cache, and if so, call record" | |||
| data = _summary_tensor_cache.get("SummaryRecord") | |||
| if data is not None: | |||
| self.record(self.step) | |||
| def __del__(self) -> None: | |||
| self.close() | |||
| def _data_convert(self, summary): | |||
| """Convert the data.""" | |||
| # convert the summary to numpy | |||
| result = [] | |||
| for name, data in summary.items(): | |||
| # confirm the data is valid | |||
| summary_tag, summary_type = SummaryRecord._parse_from(name) | |||
| if summary_tag is None: | |||
| logger.error("The data type is invalid, name = %r, tensor = %r", name, data) | |||
| return None | |||
| if isinstance(data, Tensor): | |||
| result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type}) | |||
| else: | |||
| logger.error("The data type is invalid, name = %r, tensor = %r", name, data) | |||
| return None | |||
| return result | |||
| @staticmethod | |||
| def _parse_from(name: str = None): | |||
| """ | |||
| Parse the tag and type from name. | |||
| Args: | |||
| name (str): Format: TAG[:TYPE]. | |||
| Returns: | |||
| Tuple, (summary_tag, summary_type). | |||
| """ | |||
| if name is None: | |||
| logger.error("The name is None") | |||
| return None, None | |||
| match = re.match(r'(.+)\[:(.+)\]', name) | |||
| if match: | |||
| return match.groups() | |||
| logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name) | |||
| return None, None | |||
| @@ -53,14 +53,13 @@ def me_train_tensor(net, input_np, label_np, epoch_size=2): | |||
| _network = wrap.WithLossCell(net, loss) | |||
| _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) | |||
| _train_net.set_train() | |||
| summary_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) | |||
| for epoch in range(0, epoch_size): | |||
| print(f"epoch %d" % (epoch)) | |||
| output = _train_net(Tensor(input_np), Tensor(label_np)) | |||
| summary_writer.record(i) | |||
| print("********output***********") | |||
| print(output.asnumpy()) | |||
| summary_writer.close() | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) as summary_writer: | |||
| for epoch in range(0, epoch_size): | |||
| print(f"epoch %d" % (epoch)) | |||
| output = _train_net(Tensor(input_np), Tensor(label_np)) | |||
| summary_writer.record(i) | |||
| print("********output***********") | |||
| print(output.asnumpy()) | |||
| def me_infer_tensor(net, input_np): | |||
| @@ -91,15 +91,14 @@ def train_summary_record_scalar_for_1(test_writer, steps, fwd_x, fwd_y): | |||
| def me_scalar_summary(steps, tag=None, value=None): | |||
| test_writer = SummaryRecord(SUMMARY_DIR_ME_TEMP) | |||
| with SummaryRecord(SUMMARY_DIR_ME_TEMP) as test_writer: | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| out_me_dict = train_summary_record_scalar_for_1(test_writer, steps, x, y) | |||
| out_me_dict = train_summary_record_scalar_for_1(test_writer, steps, x, y) | |||
| test_writer.close() | |||
| return out_me_dict | |||
| return out_me_dict | |||
| @pytest.mark.level0 | |||
| @@ -106,18 +106,17 @@ def test_graph_summary_sample(): | |||
| optim = Momentum(net.trainable_params(), 0.1, 0.9) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) | |||
| model.train(2, dataset) | |||
| # step 2: create the Event | |||
| for i in range(1, 5): | |||
| test_writer.record(i) | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer: | |||
| model.train(2, dataset) | |||
| # step 2: create the Event | |||
| for i in range(1, 5): | |||
| test_writer.record(i) | |||
| # step 3: send the event to mq | |||
| # step 3: send the event to mq | |||
| # step 4: accept the event and write the file | |||
| test_writer.close() | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_graph_summary_sample") | |||
| log.debug("finished test_graph_summary_sample") | |||
| def test_graph_summary_callback(): | |||
| @@ -127,9 +126,9 @@ def test_graph_summary_callback(): | |||
| optim = Momentum(net.trainable_params(), 0.1, 0.9) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) | |||
| summary_cb = SummaryStep(test_writer, 1) | |||
| model.train(2, dataset, callbacks=summary_cb) | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer: | |||
| summary_cb = SummaryStep(test_writer, 1) | |||
| model.train(2, dataset, callbacks=summary_cb) | |||
| def test_graph_summary_callback2(): | |||
| @@ -139,6 +138,6 @@ def test_graph_summary_callback2(): | |||
| optim = Momentum(net.trainable_params(), 0.1, 0.9) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) | |||
| summary_cb = SummaryStep(test_writer, 1) | |||
| model.train(2, dataset, callbacks=summary_cb) | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) as test_writer: | |||
| summary_cb = SummaryStep(test_writer, 1) | |||
| model.train(2, dataset, callbacks=summary_cb) | |||
| @@ -52,12 +52,11 @@ def _wrap_test_data(input_data: Tensor): | |||
| def test_histogram_summary(): | |||
| """Test histogram summary.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]])) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| test_writer.close() | |||
| test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]])) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| @@ -68,20 +67,18 @@ def test_histogram_summary(): | |||
| def test_histogram_multi_summary(): | |||
| """Test histogram multiple step.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||
| rng = np.random.RandomState(10) | |||
| size = 50 | |||
| num_step = 5 | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| for i in range(num_step): | |||
| arr = rng.normal(size=size) | |||
| rng = np.random.RandomState(10) | |||
| size = 50 | |||
| num_step = 5 | |||
| test_data = _wrap_test_data(Tensor(arr)) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=i) | |||
| for i in range(num_step): | |||
| arr = rng.normal(size=size) | |||
| test_writer.close() | |||
| test_data = _wrap_test_data(Tensor(arr)) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=i) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| @@ -93,12 +90,11 @@ def test_histogram_multi_summary(): | |||
| def test_histogram_summary_scalar_tensor(): | |||
| """Test histogram summary, input is a scalar tensor.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| test_data = _wrap_test_data(Tensor(1)) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| test_writer.close() | |||
| test_data = _wrap_test_data(Tensor(1)) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| @@ -109,12 +105,11 @@ def test_histogram_summary_scalar_tensor(): | |||
| def test_histogram_summary_empty_tensor(): | |||
| """Test histogram summary, input is an empty tensor.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| test_data = _wrap_test_data(Tensor([])) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| test_writer.close() | |||
| test_data = _wrap_test_data(Tensor([])) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| @@ -125,15 +120,14 @@ def test_histogram_summary_empty_tensor(): | |||
| def test_histogram_summary_same_value(): | |||
| """Test histogram summary, input is an ones tensor.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| dim1 = 100 | |||
| dim2 = 100 | |||
| dim1 = 100 | |||
| dim2 = 100 | |||
| test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2]))) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| test_writer.close() | |||
| test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2]))) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| @@ -146,15 +140,14 @@ def test_histogram_summary_same_value(): | |||
| def test_histogram_summary_high_dims(): | |||
| """Test histogram summary, input is a 4-dimension tensor.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||
| dim = 10 | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| dim = 10 | |||
| rng = np.random.RandomState(0) | |||
| tensor_data = rng.normal(size=[dim, dim, dim, dim]) | |||
| test_data = _wrap_test_data(Tensor(tensor_data)) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| test_writer.close() | |||
| rng = np.random.RandomState(0) | |||
| tensor_data = rng.normal(size=[dim, dim, dim, dim]) | |||
| test_data = _wrap_test_data(Tensor(tensor_data)) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| @@ -167,20 +160,19 @@ def test_histogram_summary_high_dims(): | |||
| def test_histogram_summary_nan_inf(): | |||
| """Test histogram summary, input tensor has nan.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| dim1 = 100 | |||
| dim2 = 100 | |||
| dim1 = 100 | |||
| dim2 = 100 | |||
| arr = np.ones([dim1, dim2]) | |||
| arr[0][0] = np.nan | |||
| arr[0][1] = np.inf | |||
| arr[0][2] = -np.inf | |||
| test_data = _wrap_test_data(Tensor(arr)) | |||
| arr = np.ones([dim1, dim2]) | |||
| arr[0][0] = np.nan | |||
| arr[0][1] = np.inf | |||
| arr[0][2] = -np.inf | |||
| test_data = _wrap_test_data(Tensor(arr)) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| test_writer.close() | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| @@ -193,12 +185,11 @@ def test_histogram_summary_nan_inf(): | |||
| def test_histogram_summary_all_nan_inf(): | |||
| """Test histogram summary, input tensor has no valid number.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||
| test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf]))) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| test_writer.close() | |||
| test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf]))) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| @@ -74,23 +74,21 @@ def test_image_summary_sample(): | |||
| """ test_image_summary_sample """ | |||
| log.debug("begin test_image_summary_sample") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||
| # step 1: create the test data for summary | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| for i in range(1, 5): | |||
| test_data = get_test_data(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| test_writer.flush() | |||
| # step 2: create the Event | |||
| for i in range(1, 5): | |||
| test_data = get_test_data(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| test_writer.flush() | |||
| # step 3: send the event to mq | |||
| # step 3: send the event to mq | |||
| # step 4: accept the event and write the file | |||
| test_writer.close() | |||
| log.debug("finished test_image_summary_sample") | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_image_summary_sample") | |||
| class Net(nn.Cell): | |||
| @@ -174,23 +172,21 @@ def test_image_summary_train(): | |||
| log.debug("begin test_image_summary_sample") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") | |||
| # step 1: create the test data for summary | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||
| # step 2: create the Event | |||
| # step 1: create the test data for summary | |||
| model = get_model() | |||
| fn = ImageSummaryCallback(test_writer) | |||
| summary_recode = SummaryStep(fn, 1) | |||
| model.train(2, dataset, callbacks=summary_recode) | |||
| # step 2: create the Event | |||
| # step 3: send the event to mq | |||
| model = get_model() | |||
| fn = ImageSummaryCallback(test_writer) | |||
| summary_recode = SummaryStep(fn, 1) | |||
| model.train(2, dataset, callbacks=summary_recode) | |||
| # step 4: accept the event and write the file | |||
| test_writer.close() | |||
| # step 3: send the event to mq | |||
| log.debug("finished test_image_summary_sample") | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_image_summary_sample") | |||
| def test_image_summary_data(): | |||
| @@ -209,18 +205,12 @@ def test_image_summary_data(): | |||
| log.debug("begin test_image_summary_sample") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| _cache_summary_tensor_data(test_data_list) | |||
| test_writer.record(1) | |||
| test_writer.flush() | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||
| # step 3: send the event to mq | |||
| # step 1: create the test data for summary | |||
| # step 4: accept the event and write the file | |||
| test_writer.close() | |||
| # step 2: create the Event | |||
| _cache_summary_tensor_data(test_data_list) | |||
| test_writer.record(1) | |||
| log.debug("finished test_image_summary_sample") | |||
| log.debug("finished test_image_summary_sample") | |||
| @@ -65,22 +65,21 @@ def test_scalar_summary_sample(): | |||
| """ test_scalar_summary_sample """ | |||
| log.debug("begin test_scalar_summary_sample") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||
| # step 1: create the test data for summary | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| for i in range(1, 500): | |||
| test_data = get_test_data(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| # step 2: create the Event | |||
| for i in range(1, 500): | |||
| test_data = get_test_data(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| # step 3: send the event to mq | |||
| # step 3: send the event to mq | |||
| # step 4: accept the event and write the file | |||
| test_writer.close() | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_scalar_summary_sample") | |||
| log.debug("finished test_scalar_summary_sample") | |||
| def get_test_data_shape_1(step): | |||
| @@ -110,22 +109,21 @@ def test_scalar_summary_sample_with_shape_1(): | |||
| """ test_scalar_summary_sample_with_shape_1 """ | |||
| log.debug("begin test_scalar_summary_sample_with_shape_1") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||
| # step 1: create the test data for summary | |||
| # step 1: create the test data for summary | |||
| # step 2: create the Event | |||
| for i in range(1, 100): | |||
| test_data = get_test_data_shape_1(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| # step 2: create the Event | |||
| for i in range(1, 100): | |||
| test_data = get_test_data_shape_1(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| # step 3: send the event to mq | |||
| # step 3: send the event to mq | |||
| # step 4: accept the event and write the file | |||
| test_writer.close() | |||
| # step 4: accept the event and write the file | |||
| log.debug("finished test_scalar_summary_sample") | |||
| log.debug("finished test_scalar_summary_sample") | |||
| # Test: test with ge | |||
| @@ -152,26 +150,24 @@ def test_scalar_summary_with_ge(): | |||
| log.debug("begin test_scalar_summary_with_ge") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net = SummaryDemo() | |||
| net.set_train() | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net = SummaryDemo() | |||
| net.set_train() | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| # step 3: close the writer | |||
| test_writer.close() | |||
| log.debug("finished test_scalar_summary_with_ge") | |||
| log.debug("finished test_scalar_summary_with_ge") | |||
| # test the problem of two consecutive use cases going wrong | |||
| @@ -180,55 +176,52 @@ def test_scalar_summary_with_ge_2(): | |||
| log.debug("begin test_scalar_summary_with_ge_2") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net = SummaryDemo() | |||
| net.set_train() | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| net = SummaryDemo() | |||
| net.set_train() | |||
| # step 3: close the writer | |||
| test_writer.close() | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| log.debug("finished test_scalar_summary_with_ge_2") | |||
| log.debug("finished test_scalar_summary_with_ge_2") | |||
| def test_validate(): | |||
| sr = SummaryRecord(SUMMARY_DIR) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, 0) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, -1) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, 1.2) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, True) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, "str") | |||
| sr.record(1) | |||
| with pytest.raises(ValueError): | |||
| sr.record(False) | |||
| with pytest.raises(ValueError): | |||
| sr.record(2.0) | |||
| with pytest.raises(ValueError): | |||
| sr.record((1, 3)) | |||
| with pytest.raises(ValueError): | |||
| sr.record([2, 3]) | |||
| with pytest.raises(ValueError): | |||
| sr.record("str") | |||
| with pytest.raises(ValueError): | |||
| sr.record(sr) | |||
| sr.close() | |||
| def test_validate(): | |||
| with SummaryRecord(SUMMARY_DIR) as sr: | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, 0) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, -1) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, 1.2) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, True) | |||
| with pytest.raises(ValueError): | |||
| SummaryStep(sr, "str") | |||
| sr.record(1) | |||
| with pytest.raises(ValueError): | |||
| sr.record(False) | |||
| with pytest.raises(ValueError): | |||
| sr.record(2.0) | |||
| with pytest.raises(ValueError): | |||
| sr.record((1, 3)) | |||
| with pytest.raises(ValueError): | |||
| sr.record([2, 3]) | |||
| with pytest.raises(ValueError): | |||
| sr.record("str") | |||
| with pytest.raises(ValueError): | |||
| sr.record(sr) | |||
| SummaryStep(sr, 1) | |||
| with pytest.raises(ValueError): | |||
| @@ -126,23 +126,21 @@ class HistogramSummaryNet(nn.Cell): | |||
| def run_case(net): | |||
| """ run_case """ | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR) | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net.set_train() | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| # step 3: close the writer | |||
| test_writer.close() | |||
| with SummaryRecord(SUMMARY_DIR) as test_writer: | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net.set_train() | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| # Test 1: use the repeat tag | |||
| @@ -80,19 +80,18 @@ def test_tensor_summary_sample(): | |||
| """ test_tensor_summary_sample """ | |||
| log.debug("begin test_tensor_summary_sample") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_TENSOR") | |||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_TENSOR") as test_writer: | |||
| # step 1: create the Event | |||
| for i in range(1, 100): | |||
| test_data = get_test_data(i) | |||
| # step 1: create the Event | |||
| for i in range(1, 100): | |||
| test_data = get_test_data(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| _cache_summary_tensor_data(test_data) | |||
| test_writer.record(i) | |||
| # step 2: accept the event and write the file | |||
| test_writer.close() | |||
| # step 2: accept the event and write the file | |||
| log.debug("finished test_tensor_summary_sample") | |||
| log.debug("finished test_tensor_summary_sample") | |||
| def get_test_data_check(step): | |||
| @@ -131,23 +130,20 @@ def test_tensor_summary_with_ge(): | |||
| log.debug("begin test_tensor_summary_with_ge") | |||
| # step 0: create the thread | |||
| test_writer = SummaryRecord(SUMMARY_DIR) | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net = SummaryDemo() | |||
| net.set_train() | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([[i], [i]]).astype(np.float32)) | |||
| y = Tensor(np.array([[i + 1], [i + 1]]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| # step 3: close the writer | |||
| test_writer.close() | |||
| log.debug("finished test_tensor_summary_with_ge") | |||
| with SummaryRecord(SUMMARY_DIR) as test_writer: | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| net = SummaryDemo() | |||
| net.set_train() | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([[i], [i]]).astype(np.float32)) | |||
| y = Tensor(np.array([[i + 1], [i + 1]]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| log.debug("finished test_tensor_summary_with_ge") | |||