Merge pull request !1069 from 李鸿章/context_managertags/v0.3.0-alpha
| @@ -14,91 +14,74 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Writes events to disk in a logdir.""" | """Writes events to disk in a logdir.""" | ||||
| import os | import os | ||||
| import time | |||||
| import stat | import stat | ||||
| from mindspore import log as logger | |||||
| from collections import deque | |||||
| from multiprocessing import Pool, Process, Queue, cpu_count | |||||
| from ..._c_expression import EventWriter_ | from ..._c_expression import EventWriter_ | ||||
| from ._summary_adapter import package_init_event | |||||
| from ._summary_adapter import package_summary_event | |||||
| class _WrapEventWriter(EventWriter_): | |||||
| """ | |||||
| Wrap the c++ EventWriter object. | |||||
| def _pack(result, step): | |||||
| summary_event = package_summary_event(result, step) | |||||
| return summary_event.SerializeToString() | |||||
| Args: | |||||
| full_file_name (str): Include directory and file name. | |||||
| """ | |||||
| def __init__(self, full_file_name): | |||||
| if full_file_name is not None: | |||||
| EventWriter_.__init__(self, full_file_name) | |||||
| class EventRecord: | |||||
| class EventWriter(Process): | |||||
| """ | """ | ||||
| Creates a `EventFileWriter` and write event to file. | |||||
| Creates a `EventWriter` and write event to file. | |||||
| Args: | Args: | ||||
| full_file_name (str): Summary event file path and file name. | |||||
| flush_time (int): The flush seconds to flush the pending events to disk. Default: 120. | |||||
| filepath (str): Summary event file path and file name. | |||||
| flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120. | |||||
| """ | """ | ||||
| def __init__(self, full_file_name: str, flush_time: int = 120): | |||||
| self.full_file_name = full_file_name | |||||
| # The first event will be flushed immediately. | |||||
| self.flush_time = flush_time | |||||
| self.next_flush_time = 0 | |||||
| # create event write object | |||||
| self.event_writer = self._create_event_file() | |||||
| self._init_event_file() | |||||
| # count the events | |||||
| self.event_count = 0 | |||||
| def _create_event_file(self): | |||||
| """Create the event write file.""" | |||||
| with open(self.full_file_name, 'w'): | |||||
| os.chmod(self.full_file_name, stat.S_IWUSR | stat.S_IRUSR) | |||||
| # create c++ event write object | |||||
| event_writer = _WrapEventWriter(self.full_file_name) | |||||
| return event_writer | |||||
| def _init_event_file(self): | |||||
| """Send the init event to file.""" | |||||
| self.event_writer.Write((package_init_event()).SerializeToString()) | |||||
| self.flush() | |||||
| return True | |||||
| def write_event_to_file(self, event_str): | |||||
| """Write the event to file.""" | |||||
| self.event_writer.Write(event_str) | |||||
| def get_data_count(self): | |||||
| """Return the event count.""" | |||||
| return self.event_count | |||||
| def flush_cycle(self): | |||||
| """Flush file by timer.""" | |||||
| self.event_count = self.event_count + 1 | |||||
| # Flush the event writer every so often. | |||||
| now = int(time.time()) | |||||
| if now > self.next_flush_time: | |||||
| self.flush() | |||||
| # update the flush time | |||||
| self.next_flush_time = now + self.flush_time | |||||
| def count_event(self): | |||||
| """Count event.""" | |||||
| logger.debug("Write the event count is %r", self.event_count) | |||||
| self.event_count = self.event_count + 1 | |||||
| return self.event_count | |||||
| def __init__(self, filepath: str, flush_interval: int) -> None: | |||||
| super().__init__() | |||||
| with open(filepath, 'w'): | |||||
| os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR) | |||||
| self._writer = EventWriter_(filepath) | |||||
| self._queue = Queue(cpu_count() * 2) | |||||
| self.start() | |||||
| def run(self): | |||||
| with Pool() as pool: | |||||
| deq = deque() | |||||
| while True: | |||||
| while deq and deq[0].ready(): | |||||
| self._writer.Write(deq.popleft().get()) | |||||
| if not self._queue.empty(): | |||||
| action, data = self._queue.get() | |||||
| if action == 'WRITE': | |||||
| if not isinstance(data, (str, bytes)): | |||||
| deq.append(pool.apply_async(_pack, data)) | |||||
| else: | |||||
| self._writer.Write(data) | |||||
| elif action == 'FLUSH': | |||||
| self._writer.Flush() | |||||
| elif action == 'END': | |||||
| break | |||||
| for res in deq: | |||||
| self._writer.Write(res.get()) | |||||
| self._writer.Shut() | |||||
| def write(self, data) -> None: | |||||
| """ | |||||
| Write the event to file. | |||||
| Args: | |||||
| data (Optional[str, Tuple[list, int]]): The data to write. | |||||
| """ | |||||
| self._queue.put(('WRITE', data)) | |||||
| def flush(self): | def flush(self): | ||||
| """Flush the event file to disk.""" | |||||
| self.event_writer.Flush() | |||||
| """Flush the writer.""" | |||||
| self._queue.put(('FLUSH', None)) | |||||
| def close(self): | |||||
| """Flush the event file to disk and close the file.""" | |||||
| self.flush() | |||||
| self.event_writer.Shut() | |||||
| def close(self) -> None: | |||||
| """Close the writer.""" | |||||
| self._queue.put(('END', None)) | |||||
| self.join() | |||||
| @@ -13,17 +13,17 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Generate the summary event which conform to proto format.""" | """Generate the summary event which conform to proto format.""" | ||||
| import time | |||||
| import socket | import socket | ||||
| import math | |||||
| from enum import Enum, unique | |||||
| import time | |||||
| import numpy as np | import numpy as np | ||||
| from PIL import Image | from PIL import Image | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from ..summary_pb2 import Event | |||||
| from ..anf_ir_pb2 import ModelProto, DataType | |||||
| from ..._checkparam import _check_str_by_regular | from ..._checkparam import _check_str_by_regular | ||||
| from ..anf_ir_pb2 import DataType, ModelProto | |||||
| from ..summary_pb2 import Event | |||||
| # define the MindSpore image format | # define the MindSpore image format | ||||
| MS_IMAGE_TENSOR_FORMAT = 'NCHW' | MS_IMAGE_TENSOR_FORMAT = 'NCHW' | ||||
| @@ -32,55 +32,6 @@ EVENT_FILE_NAME_MARK = ".out.events.summary." | |||||
| # Set the init event of version and mark | # Set the init event of version and mark | ||||
| EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:" | EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:" | ||||
| EVENT_FILE_INIT_VERSION = 1 | EVENT_FILE_INIT_VERSION = 1 | ||||
| # cache the summary data dict | |||||
| # {id: SummaryData} | |||||
| # |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...] | |||||
| g_summary_data_dict = {} | |||||
| def save_summary_data(data_id, data): | |||||
| """Save the global summary cache.""" | |||||
| global g_summary_data_dict | |||||
| g_summary_data_dict[data_id] = data | |||||
| def del_summary_data(data_id): | |||||
| """Save the global summary cache.""" | |||||
| global g_summary_data_dict | |||||
| if data_id in g_summary_data_dict: | |||||
| del g_summary_data_dict[data_id] | |||||
| else: | |||||
| logger.warning("Can't del the data because data_id(%r) " | |||||
| "does not have data in g_summary_data_dict", data_id) | |||||
| def get_summary_data(data_id): | |||||
| """Save the global summary cache.""" | |||||
| ret = None | |||||
| global g_summary_data_dict | |||||
| if data_id in g_summary_data_dict: | |||||
| ret = g_summary_data_dict.get(data_id) | |||||
| else: | |||||
| logger.warning("The data_id(%r) does not have data in g_summary_data_dict", data_id) | |||||
| return ret | |||||
| @unique | |||||
| class SummaryType(Enum): | |||||
| """ | |||||
| Summary type. | |||||
| Args: | |||||
| SCALAR (Number): Summary Scalar enum. | |||||
| TENSOR (Number): Summary TENSOR enum. | |||||
| IMAGE (Number): Summary image enum. | |||||
| GRAPH (Number): Summary graph enum. | |||||
| HISTOGRAM (Number): Summary histogram enum. | |||||
| INVALID (Number): Unknow type. | |||||
| """ | |||||
| SCALAR = 1 # Scalar summary | |||||
| TENSOR = 2 # Tensor summary | |||||
| IMAGE = 3 # Image summary | |||||
| GRAPH = 4 # graph | |||||
| HISTOGRAM = 5 # Histogram Summary | |||||
| INVALID = 0xFF # unknow type | |||||
| def get_event_file_name(prefix, suffix): | def get_event_file_name(prefix, suffix): | ||||
| @@ -138,7 +89,7 @@ def package_graph_event(data): | |||||
| return graph_event | return graph_event | ||||
| def package_summary_event(data_id, step): | |||||
| def package_summary_event(data_list, step): | |||||
| """ | """ | ||||
| Package the summary to event protobuffer. | Package the summary to event protobuffer. | ||||
| @@ -149,50 +100,37 @@ def package_summary_event(data_id, step): | |||||
| Returns: | Returns: | ||||
| Summary, the summary event. | Summary, the summary event. | ||||
| """ | """ | ||||
| data_list = get_summary_data(data_id) | |||||
| if data_list is None: | |||||
| logger.error("The step(%r) does not have record data.", step) | |||||
| del_summary_data(data_id) | |||||
| # create the event of summary | # create the event of summary | ||||
| summary_event = Event() | summary_event = Event() | ||||
| summary = summary_event.summary | summary = summary_event.summary | ||||
| summary_event.wall_time = time.time() | |||||
| summary_event.step = int(step) | |||||
| for value in data_list: | for value in data_list: | ||||
| tag = value["name"] | |||||
| summary_type = value["_type"] | |||||
| data = value["data"] | data = value["data"] | ||||
| summary_type = value["type"] | |||||
| tag = value["name"] | |||||
| logger.debug("Now process %r summary, tag = %r", summary_type, tag) | |||||
| summary_value = summary.value.add() | |||||
| summary_value.tag = tag | |||||
| # get the summary type and parse the tag | # get the summary type and parse the tag | ||||
| if summary_type is SummaryType.SCALAR: | |||||
| logger.debug("Now process Scalar summary, tag = %r", tag) | |||||
| summary_value = summary.value.add() | |||||
| summary_value.tag = tag | |||||
| if summary_type == 'Scalar': | |||||
| summary_value.scalar_value = _get_scalar_summary(tag, data) | summary_value.scalar_value = _get_scalar_summary(tag, data) | ||||
| elif summary_type is SummaryType.TENSOR: | |||||
| logger.debug("Now process Tensor summary, tag = %r", tag) | |||||
| summary_value = summary.value.add() | |||||
| summary_value.tag = tag | |||||
| elif summary_type == 'Tensor': | |||||
| summary_tensor = summary_value.tensor | summary_tensor = summary_value.tensor | ||||
| _get_tensor_summary(tag, data, summary_tensor) | _get_tensor_summary(tag, data, summary_tensor) | ||||
| elif summary_type is SummaryType.IMAGE: | |||||
| logger.debug("Now process Image summary, tag = %r", tag) | |||||
| summary_value = summary.value.add() | |||||
| summary_value.tag = tag | |||||
| elif summary_type == 'Image': | |||||
| summary_image = summary_value.image | summary_image = summary_value.image | ||||
| _get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT) | _get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT) | ||||
| elif summary_type is SummaryType.HISTOGRAM: | |||||
| logger.debug("Now process Histogram summary, tag = %r", tag) | |||||
| summary_value = summary.value.add() | |||||
| summary_value.tag = tag | |||||
| elif summary_type == 'Histogram': | |||||
| summary_histogram = summary_value.histogram | summary_histogram = summary_value.histogram | ||||
| _fill_histogram_summary(tag, data, summary_histogram) | _fill_histogram_summary(tag, data, summary_histogram) | ||||
| else: | else: | ||||
| # The data is invalid ,jump the data | # The data is invalid ,jump the data | ||||
| logger.error("Summary type is error, tag = %r", tag) | |||||
| continue | |||||
| logger.error("Summary type(%r) is error, tag = %r", summary_type, tag) | |||||
| summary_event.wall_time = time.time() | |||||
| summary_event.step = int(step) | |||||
| return summary_event | return summary_event | ||||
| @@ -255,11 +193,11 @@ def _get_scalar_summary(tag: str, np_value): | |||||
| # So consider the dim = 1, shape = (1,) tensor is scalar | # So consider the dim = 1, shape = (1,) tensor is scalar | ||||
| scalar_value = np_value[0] | scalar_value = np_value[0] | ||||
| if np_value.shape != (1,): | if np_value.shape != (1,): | ||||
| logger.error("The tensor is not Scalar, tag = %r, Value = %r", tag, np_value) | |||||
| logger.error("The tensor is not Scalar, tag = %r, Shape = %r", tag, np_value.shape) | |||||
| else: | else: | ||||
| np_list = np_value.reshape(-1).tolist() | np_list = np_value.reshape(-1).tolist() | ||||
| scalar_value = np_list[0] | scalar_value = np_list[0] | ||||
| logger.error("The value is not Scalar, tag = %r, Value = %r", tag, np_value) | |||||
| logger.error("The value is not Scalar, tag = %r, ndim = %r", tag, np_value.ndim) | |||||
| logger.debug("The tag(%r) value is: %r", tag, scalar_value) | logger.debug("The tag(%r) value is: %r", tag, scalar_value) | ||||
| return scalar_value | return scalar_value | ||||
| @@ -307,8 +245,7 @@ def _calc_histogram_bins(count): | |||||
| Returns: | Returns: | ||||
| int, number of histogram bins. | int, number of histogram bins. | ||||
| """ | """ | ||||
| number_per_bucket = 10 | |||||
| max_bins = 90 | |||||
| max_bins, max_per_bin = 90, 10 | |||||
| if not count: | if not count: | ||||
| return 1 | return 1 | ||||
| @@ -318,78 +255,50 @@ def _calc_histogram_bins(count): | |||||
| return 3 | return 3 | ||||
| if count <= 880: | if count <= 880: | ||||
| # note that math.ceil(881/10) + 1 equals 90 | # note that math.ceil(881/10) + 1 equals 90 | ||||
| return int(math.ceil(count / number_per_bucket) + 1) | |||||
| return count // max_per_bin + 1 | |||||
| return max_bins | return max_bins | ||||
| def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None: | |||||
| def _fill_histogram_summary(tag: str, np_value: np.ndarray, summary) -> None: | |||||
| """ | """ | ||||
| Package the histogram summary. | Package the histogram summary. | ||||
| Args: | Args: | ||||
| tag (str): Summary tag describe. | tag (str): Summary tag describe. | ||||
| np_value (np.array): Summary data. | |||||
| summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data. | |||||
| np_value (np.ndarray): Summary data. | |||||
| summary (summary_pb2.Summary.Histogram): Summary histogram data. | |||||
| """ | """ | ||||
| logger.debug("Set(%r) the histogram summary value", tag) | logger.debug("Set(%r) the histogram summary value", tag) | ||||
| # Default bucket for tensor with no valid data. | # Default bucket for tensor with no valid data. | ||||
| default_bucket_left = -0.5 | |||||
| default_bucket_width = 1.0 | |||||
| if np_value.size == 0: | |||||
| bucket = summary_histogram.buckets.add() | |||||
| bucket.left = default_bucket_left | |||||
| bucket.width = default_bucket_width | |||||
| bucket.count = 0 | |||||
| summary_histogram.nan_count = 0 | |||||
| summary_histogram.pos_inf_count = 0 | |||||
| summary_histogram.neg_inf_count = 0 | |||||
| summary_histogram.max = 0 | |||||
| summary_histogram.min = 0 | |||||
| summary_histogram.sum = 0 | |||||
| summary_histogram.count = 0 | |||||
| return | |||||
| summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value)) | |||||
| summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value)) | |||||
| summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value)) | |||||
| summary_histogram.count = np_value.size | |||||
| masked_value = np.ma.masked_invalid(np_value) | |||||
| tensor_max = masked_value.max() | |||||
| tensor_min = masked_value.min() | |||||
| tensor_sum = masked_value.sum() | |||||
| # No valid value in tensor. | |||||
| if tensor_max is np.ma.masked: | |||||
| bucket = summary_histogram.buckets.add() | |||||
| bucket.left = default_bucket_left | |||||
| bucket.width = default_bucket_width | |||||
| bucket.count = 0 | |||||
| summary_histogram.max = np.nan | |||||
| summary_histogram.min = np.nan | |||||
| summary_histogram.sum = 0 | |||||
| return | |||||
| bin_number = _calc_histogram_bins(masked_value.count()) | |||||
| counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max)) | |||||
| ma_value = np.ma.masked_invalid(np_value) | |||||
| total, valid = np_value.size, ma_value.count() | |||||
| invalids = [] | |||||
| for isfn in np.isnan, np.isposinf, np.isneginf: | |||||
| if total - valid > sum(invalids): | |||||
| count = np.count_nonzero(isfn(np_value)) | |||||
| invalids.append(count) | |||||
| else: | |||||
| invalids.append(0) | |||||
| for ind, count in enumerate(counts): | |||||
| bucket = summary_histogram.buckets.add() | |||||
| bucket.left = edges[ind] | |||||
| bucket.width = edges[ind + 1] - edges[ind] | |||||
| bucket.count = count | |||||
| summary.count = total | |||||
| summary.nan_count, summary.pos_inf_count, summary.neg_inf_count = invalids | |||||
| if not valid: | |||||
| logger.warning('There are no valid values in the ndarray(size=%d, shape=%d)', total, np_value.shape) | |||||
| # summary.{min, max, sum} are 0s by default, no need to explicitly set | |||||
| else: | |||||
| summary.min = ma_value.min() | |||||
| summary.max = ma_value.max() | |||||
| summary.sum = ma_value.sum() | |||||
| bins = _calc_histogram_bins(valid) | |||||
| range_ = summary.min, summary.max | |||||
| hists, edges = np.histogram(np_value, bins=bins, range=range_) | |||||
| summary_histogram.max = tensor_max | |||||
| summary_histogram.min = tensor_min | |||||
| summary_histogram.sum = tensor_sum | |||||
| for hist, edge1, edge2 in zip(hists, edges, edges[1:]): | |||||
| bucket = summary.buckets.add() | |||||
| bucket.width = edge2 - edge1 | |||||
| bucket.count = hist | |||||
| bucket.left = edge1 | |||||
| def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): | def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): | ||||
| @@ -407,7 +316,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): | |||||
| """ | """ | ||||
| logger.debug("Set(%r) the image summary value", tag) | logger.debug("Set(%r) the image summary value", tag) | ||||
| if np_value.ndim != 4: | if np_value.ndim != 4: | ||||
| logger.error("The value is not Image, tag = %r, Value = %r", tag, np_value) | |||||
| logger.error("The value is not Image, tag = %r, ndim = %r", tag, np_value.ndim) | |||||
| # convert the tensor format | # convert the tensor format | ||||
| tensor = _convert_image_format(np_value, input_format) | tensor = _convert_image_format(np_value, input_format) | ||||
| @@ -469,8 +378,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'): | |||||
| """ | """ | ||||
| out_tensor = None | out_tensor = None | ||||
| if np_tensor.ndim != len(input_format): | if np_tensor.ndim != len(input_format): | ||||
| logger.error("The tensor(%r) can't convert the format(%r) because dim not same", | |||||
| np_tensor, input_format) | |||||
| logger.error("The tensor with dim(%r) can't convert the format(%r) because dim not same", np_tensor.ndim, | |||||
| input_format) | |||||
| return out_tensor | return out_tensor | ||||
| input_format = input_format.upper() | input_format = input_format.upper() | ||||
| @@ -512,7 +421,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8): | |||||
| # check the tensor format | # check the tensor format | ||||
| if tensor.ndim != 4 or tensor.shape[1] != 3: | if tensor.ndim != 4 or tensor.shape[1] != 3: | ||||
| logger.error("The image tensor(%r) is not 'NCHW' format", tensor) | |||||
| logger.error("The image tensor with ndim(%r) and shape(%r) is not 'NCHW' format", tensor.ndim, tensor.shape) | |||||
| return out_canvas | return out_canvas | ||||
| # expand the N | # expand the N | ||||
| @@ -1,308 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Schedule the event writer process.""" | |||||
| import multiprocessing as mp | |||||
| from enum import Enum, unique | |||||
| from mindspore import log as logger | |||||
| from ..._c_expression import Tensor | |||||
| from ._summary_adapter import SummaryType, package_summary_event, save_summary_data | |||||
| # define the type of summary | |||||
| FORMAT_SCALAR_STR = "Scalar" | |||||
| FORMAT_TENSOR_STR = "Tensor" | |||||
| FORMAT_IMAGE_STR = "Image" | |||||
| FORMAT_HISTOGRAM_STR = "Histogram" | |||||
| FORMAT_BEGIN_SLICE = "[:" | |||||
| FORMAT_END_SLICE = "]" | |||||
| # cache the summary data dict | |||||
| # {id: SummaryData} | |||||
| # |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...] | |||||
| g_summary_data_id = 0 | |||||
| g_summary_data_dict = {} | |||||
| # cache the summary data file | |||||
| g_summary_writer_id = 0 | |||||
| g_summary_file = {} | |||||
| @unique | |||||
| class ScheduleMethod(Enum): | |||||
| """Schedule method type.""" | |||||
| FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue | |||||
| TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy) | |||||
| CACHE_DATA = 2 # Cache data util have idle worker to process it | |||||
| @unique | |||||
| class WorkerStatus(Enum): | |||||
| """Worker status.""" | |||||
| WORKER_INIT = 0 # data is exist but not process | |||||
| WORKER_PROCESSING = 1 # data is processing | |||||
| WORKER_PROCESSED = 2 # data already processed | |||||
| def _parse_tag_format(tag: str): | |||||
| """ | |||||
| Parse the tag. | |||||
| Args: | |||||
| tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor]. | |||||
| Returns: | |||||
| Tuple, (SummaryType, summary_tag). | |||||
| """ | |||||
| summary_type = SummaryType.INVALID | |||||
| summary_tag = tag | |||||
| if tag is None: | |||||
| logger.error("The tag is None") | |||||
| return summary_type, summary_tag | |||||
| # search the slice | |||||
| slice_begin = FORMAT_BEGIN_SLICE | |||||
| slice_end = FORMAT_END_SLICE | |||||
| index = tag.rfind(slice_begin) | |||||
| if index is -1: | |||||
| logger.error("The tag(%s) have not the key slice.", tag) | |||||
| return summary_type, summary_tag | |||||
| # slice the tag | |||||
| summary_tag = tag[:index] | |||||
| # check the slice end | |||||
| if tag[-1:] != slice_end: | |||||
| logger.error("The tag(%s) end format is error", tag) | |||||
| return summary_type, summary_tag | |||||
| # check the type | |||||
| type_str = tag[index + 2: -1] | |||||
| logger.debug("The summary_tag is = %r", summary_tag) | |||||
| logger.debug("The type_str value is = %r", type_str) | |||||
| if type_str == FORMAT_SCALAR_STR: | |||||
| summary_type = SummaryType.SCALAR | |||||
| elif type_str == FORMAT_TENSOR_STR: | |||||
| summary_type = SummaryType.TENSOR | |||||
| elif type_str == FORMAT_IMAGE_STR: | |||||
| summary_type = SummaryType.IMAGE | |||||
| elif type_str == FORMAT_HISTOGRAM_STR: | |||||
| summary_type = SummaryType.HISTOGRAM | |||||
| else: | |||||
| logger.error("The tag(%s) type is invalid.", tag) | |||||
| summary_type = SummaryType.INVALID | |||||
| return summary_type, summary_tag | |||||
| class SummaryDataManager: | |||||
| """Manage the summary global data cache.""" | |||||
| def __init__(self): | |||||
| global g_summary_data_dict | |||||
| self.size = len(g_summary_data_dict) | |||||
| @classmethod | |||||
| def summary_data_save(cls, data): | |||||
| """Save the global summary cache.""" | |||||
| global g_summary_data_id | |||||
| data_id = g_summary_data_id | |||||
| save_summary_data(data_id, data) | |||||
| g_summary_data_id += 1 | |||||
| return data_id | |||||
| @classmethod | |||||
| def summary_file_set(cls, event_writer): | |||||
| """Support the many event_writer.""" | |||||
| global g_summary_file, g_summary_writer_id | |||||
| g_summary_writer_id += 1 | |||||
| g_summary_file[g_summary_writer_id] = event_writer | |||||
| return g_summary_writer_id | |||||
| @classmethod | |||||
| def summary_file_get(cls, writer_id=1): | |||||
| ret = None | |||||
| global g_summary_file | |||||
| if writer_id in g_summary_file: | |||||
| ret = g_summary_file.get(writer_id) | |||||
| return ret | |||||
| class WorkerScheduler: | |||||
| """ | |||||
| Create worker and schedule data to worker. | |||||
| Args: | |||||
| writer_id (int): The index of writer. | |||||
| """ | |||||
| def __init__(self, writer_id): | |||||
| # Create the process of write event file | |||||
| self.write_lock = mp.Lock() | |||||
| # Schedule info for all worker | |||||
| # Format: {worker: (step, WorkerStatus)} | |||||
| self.schedule_table = {} | |||||
| # write id | |||||
| self.writer_id = writer_id | |||||
| self.has_graph = False | |||||
| def dispatch(self, step, data): | |||||
| """ | |||||
| Select schedule strategy and dispatch data. | |||||
| Args: | |||||
| step (Number): The number of step index. | |||||
| data (Object): The data of recode for summary. | |||||
| Retruns: | |||||
| bool, run successfully or not. | |||||
| """ | |||||
| # save the data to global cache , convert the tensor to numpy | |||||
| result, size, data = self._data_convert(data) | |||||
| if result is False: | |||||
| logger.error("The step(%r) summary data(%r) is invalid.", step, size) | |||||
| return False | |||||
| data_id = SummaryDataManager.summary_data_save(data) | |||||
| self._start_worker(step, data_id) | |||||
| return True | |||||
| def _start_worker(self, step, data_id): | |||||
| """ | |||||
| Start worker. | |||||
| Args: | |||||
| step (Number): The index of recode. | |||||
| data_id (str): The id of work. | |||||
| Return: | |||||
| bool, run successfully or not. | |||||
| """ | |||||
| # assign the worker | |||||
| policy = self._make_policy() | |||||
| if policy == ScheduleMethod.TEMP_WORKER: | |||||
| worker = SummaryDataProcess(step, data_id, self.write_lock, self.writer_id) | |||||
| # update the schedule table | |||||
| self.schedule_table[worker] = (step, data_id, WorkerStatus.WORKER_INIT) | |||||
| # start the worker | |||||
| worker.start() | |||||
| else: | |||||
| logger.error("Do not support the other scheduler policy now.") | |||||
| # update the scheduler infor | |||||
| self._update_scheduler() | |||||
| return True | |||||
| def _data_convert(self, data_list): | |||||
| """Convert the data.""" | |||||
| if data_list is None: | |||||
| logger.warning("The step does not have record data.") | |||||
| return False, 0, None | |||||
| # convert the summary to numpy | |||||
| size = 0 | |||||
| for v_dict in data_list: | |||||
| tag = v_dict["name"] | |||||
| data = v_dict["data"] | |||||
| # confirm the data is valid | |||||
| summary_type, summary_tag = _parse_tag_format(tag) | |||||
| if summary_type == SummaryType.INVALID: | |||||
| logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data) | |||||
| return False, 0, None | |||||
| if isinstance(data, Tensor): | |||||
| # get the summary type and parse the tag | |||||
| v_dict["name"] = summary_tag | |||||
| v_dict["type"] = summary_type | |||||
| v_dict["data"] = data.asnumpy() | |||||
| size += v_dict["data"].size | |||||
| else: | |||||
| logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data) | |||||
| return False, 0, None | |||||
| return True, size, data_list | |||||
| def _update_scheduler(self): | |||||
| """Check the worker status and update schedule table.""" | |||||
| workers = list(self.schedule_table.keys()) | |||||
| for worker in workers: | |||||
| if not worker.is_alive(): | |||||
| # update the table | |||||
| worker.join() | |||||
| del self.schedule_table[worker] | |||||
| def close(self): | |||||
| """Confirm all worker is end.""" | |||||
| workers = self.schedule_table.keys() | |||||
| for worker in workers: | |||||
| if worker.is_alive(): | |||||
| worker.join() | |||||
| def _make_policy(self): | |||||
| """Select the schedule strategy by data.""" | |||||
| # now only support the temp worker | |||||
| return ScheduleMethod.TEMP_WORKER | |||||
| class SummaryDataProcess(mp.Process): | |||||
| """ | |||||
| Process that consume the summarydata. | |||||
| Args: | |||||
| step (int): The index of step. | |||||
| data_id (int): The index of summary data. | |||||
| write_lock (Lock): The process lock for writer same file. | |||||
| writer_id (int): The index of writer. | |||||
| """ | |||||
| def __init__(self, step, data_id, write_lock, writer_id): | |||||
| super(SummaryDataProcess, self).__init__() | |||||
| self.daemon = True | |||||
| self.writer_id = writer_id | |||||
| self.writer = SummaryDataManager.summary_file_get(self.writer_id) | |||||
| if self.writer is None: | |||||
| logger.error("The writer_id(%r) does not have writer", writer_id) | |||||
| self.step = step | |||||
| self.data_id = data_id | |||||
| self.write_lock = write_lock | |||||
| self.name = "SummaryDataConsumer_" + str(self.step) | |||||
| def run(self): | |||||
| """The consumer is process the step data and exit.""" | |||||
| # convert the data to event | |||||
| # All exceptions need to be caught and end the queue | |||||
| try: | |||||
| logger.debug("process(%r) process a data(%r)", self.name, self.step) | |||||
| # package the summary event | |||||
| summary_event = package_summary_event(self.data_id, self.step) | |||||
| # send the event to file | |||||
| self._write_summary(summary_event) | |||||
| except Exception as e: | |||||
| logger.error("Summary data mq consumer exception occurred, value = %r", e) | |||||
| def _write_summary(self, summary_event): | |||||
| """ | |||||
| Write the summary to event file. | |||||
| Note: | |||||
| The write record format: | |||||
| 1 uint64 : data length. | |||||
| 2 uint32 : mask crc value of data length. | |||||
| 3 bytes : data. | |||||
| 4 uint32 : mask crc value of data. | |||||
| Args: | |||||
| summary_event (Event): The summary event of proto. | |||||
| """ | |||||
| event_str = summary_event.SerializeToString() | |||||
| self.write_lock.acquire() | |||||
| self.writer.write_event_to_file(event_str) | |||||
| self.writer.flush() | |||||
| self.write_lock.release() | |||||
| @@ -14,17 +14,22 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Record the summary event.""" | """Record the summary event.""" | ||||
| import os | import os | ||||
| import re | |||||
| import threading | import threading | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from ._summary_scheduler import WorkerScheduler, SummaryDataManager | |||||
| from ._summary_adapter import get_event_file_name, package_graph_event | |||||
| from ._event_writer import EventRecord | |||||
| from .._utils import _make_directory | |||||
| from ..._c_expression import Tensor | |||||
| from ..._checkparam import _check_str_by_regular | from ..._checkparam import _check_str_by_regular | ||||
| from .._utils import _make_directory | |||||
| from ._event_writer import EventWriter | |||||
| from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event | |||||
| # for the moment, this lock is for caution's sake, | |||||
| # there are actually no any concurrencies happening. | |||||
| _summary_lock = threading.Lock() | |||||
| # cache the summary data | # cache the summary data | ||||
| _summary_tensor_cache = {} | _summary_tensor_cache = {} | ||||
| _summary_lock = threading.Lock() | |||||
| def _cache_summary_tensor_data(summary): | def _cache_summary_tensor_data(summary): | ||||
| @@ -34,14 +39,18 @@ def _cache_summary_tensor_data(summary): | |||||
| Args: | Args: | ||||
| summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]. | summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]. | ||||
| """ | """ | ||||
| _summary_lock.acquire() | |||||
| if "SummaryRecord" in _summary_tensor_cache: | |||||
| for record in summary: | |||||
| _summary_tensor_cache["SummaryRecord"].append(record) | |||||
| else: | |||||
| _summary_tensor_cache["SummaryRecord"] = summary | |||||
| _summary_lock.release() | |||||
| return True | |||||
| with _summary_lock: | |||||
| for item in summary: | |||||
| _summary_tensor_cache[item['name']] = item['data'] | |||||
| return True | |||||
| def _get_summary_tensor_data(): | |||||
| global _summary_tensor_cache | |||||
| with _summary_lock: | |||||
| data = _summary_tensor_cache | |||||
| _summary_tensor_cache = {} | |||||
| return data | |||||
| class SummaryRecord: | class SummaryRecord: | ||||
| @@ -53,7 +62,7 @@ class SummaryRecord: | |||||
| It writes the event log to a file by executing the record method. In addition, | It writes the event log to a file by executing the record method. In addition, | ||||
| if the SummaryRecord object is created and the summary operator is used in the network, | if the SummaryRecord object is created and the summary operator is used in the network, | ||||
| even if the record method is not called, the event in the cache will be written to the | even if the record method is not called, the event in the cache will be written to the | ||||
| file at the end of execution or when the summary is closed. | |||||
| file at the end of execution. Make sure to close the SummaryRecord object at the end. | |||||
| Args: | Args: | ||||
| log_dir (str): The log_dir is a directory location to save the summary. | log_dir (str): The log_dir is a directory location to save the summary. | ||||
| @@ -68,9 +77,10 @@ class SummaryRecord: | |||||
| RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname. | RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname. | ||||
| Examples: | Examples: | ||||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||||
| >>> pass | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| log_dir, | log_dir, | ||||
| queue_max_size=0, | queue_max_size=0, | ||||
| @@ -101,26 +111,36 @@ class SummaryRecord: | |||||
| self.prefix = file_prefix | self.prefix = file_prefix | ||||
| self.suffix = file_suffix | self.suffix = file_suffix | ||||
| self.network = network | |||||
| self.has_graph = False | |||||
| self._closed = False | |||||
| # create the summary writer file | # create the summary writer file | ||||
| self.event_file_name = get_event_file_name(self.prefix, self.suffix) | self.event_file_name = get_event_file_name(self.prefix, self.suffix) | ||||
| if self.log_path[-1:] == '/': | |||||
| self.full_file_name = self.log_path + self.event_file_name | |||||
| else: | |||||
| self.full_file_name = self.log_path + '/' + self.event_file_name | |||||
| try: | try: | ||||
| self.full_file_name = os.path.realpath(self.full_file_name) | |||||
| self.full_file_name = os.path.join(self.log_path, self.event_file_name) | |||||
| except Exception as ex: | except Exception as ex: | ||||
| raise RuntimeError(ex) | raise RuntimeError(ex) | ||||
| self.event_writer = EventRecord(self.full_file_name, self.flush_time) | |||||
| self.writer_id = SummaryDataManager.summary_file_set(self.event_writer) | |||||
| self.worker_scheduler = WorkerScheduler(self.writer_id) | |||||
| self.step = 0 | |||||
| self._closed = False | |||||
| self.network = network | |||||
| self.has_graph = False | |||||
| self._event_writer = None | |||||
| def _init_event_writer(self): | |||||
| """Init event writer and write metadata.""" | |||||
| event_writer = EventWriter(self.full_file_name, self.flush_time) | |||||
| event_writer.write(package_init_event().SerializeToString()) | |||||
| return event_writer | |||||
| def __enter__(self): | |||||
| """Enter the context manager.""" | |||||
| if not self._event_writer: | |||||
| self._event_writer = self._init_event_writer() | |||||
| if self._closed: | |||||
| raise ValueError('SummaryRecord has been closed.') | |||||
| return self | |||||
| def __exit__(self, extype, exvalue, traceback): | |||||
| """Exit the context manager.""" | |||||
| self.close() | |||||
| def record(self, step, train_network=None): | def record(self, step, train_network=None): | ||||
| """ | """ | ||||
| @@ -131,9 +151,8 @@ class SummaryRecord: | |||||
| train_network (Cell): The network that called the callback. | train_network (Cell): The network that called the callback. | ||||
| Examples: | Examples: | ||||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||||
| >>> summary_record.record(step=2) | |||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||||
| >>> summary_record.record(step=2) | |||||
| Returns: | Returns: | ||||
| bool, whether the record process is successful or not. | bool, whether the record process is successful or not. | ||||
| @@ -145,42 +164,37 @@ class SummaryRecord: | |||||
| if not isinstance(step, int) or isinstance(step, bool): | if not isinstance(step, int) or isinstance(step, bool): | ||||
| raise ValueError("`step` should be int") | raise ValueError("`step` should be int") | ||||
| # Set the current summary of train step | # Set the current summary of train step | ||||
| self.step = step | |||||
| if not self._event_writer: | |||||
| self._event_writer = self._init_event_writer() | |||||
| logger.warning('SummaryRecord should be used as context manager for a with statement.') | |||||
| if self.network is not None and self.has_graph is False: | |||||
| if self.network is not None and not self.has_graph: | |||||
| graph_proto = self.network.get_func_graph_proto() | graph_proto = self.network.get_func_graph_proto() | ||||
| if graph_proto is None and train_network is not None: | if graph_proto is None and train_network is not None: | ||||
| graph_proto = train_network.get_func_graph_proto() | graph_proto = train_network.get_func_graph_proto() | ||||
| if graph_proto is None: | if graph_proto is None: | ||||
| logger.error("Failed to get proto for graph") | logger.error("Failed to get proto for graph") | ||||
| else: | else: | ||||
| self.event_writer.write_event_to_file( | |||||
| package_graph_event(graph_proto).SerializeToString()) | |||||
| self.event_writer.flush() | |||||
| self._event_writer.write(package_graph_event(graph_proto).SerializeToString()) | |||||
| self.has_graph = True | self.has_graph = True | ||||
| data = _summary_tensor_cache.get("SummaryRecord") | |||||
| if data is None: | |||||
| if not _summary_tensor_cache: | |||||
| return True | return True | ||||
| data = _summary_tensor_cache.get("SummaryRecord") | |||||
| if data is None: | |||||
| logger.error("The step(%r) does not have record data.", self.step) | |||||
| data = _get_summary_tensor_data() | |||||
| if not data: | |||||
| logger.error("The step(%r) does not have record data.", step) | |||||
| return False | return False | ||||
| if self.queue_max_size > 0 and len(data) > self.queue_max_size: | if self.queue_max_size > 0 and len(data) > self.queue_max_size: | ||||
| logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), | logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), | ||||
| self.queue_max_size) | self.queue_max_size) | ||||
| # clean the data of cache | |||||
| del _summary_tensor_cache["SummaryRecord"] | |||||
| # process the data | # process the data | ||||
| self.worker_scheduler.dispatch(self.step, data) | |||||
| # count & flush | |||||
| self.event_writer.count_event() | |||||
| self.event_writer.flush_cycle() | |||||
| logger.debug("Send the summary data to scheduler for saving, step = %d", self.step) | |||||
| result = self._data_convert(data) | |||||
| if not result: | |||||
| logger.error("The step(%r) summary data is invalid.", step) | |||||
| return False | |||||
| self._event_writer.write((result, step)) | |||||
| logger.debug("Send the summary data to scheduler for saving, step = %d", step) | |||||
| return True | return True | ||||
| @property | @property | ||||
| @@ -189,14 +203,13 @@ class SummaryRecord: | |||||
| Get the full path of the log file. | Get the full path of the log file. | ||||
| Examples: | Examples: | ||||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||||
| >>> print(summary_record.log_dir) | |||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||||
| >>> print(summary_record.log_dir) | |||||
| Returns: | Returns: | ||||
| String, the full path of log file. | String, the full path of log file. | ||||
| """ | """ | ||||
| return self.event_writer.full_file_name | |||||
| return self.full_file_name | |||||
| def flush(self): | def flush(self): | ||||
| """ | """ | ||||
| @@ -205,39 +218,64 @@ class SummaryRecord: | |||||
| Call it to make sure that all pending events have been written to disk. | Call it to make sure that all pending events have been written to disk. | ||||
| Examples: | Examples: | ||||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||||
| >>> summary_record.flush() | |||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||||
| >>> summary_record.flush() | |||||
| """ | """ | ||||
| if self._closed: | if self._closed: | ||||
| logger.error("The record writer is closed and can not flush.") | logger.error("The record writer is closed and can not flush.") | ||||
| else: | |||||
| self.event_writer.flush() | |||||
| elif self._event_writer: | |||||
| self._event_writer.flush() | |||||
| def close(self): | def close(self): | ||||
| """ | """ | ||||
| Flush all events and close summary records. | |||||
| Flush all events and close summary records. Please use with statement to autoclose. | |||||
| Examples: | Examples: | ||||
| >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, | |||||
| >>> file_prefix="xxx_", file_suffix="_yyy") | |||||
| >>> summary_record.close() | |||||
| >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: | |||||
| >>> pass # summary_record autoclosed | |||||
| """ | """ | ||||
| if not self._closed: | |||||
| self._check_data_before_close() | |||||
| self.worker_scheduler.close() | |||||
| if not self._closed and self._event_writer: | |||||
| # event writer flush and close | # event writer flush and close | ||||
| self.event_writer.close() | |||||
| self._event_writer.close() | |||||
| self._closed = True | self._closed = True | ||||
| def __del__(self): | |||||
| """Process exit is called.""" | |||||
| if hasattr(self, "worker_scheduler"): | |||||
| if self.worker_scheduler: | |||||
| self.close() | |||||
| def _check_data_before_close(self): | |||||
| "Check whether there is any data in the cache, and if so, call record" | |||||
| data = _summary_tensor_cache.get("SummaryRecord") | |||||
| if data is not None: | |||||
| self.record(self.step) | |||||
| def __del__(self) -> None: | |||||
| self.close() | |||||
| def _data_convert(self, summary): | |||||
| """Convert the data.""" | |||||
| # convert the summary to numpy | |||||
| result = [] | |||||
| for name, data in summary.items(): | |||||
| # confirm the data is valid | |||||
| summary_tag, summary_type = SummaryRecord._parse_from(name) | |||||
| if summary_tag is None: | |||||
| logger.error("The data type is invalid, name = %r, tensor = %r", name, data) | |||||
| return None | |||||
| if isinstance(data, Tensor): | |||||
| result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type}) | |||||
| else: | |||||
| logger.error("The data type is invalid, name = %r, tensor = %r", name, data) | |||||
| return None | |||||
| return result | |||||
| @staticmethod | |||||
| def _parse_from(name: str = None): | |||||
| """ | |||||
| Parse the tag and type from name. | |||||
| Args: | |||||
| name (str): Format: TAG[:TYPE]. | |||||
| Returns: | |||||
| Tuple, (summary_tag, summary_type). | |||||
| """ | |||||
| if name is None: | |||||
| logger.error("The name is None") | |||||
| return None, None | |||||
| match = re.match(r'(.+)\[:(.+)\]', name) | |||||
| if match: | |||||
| return match.groups() | |||||
| logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name) | |||||
| return None, None | |||||
| @@ -53,14 +53,13 @@ def me_train_tensor(net, input_np, label_np, epoch_size=2): | |||||
| _network = wrap.WithLossCell(net, loss) | _network = wrap.WithLossCell(net, loss) | ||||
| _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) | _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) | ||||
| _train_net.set_train() | _train_net.set_train() | ||||
| summary_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) | |||||
| for epoch in range(0, epoch_size): | |||||
| print(f"epoch %d" % (epoch)) | |||||
| output = _train_net(Tensor(input_np), Tensor(label_np)) | |||||
| summary_writer.record(i) | |||||
| print("********output***********") | |||||
| print(output.asnumpy()) | |||||
| summary_writer.close() | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) as summary_writer: | |||||
| for epoch in range(0, epoch_size): | |||||
| print(f"epoch %d" % (epoch)) | |||||
| output = _train_net(Tensor(input_np), Tensor(label_np)) | |||||
| summary_writer.record(i) | |||||
| print("********output***********") | |||||
| print(output.asnumpy()) | |||||
| def me_infer_tensor(net, input_np): | def me_infer_tensor(net, input_np): | ||||
| @@ -91,15 +91,14 @@ def train_summary_record_scalar_for_1(test_writer, steps, fwd_x, fwd_y): | |||||
| def me_scalar_summary(steps, tag=None, value=None): | def me_scalar_summary(steps, tag=None, value=None): | ||||
| test_writer = SummaryRecord(SUMMARY_DIR_ME_TEMP) | |||||
| with SummaryRecord(SUMMARY_DIR_ME_TEMP) as test_writer: | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| out_me_dict = train_summary_record_scalar_for_1(test_writer, steps, x, y) | |||||
| out_me_dict = train_summary_record_scalar_for_1(test_writer, steps, x, y) | |||||
| test_writer.close() | |||||
| return out_me_dict | |||||
| return out_me_dict | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -106,18 +106,17 @@ def test_graph_summary_sample(): | |||||
| optim = Momentum(net.trainable_params(), 0.1, 0.9) | optim = Momentum(net.trainable_params(), 0.1, 0.9) | ||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) | |||||
| model.train(2, dataset) | |||||
| # step 2: create the Event | |||||
| for i in range(1, 5): | |||||
| test_writer.record(i) | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer: | |||||
| model.train(2, dataset) | |||||
| # step 2: create the Event | |||||
| for i in range(1, 5): | |||||
| test_writer.record(i) | |||||
| # step 3: send the event to mq | |||||
| # step 3: send the event to mq | |||||
| # step 4: accept the event and write the file | |||||
| test_writer.close() | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_graph_summary_sample") | |||||
| log.debug("finished test_graph_summary_sample") | |||||
| def test_graph_summary_callback(): | def test_graph_summary_callback(): | ||||
| @@ -127,9 +126,9 @@ def test_graph_summary_callback(): | |||||
| optim = Momentum(net.trainable_params(), 0.1, 0.9) | optim = Momentum(net.trainable_params(), 0.1, 0.9) | ||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) | |||||
| summary_cb = SummaryStep(test_writer, 1) | |||||
| model.train(2, dataset, callbacks=summary_cb) | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer: | |||||
| summary_cb = SummaryStep(test_writer, 1) | |||||
| model.train(2, dataset, callbacks=summary_cb) | |||||
| def test_graph_summary_callback2(): | def test_graph_summary_callback2(): | ||||
| @@ -139,6 +138,6 @@ def test_graph_summary_callback2(): | |||||
| optim = Momentum(net.trainable_params(), 0.1, 0.9) | optim = Momentum(net.trainable_params(), 0.1, 0.9) | ||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) | |||||
| summary_cb = SummaryStep(test_writer, 1) | |||||
| model.train(2, dataset, callbacks=summary_cb) | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) as test_writer: | |||||
| summary_cb = SummaryStep(test_writer, 1) | |||||
| model.train(2, dataset, callbacks=summary_cb) | |||||
| @@ -52,12 +52,11 @@ def _wrap_test_data(input_data: Tensor): | |||||
| def test_histogram_summary(): | def test_histogram_summary(): | ||||
| """Test histogram summary.""" | """Test histogram summary.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]])) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| test_writer.close() | |||||
| test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]])) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | reader = SummaryReader(file_name) | ||||
| @@ -68,20 +67,18 @@ def test_histogram_summary(): | |||||
| def test_histogram_multi_summary(): | def test_histogram_multi_summary(): | ||||
| """Test histogram multiple step.""" | """Test histogram multiple step.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||||
| rng = np.random.RandomState(10) | |||||
| size = 50 | |||||
| num_step = 5 | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| for i in range(num_step): | |||||
| arr = rng.normal(size=size) | |||||
| rng = np.random.RandomState(10) | |||||
| size = 50 | |||||
| num_step = 5 | |||||
| test_data = _wrap_test_data(Tensor(arr)) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=i) | |||||
| for i in range(num_step): | |||||
| arr = rng.normal(size=size) | |||||
| test_writer.close() | |||||
| test_data = _wrap_test_data(Tensor(arr)) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=i) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | reader = SummaryReader(file_name) | ||||
| @@ -93,12 +90,11 @@ def test_histogram_multi_summary(): | |||||
| def test_histogram_summary_scalar_tensor(): | def test_histogram_summary_scalar_tensor(): | ||||
| """Test histogram summary, input is a scalar tensor.""" | """Test histogram summary, input is a scalar tensor.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| test_data = _wrap_test_data(Tensor(1)) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| test_writer.close() | |||||
| test_data = _wrap_test_data(Tensor(1)) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | reader = SummaryReader(file_name) | ||||
| @@ -109,12 +105,11 @@ def test_histogram_summary_scalar_tensor(): | |||||
| def test_histogram_summary_empty_tensor(): | def test_histogram_summary_empty_tensor(): | ||||
| """Test histogram summary, input is an empty tensor.""" | """Test histogram summary, input is an empty tensor.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| test_data = _wrap_test_data(Tensor([])) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| test_writer.close() | |||||
| test_data = _wrap_test_data(Tensor([])) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | reader = SummaryReader(file_name) | ||||
| @@ -125,15 +120,14 @@ def test_histogram_summary_empty_tensor(): | |||||
| def test_histogram_summary_same_value(): | def test_histogram_summary_same_value(): | ||||
| """Test histogram summary, input is an ones tensor.""" | """Test histogram summary, input is an ones tensor.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| dim1 = 100 | |||||
| dim2 = 100 | |||||
| dim1 = 100 | |||||
| dim2 = 100 | |||||
| test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2]))) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| test_writer.close() | |||||
| test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2]))) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | reader = SummaryReader(file_name) | ||||
| @@ -146,15 +140,14 @@ def test_histogram_summary_same_value(): | |||||
| def test_histogram_summary_high_dims(): | def test_histogram_summary_high_dims(): | ||||
| """Test histogram summary, input is a 4-dimension tensor.""" | """Test histogram summary, input is a 4-dimension tensor.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||||
| dim = 10 | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| dim = 10 | |||||
| rng = np.random.RandomState(0) | |||||
| tensor_data = rng.normal(size=[dim, dim, dim, dim]) | |||||
| test_data = _wrap_test_data(Tensor(tensor_data)) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| test_writer.close() | |||||
| rng = np.random.RandomState(0) | |||||
| tensor_data = rng.normal(size=[dim, dim, dim, dim]) | |||||
| test_data = _wrap_test_data(Tensor(tensor_data)) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | reader = SummaryReader(file_name) | ||||
| @@ -167,20 +160,19 @@ def test_histogram_summary_high_dims(): | |||||
| def test_histogram_summary_nan_inf(): | def test_histogram_summary_nan_inf(): | ||||
| """Test histogram summary, input tensor has nan.""" | """Test histogram summary, input tensor has nan.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| dim1 = 100 | |||||
| dim2 = 100 | |||||
| dim1 = 100 | |||||
| dim2 = 100 | |||||
| arr = np.ones([dim1, dim2]) | |||||
| arr[0][0] = np.nan | |||||
| arr[0][1] = np.inf | |||||
| arr[0][2] = -np.inf | |||||
| test_data = _wrap_test_data(Tensor(arr)) | |||||
| arr = np.ones([dim1, dim2]) | |||||
| arr[0][0] = np.nan | |||||
| arr[0][1] = np.inf | |||||
| arr[0][2] = -np.inf | |||||
| test_data = _wrap_test_data(Tensor(arr)) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| test_writer.close() | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | reader = SummaryReader(file_name) | ||||
| @@ -193,12 +185,11 @@ def test_histogram_summary_nan_inf(): | |||||
| def test_histogram_summary_all_nan_inf(): | def test_histogram_summary_all_nan_inf(): | ||||
| """Test histogram summary, input tensor has no valid number.""" | """Test histogram summary, input tensor has no valid number.""" | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") | |||||
| with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: | |||||
| test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf]))) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| test_writer.close() | |||||
| test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf]))) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(step=1) | |||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | reader = SummaryReader(file_name) | ||||
| @@ -74,23 +74,21 @@ def test_image_summary_sample(): | |||||
| """ test_image_summary_sample """ | """ test_image_summary_sample """ | ||||
| log.debug("begin test_image_summary_sample") | log.debug("begin test_image_summary_sample") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||||
| # step 1: create the test data for summary | |||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| for i in range(1, 5): | |||||
| test_data = get_test_data(i) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(i) | |||||
| test_writer.flush() | |||||
| # step 2: create the Event | |||||
| for i in range(1, 5): | |||||
| test_data = get_test_data(i) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(i) | |||||
| test_writer.flush() | |||||
| # step 3: send the event to mq | |||||
| # step 3: send the event to mq | |||||
| # step 4: accept the event and write the file | |||||
| test_writer.close() | |||||
| log.debug("finished test_image_summary_sample") | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_image_summary_sample") | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| @@ -174,23 +172,21 @@ def test_image_summary_train(): | |||||
| log.debug("begin test_image_summary_sample") | log.debug("begin test_image_summary_sample") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") | |||||
| # step 1: create the test data for summary | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||||
| # step 2: create the Event | |||||
| # step 1: create the test data for summary | |||||
| model = get_model() | |||||
| fn = ImageSummaryCallback(test_writer) | |||||
| summary_recode = SummaryStep(fn, 1) | |||||
| model.train(2, dataset, callbacks=summary_recode) | |||||
| # step 2: create the Event | |||||
| # step 3: send the event to mq | |||||
| model = get_model() | |||||
| fn = ImageSummaryCallback(test_writer) | |||||
| summary_recode = SummaryStep(fn, 1) | |||||
| model.train(2, dataset, callbacks=summary_recode) | |||||
| # step 4: accept the event and write the file | |||||
| test_writer.close() | |||||
| # step 3: send the event to mq | |||||
| log.debug("finished test_image_summary_sample") | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_image_summary_sample") | |||||
| def test_image_summary_data(): | def test_image_summary_data(): | ||||
| @@ -209,18 +205,12 @@ def test_image_summary_data(): | |||||
| log.debug("begin test_image_summary_sample") | log.debug("begin test_image_summary_sample") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") | |||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| _cache_summary_tensor_data(test_data_list) | |||||
| test_writer.record(1) | |||||
| test_writer.flush() | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: | |||||
| # step 3: send the event to mq | |||||
| # step 1: create the test data for summary | |||||
| # step 4: accept the event and write the file | |||||
| test_writer.close() | |||||
| # step 2: create the Event | |||||
| _cache_summary_tensor_data(test_data_list) | |||||
| test_writer.record(1) | |||||
| log.debug("finished test_image_summary_sample") | |||||
| log.debug("finished test_image_summary_sample") | |||||
| @@ -65,22 +65,21 @@ def test_scalar_summary_sample(): | |||||
| """ test_scalar_summary_sample """ | """ test_scalar_summary_sample """ | ||||
| log.debug("begin test_scalar_summary_sample") | log.debug("begin test_scalar_summary_sample") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||||
| # step 1: create the test data for summary | |||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| for i in range(1, 500): | |||||
| test_data = get_test_data(i) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(i) | |||||
| # step 2: create the Event | |||||
| for i in range(1, 500): | |||||
| test_data = get_test_data(i) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(i) | |||||
| # step 3: send the event to mq | |||||
| # step 3: send the event to mq | |||||
| # step 4: accept the event and write the file | |||||
| test_writer.close() | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_scalar_summary_sample") | |||||
| log.debug("finished test_scalar_summary_sample") | |||||
| def get_test_data_shape_1(step): | def get_test_data_shape_1(step): | ||||
| @@ -110,22 +109,21 @@ def test_scalar_summary_sample_with_shape_1(): | |||||
| """ test_scalar_summary_sample_with_shape_1 """ | """ test_scalar_summary_sample_with_shape_1 """ | ||||
| log.debug("begin test_scalar_summary_sample_with_shape_1") | log.debug("begin test_scalar_summary_sample_with_shape_1") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||||
| # step 1: create the test data for summary | |||||
| # step 1: create the test data for summary | |||||
| # step 2: create the Event | |||||
| for i in range(1, 100): | |||||
| test_data = get_test_data_shape_1(i) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(i) | |||||
| # step 2: create the Event | |||||
| for i in range(1, 100): | |||||
| test_data = get_test_data_shape_1(i) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(i) | |||||
| # step 3: send the event to mq | |||||
| # step 3: send the event to mq | |||||
| # step 4: accept the event and write the file | |||||
| test_writer.close() | |||||
| # step 4: accept the event and write the file | |||||
| log.debug("finished test_scalar_summary_sample") | |||||
| log.debug("finished test_scalar_summary_sample") | |||||
| # Test: test with ge | # Test: test with ge | ||||
| @@ -152,26 +150,24 @@ def test_scalar_summary_with_ge(): | |||||
| log.debug("begin test_scalar_summary_with_ge") | log.debug("begin test_scalar_summary_with_ge") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net = SummaryDemo() | |||||
| net.set_train() | |||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net = SummaryDemo() | |||||
| net.set_train() | |||||
| # step 2: create the Event | |||||
| steps = 100 | |||||
| for i in range(1, steps): | |||||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||||
| net(x, y) | |||||
| test_writer.record(i) | |||||
| # step 2: create the Event | |||||
| steps = 100 | |||||
| for i in range(1, steps): | |||||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||||
| net(x, y) | |||||
| test_writer.record(i) | |||||
| # step 3: close the writer | |||||
| test_writer.close() | |||||
| log.debug("finished test_scalar_summary_with_ge") | |||||
| log.debug("finished test_scalar_summary_with_ge") | |||||
| # test the problem of two consecutive use cases going wrong | # test the problem of two consecutive use cases going wrong | ||||
| @@ -180,55 +176,52 @@ def test_scalar_summary_with_ge_2(): | |||||
| log.debug("begin test_scalar_summary_with_ge_2") | log.debug("begin test_scalar_summary_with_ge_2") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") | |||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net = SummaryDemo() | |||||
| net.set_train() | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: | |||||
| # step 2: create the Event | |||||
| steps = 100 | |||||
| for i in range(1, steps): | |||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | x = Tensor(np.array([1.1]).astype(np.float32)) | ||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | y = Tensor(np.array([1.2]).astype(np.float32)) | ||||
| net(x, y) | |||||
| test_writer.record(i) | |||||
| net = SummaryDemo() | |||||
| net.set_train() | |||||
| # step 3: close the writer | |||||
| test_writer.close() | |||||
| # step 2: create the Event | |||||
| steps = 100 | |||||
| for i in range(1, steps): | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net(x, y) | |||||
| test_writer.record(i) | |||||
| log.debug("finished test_scalar_summary_with_ge_2") | |||||
| log.debug("finished test_scalar_summary_with_ge_2") | |||||
| def test_validate(): | |||||
| sr = SummaryRecord(SUMMARY_DIR) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, 0) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, -1) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, 1.2) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, True) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, "str") | |||||
| sr.record(1) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(False) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(2.0) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record((1, 3)) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record([2, 3]) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record("str") | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(sr) | |||||
| sr.close() | |||||
| def test_validate(): | |||||
| with SummaryRecord(SUMMARY_DIR) as sr: | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, 0) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, -1) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, 1.2) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, True) | |||||
| with pytest.raises(ValueError): | |||||
| SummaryStep(sr, "str") | |||||
| sr.record(1) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(False) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(2.0) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record((1, 3)) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record([2, 3]) | |||||
| with pytest.raises(ValueError): | |||||
| sr.record("str") | |||||
| with pytest.raises(ValueError): | |||||
| sr.record(sr) | |||||
| SummaryStep(sr, 1) | SummaryStep(sr, 1) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| @@ -126,23 +126,21 @@ class HistogramSummaryNet(nn.Cell): | |||||
| def run_case(net): | def run_case(net): | ||||
| """ run_case """ | """ run_case """ | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR) | |||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net.set_train() | |||||
| # step 2: create the Event | |||||
| steps = 100 | |||||
| for i in range(1, steps): | |||||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||||
| net(x, y) | |||||
| test_writer.record(i) | |||||
| # step 3: close the writer | |||||
| test_writer.close() | |||||
| with SummaryRecord(SUMMARY_DIR) as test_writer: | |||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net.set_train() | |||||
| # step 2: create the Event | |||||
| steps = 100 | |||||
| for i in range(1, steps): | |||||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||||
| net(x, y) | |||||
| test_writer.record(i) | |||||
| # Test 1: use the repeat tag | # Test 1: use the repeat tag | ||||
| @@ -80,19 +80,18 @@ def test_tensor_summary_sample(): | |||||
| """ test_tensor_summary_sample """ | """ test_tensor_summary_sample """ | ||||
| log.debug("begin test_tensor_summary_sample") | log.debug("begin test_tensor_summary_sample") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_TENSOR") | |||||
| with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_TENSOR") as test_writer: | |||||
| # step 1: create the Event | |||||
| for i in range(1, 100): | |||||
| test_data = get_test_data(i) | |||||
| # step 1: create the Event | |||||
| for i in range(1, 100): | |||||
| test_data = get_test_data(i) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(i) | |||||
| _cache_summary_tensor_data(test_data) | |||||
| test_writer.record(i) | |||||
| # step 2: accept the event and write the file | |||||
| test_writer.close() | |||||
| # step 2: accept the event and write the file | |||||
| log.debug("finished test_tensor_summary_sample") | |||||
| log.debug("finished test_tensor_summary_sample") | |||||
| def get_test_data_check(step): | def get_test_data_check(step): | ||||
| @@ -131,23 +130,20 @@ def test_tensor_summary_with_ge(): | |||||
| log.debug("begin test_tensor_summary_with_ge") | log.debug("begin test_tensor_summary_with_ge") | ||||
| # step 0: create the thread | # step 0: create the thread | ||||
| test_writer = SummaryRecord(SUMMARY_DIR) | |||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net = SummaryDemo() | |||||
| net.set_train() | |||||
| # step 2: create the Event | |||||
| steps = 100 | |||||
| for i in range(1, steps): | |||||
| x = Tensor(np.array([[i], [i]]).astype(np.float32)) | |||||
| y = Tensor(np.array([[i + 1], [i + 1]]).astype(np.float32)) | |||||
| net(x, y) | |||||
| test_writer.record(i) | |||||
| # step 3: close the writer | |||||
| test_writer.close() | |||||
| log.debug("finished test_tensor_summary_with_ge") | |||||
| with SummaryRecord(SUMMARY_DIR) as test_writer: | |||||
| # step 1: create the network for summary | |||||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||||
| net = SummaryDemo() | |||||
| net.set_train() | |||||
| # step 2: create the Event | |||||
| steps = 100 | |||||
| for i in range(1, steps): | |||||
| x = Tensor(np.array([[i], [i]]).astype(np.float32)) | |||||
| y = Tensor(np.array([[i + 1], [i + 1]]).astype(np.float32)) | |||||
| net(x, y) | |||||
| test_writer.record(i) | |||||
| log.debug("finished test_tensor_summary_with_ge") | |||||