From 1b3a0277b76be67e0d064a90ea653efd2490c686 Mon Sep 17 00:00:00 2001 From: guozhijian <5719707+jonyguo@user.noreply.gitee.com> Date: Mon, 11 May 2020 19:05:59 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!713=20:?= =?UTF-8?q?=20Use=20a=20resident=20process=20to=20write=20summary=20files'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspore/train/summary/_event_writer.py | 136 ++++---- mindspore/train/summary/_summary_adapter.py | 203 ++++++++---- mindspore/train/summary/_summary_scheduler.py | 308 ++++++++++++++++++ mindspore/train/summary/summary_record.py | 140 ++++---- 4 files changed, 591 insertions(+), 196 deletions(-) create mode 100644 mindspore/train/summary/_summary_scheduler.py diff --git a/mindspore/train/summary/_event_writer.py b/mindspore/train/summary/_event_writer.py index fa3b09f5ee..c04308dcbc 100644 --- a/mindspore/train/summary/_event_writer.py +++ b/mindspore/train/summary/_event_writer.py @@ -14,77 +14,91 @@ # ============================================================================ """Writes events to disk in a logdir.""" import os +import time import stat -from collections import deque -from multiprocessing import Pool, Process, Queue, cpu_count - +from mindspore import log as logger from ..._c_expression import EventWriter_ -from ._summary_adapter import package_summary_event +from ._summary_adapter import package_init_event -def _pack(result, step): - summary_event = package_summary_event(result, step) - return summary_event.SerializeToString() +class _WrapEventWriter(EventWriter_): + """ + Wrap the c++ EventWriter object. + Args: + full_file_name (str): Include directory and file name. + """ + def __init__(self, full_file_name): + if full_file_name is not None: + EventWriter_.__init__(self, full_file_name) -class EventWriter(Process): + +class EventRecord: """ - Creates a `EventWriter` and write event to file. + Creates a `EventFileWriter` and write event to file. Args: - filepath (str): Summary event file path and file name. - flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120. + full_file_name (str): Summary event file path and file name. + flush_time (int): The flush seconds to flush the pending events to disk. Default: 120. """ - - def __init__(self, filepath: str, flush_interval: int) -> None: - super().__init__() - with open(filepath, 'w'): - os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR) - self._writer = EventWriter_(filepath) - self._queue = Queue(cpu_count() * 2) - self.start() - - def run(self): - - with Pool() as pool: - deq = deque() - while True: - while deq and deq[0].ready(): - self._writer.Write(deq.popleft().get()) - - if not self._queue.empty(): - action, data = self._queue.get() - if action == 'WRITE': - if not isinstance(data, (str, bytes)): - deq.append(pool.apply_async(_pack, data)) - else: - self._writer.Write(data) - elif action == 'FLUSH': - self._writer.Flush() - elif action == 'END': - break - for res in deq: - self._writer.Write(res.get()) - - self._writer.Shut() - - def write(self, data) -> None: - """ - Write the event to file. - - Args: - data (Optional[str, Tuple[list, int]]): The data to write. - """ - self._queue.put(('WRITE', data)) + def __init__(self, full_file_name: str, flush_time: int = 120): + self.full_file_name = full_file_name + + # The first event will be flushed immediately. + self.flush_time = flush_time + self.next_flush_time = 0 + + # create event write object + self.event_writer = self._create_event_file() + self._init_event_file() + + # count the events + self.event_count = 0 + + def _create_event_file(self): + """Create the event write file.""" + with open(self.full_file_name, 'w'): + os.chmod(self.full_file_name, stat.S_IWUSR | stat.S_IRUSR) + + # create c++ event write object + event_writer = _WrapEventWriter(self.full_file_name) + return event_writer + + def _init_event_file(self): + """Send the init event to file.""" + self.event_writer.Write((package_init_event()).SerializeToString()) + self.flush() + return True + + def write_event_to_file(self, event_str): + """Write the event to file.""" + self.event_writer.Write(event_str) + + def get_data_count(self): + """Return the event count.""" + return self.event_count + + def flush_cycle(self): + """Flush file by timer.""" + self.event_count = self.event_count + 1 + # Flush the event writer every so often. + now = int(time.time()) + if now > self.next_flush_time: + self.flush() + # update the flush time + self.next_flush_time = now + self.flush_time + + def count_event(self): + """Count event.""" + logger.debug("Write the event count is %r", self.event_count) + self.event_count = self.event_count + 1 + return self.event_count def flush(self): - """Flush the writer.""" - self._queue.put(('FLUSH', None)) - - def close(self) -> None: - """Close the writer.""" - self._queue.put(('END', None)) - self.join() + """Flush the event file to disk.""" + self.event_writer.Flush() - def __del__(self) -> None: - self.close() + def close(self): + """Flush the event file to disk and close the file.""" + self.flush() + self.event_writer.Shut() diff --git a/mindspore/train/summary/_summary_adapter.py b/mindspore/train/summary/_summary_adapter.py index 1cfde39b83..9669d0f054 100644 --- a/mindspore/train/summary/_summary_adapter.py +++ b/mindspore/train/summary/_summary_adapter.py @@ -13,17 +13,17 @@ # limitations under the License. # ============================================================================ """Generate the summary event which conform to proto format.""" -import socket import time - +import socket +import math +from enum import Enum, unique import numpy as np from PIL import Image from mindspore import log as logger - -from ..._checkparam import _check_str_by_regular -from ..anf_ir_pb2 import DataType, ModelProto from ..summary_pb2 import Event +from ..anf_ir_pb2 import ModelProto, DataType +from ..._checkparam import _check_str_by_regular # define the MindSpore image format MS_IMAGE_TENSOR_FORMAT = 'NCHW' @@ -32,6 +32,55 @@ EVENT_FILE_NAME_MARK = ".out.events.summary." # Set the init event of version and mark EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:" EVENT_FILE_INIT_VERSION = 1 +# cache the summary data dict +# {id: SummaryData} +# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...] +g_summary_data_dict = {} + +def save_summary_data(data_id, data): + """Save the global summary cache.""" + global g_summary_data_dict + g_summary_data_dict[data_id] = data + + +def del_summary_data(data_id): + """Save the global summary cache.""" + global g_summary_data_dict + if data_id in g_summary_data_dict: + del g_summary_data_dict[data_id] + else: + logger.warning("Can't del the data because data_id(%r) " + "does not have data in g_summary_data_dict", data_id) + +def get_summary_data(data_id): + """Save the global summary cache.""" + ret = None + global g_summary_data_dict + if data_id in g_summary_data_dict: + ret = g_summary_data_dict.get(data_id) + else: + logger.warning("The data_id(%r) does not have data in g_summary_data_dict", data_id) + return ret + +@unique +class SummaryType(Enum): + """ + Summary type. + + Args: + SCALAR (Number): Summary Scalar enum. + TENSOR (Number): Summary TENSOR enum. + IMAGE (Number): Summary image enum. + GRAPH (Number): Summary graph enum. + HISTOGRAM (Number): Summary histogram enum. + INVALID (Number): Unknow type. + """ + SCALAR = 1 # Scalar summary + TENSOR = 2 # Tensor summary + IMAGE = 3 # Image summary + GRAPH = 4 # graph + HISTOGRAM = 5 # Histogram Summary + INVALID = 0xFF # unknow type def get_event_file_name(prefix, suffix): @@ -89,7 +138,7 @@ def package_graph_event(data): return graph_event -def package_summary_event(data_list, step): +def package_summary_event(data_id, step): """ Package the summary to event protobuffer. @@ -100,37 +149,50 @@ def package_summary_event(data_list, step): Returns: Summary, the summary event. """ + data_list = get_summary_data(data_id) + if data_list is None: + logger.error("The step(%r) does not have record data.", step) + del_summary_data(data_id) # create the event of summary summary_event = Event() summary = summary_event.summary - summary_event.wall_time = time.time() - summary_event.step = int(step) for value in data_list: - summary_type = value["_type"] - data = value["data"] tag = value["name"] + data = value["data"] + summary_type = value["type"] - logger.debug("Now process %r summary, tag = %r", summary_type, tag) - - summary_value = summary.value.add() - summary_value.tag = tag # get the summary type and parse the tag - if summary_type == 'Scalar': + if summary_type is SummaryType.SCALAR: + logger.debug("Now process Scalar summary, tag = %r", tag) + summary_value = summary.value.add() + summary_value.tag = tag summary_value.scalar_value = _get_scalar_summary(tag, data) - elif summary_type == 'Tensor': + elif summary_type is SummaryType.TENSOR: + logger.debug("Now process Tensor summary, tag = %r", tag) + summary_value = summary.value.add() + summary_value.tag = tag summary_tensor = summary_value.tensor _get_tensor_summary(tag, data, summary_tensor) - elif summary_type == 'Image': + elif summary_type is SummaryType.IMAGE: + logger.debug("Now process Image summary, tag = %r", tag) + summary_value = summary.value.add() + summary_value.tag = tag summary_image = summary_value.image _get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT) - elif summary_type == 'Histogram': + elif summary_type is SummaryType.HISTOGRAM: + logger.debug("Now process Histogram summary, tag = %r", tag) + summary_value = summary.value.add() + summary_value.tag = tag summary_histogram = summary_value.histogram _fill_histogram_summary(tag, data, summary_histogram) else: # The data is invalid ,jump the data - logger.error("Summary type(%r) is error, tag = %r", summary_type, tag) + logger.error("Summary type is error, tag = %r", tag) + continue + summary_event.wall_time = time.time() + summary_event.step = int(step) return summary_event @@ -193,11 +255,11 @@ def _get_scalar_summary(tag: str, np_value): # So consider the dim = 1, shape = (1,) tensor is scalar scalar_value = np_value[0] if np_value.shape != (1,): - logger.error("The tensor is not Scalar, tag = %r, Shape = %r", tag, np_value.shape) + logger.error("The tensor is not Scalar, tag = %r, Value = %r", tag, np_value) else: np_list = np_value.reshape(-1).tolist() scalar_value = np_list[0] - logger.error("The value is not Scalar, tag = %r, ndim = %r", tag, np_value.ndim) + logger.error("The value is not Scalar, tag = %r, Value = %r", tag, np_value) logger.debug("The tag(%r) value is: %r", tag, scalar_value) return scalar_value @@ -245,7 +307,8 @@ def _calc_histogram_bins(count): Returns: int, number of histogram bins. """ - max_bins, max_per_bin = 90, 10 + number_per_bucket = 10 + max_bins = 90 if not count: return 1 @@ -255,50 +318,78 @@ def _calc_histogram_bins(count): return 3 if count <= 880: # note that math.ceil(881/10) + 1 equals 90 - return count // max_per_bin + 1 + return int(math.ceil(count / number_per_bucket) + 1) return max_bins -def _fill_histogram_summary(tag: str, np_value: np.ndarray, summary) -> None: +def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None: """ Package the histogram summary. Args: tag (str): Summary tag describe. - np_value (np.ndarray): Summary data. - summary (summary_pb2.Summary.Histogram): Summary histogram data. + np_value (np.array): Summary data. + summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data. """ logger.debug("Set(%r) the histogram summary value", tag) # Default bucket for tensor with no valid data. - ma_value = np.ma.masked_invalid(np_value) - total, valid = np_value.size, ma_value.count() - invalids = [] - for isfn in np.isnan, np.isposinf, np.isneginf: - if total - valid > sum(invalids): - count = np.count_nonzero(isfn(np_value)) - invalids.append(count) - else: - invalids.append(0) + default_bucket_left = -0.5 + default_bucket_width = 1.0 - summary.count = total - summary.nan_count, summary.pos_inf_count, summary.neg_inf_count = invalids - if not valid: - logger.warning('There are no valid values in the ndarray(size=%d, shape=%d)', total, np_value.shape) - # summary.{min, max, sum} are 0s by default, no need to explicitly set - else: - summary.min = ma_value.min() - summary.max = ma_value.max() - summary.sum = ma_value.sum() - bins = _calc_histogram_bins(valid) - range_ = summary.min, summary.max - hists, edges = np.histogram(np_value, bins=bins, range=range_) + if np_value.size == 0: + bucket = summary_histogram.buckets.add() + bucket.left = default_bucket_left + bucket.width = default_bucket_width + bucket.count = 0 + + summary_histogram.nan_count = 0 + summary_histogram.pos_inf_count = 0 + summary_histogram.neg_inf_count = 0 + + summary_histogram.max = 0 + summary_histogram.min = 0 + summary_histogram.sum = 0 + + summary_histogram.count = 0 + + return + + summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value)) + summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value)) + summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value)) + summary_histogram.count = np_value.size + + masked_value = np.ma.masked_invalid(np_value) + tensor_max = masked_value.max() + tensor_min = masked_value.min() + tensor_sum = masked_value.sum() + + # No valid value in tensor. + if tensor_max is np.ma.masked: + bucket = summary_histogram.buckets.add() + bucket.left = default_bucket_left + bucket.width = default_bucket_width + bucket.count = 0 + + summary_histogram.max = np.nan + summary_histogram.min = np.nan + summary_histogram.sum = 0 + + return + + bin_number = _calc_histogram_bins(masked_value.count()) + counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max)) + + for ind, count in enumerate(counts): + bucket = summary_histogram.buckets.add() + bucket.left = edges[ind] + bucket.width = edges[ind + 1] - edges[ind] + bucket.count = count - for hist, edge1, edge2 in zip(hists, edges, edges[1:]): - bucket = summary.buckets.add() - bucket.width = edge2 - edge1 - bucket.count = hist - bucket.left = edge1 + summary_histogram.max = tensor_max + summary_histogram.min = tensor_min + summary_histogram.sum = tensor_sum def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): @@ -316,7 +407,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): """ logger.debug("Set(%r) the image summary value", tag) if np_value.ndim != 4: - logger.error("The value is not Image, tag = %r, ndim = %r", tag, np_value.ndim) + logger.error("The value is not Image, tag = %r, Value = %r", tag, np_value) # convert the tensor format tensor = _convert_image_format(np_value, input_format) @@ -378,8 +469,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'): """ out_tensor = None if np_tensor.ndim != len(input_format): - logger.error("The tensor with dim(%r) can't convert the format(%r) because dim not same", np_tensor.ndim, - input_format) + logger.error("The tensor(%r) can't convert the format(%r) because dim not same", + np_tensor, input_format) return out_tensor input_format = input_format.upper() @@ -421,7 +512,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8): # check the tensor format if tensor.ndim != 4 or tensor.shape[1] != 3: - logger.error("The image tensor with ndim(%r) and shape(%r) is not 'NCHW' format", tensor.ndim, tensor.shape) + logger.error("The image tensor(%r) is not 'NCHW' format", tensor) return out_canvas # expand the N diff --git a/mindspore/train/summary/_summary_scheduler.py b/mindspore/train/summary/_summary_scheduler.py new file mode 100644 index 0000000000..3327b02fa7 --- /dev/null +++ b/mindspore/train/summary/_summary_scheduler.py @@ -0,0 +1,308 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Schedule the event writer process.""" +import multiprocessing as mp +from enum import Enum, unique +from mindspore import log as logger +from ..._c_expression import Tensor +from ._summary_adapter import SummaryType, package_summary_event, save_summary_data + +# define the type of summary +FORMAT_SCALAR_STR = "Scalar" +FORMAT_TENSOR_STR = "Tensor" +FORMAT_IMAGE_STR = "Image" +FORMAT_HISTOGRAM_STR = "Histogram" +FORMAT_BEGIN_SLICE = "[:" +FORMAT_END_SLICE = "]" + +# cache the summary data dict +# {id: SummaryData} +# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...] +g_summary_data_id = 0 +g_summary_data_dict = {} +# cache the summary data file +g_summary_writer_id = 0 +g_summary_file = {} + + +@unique +class ScheduleMethod(Enum): + """Schedule method type.""" + FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue + TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy) + CACHE_DATA = 2 # Cache data util have idle worker to process it + + +@unique +class WorkerStatus(Enum): + """Worker status.""" + WORKER_INIT = 0 # data is exist but not process + WORKER_PROCESSING = 1 # data is processing + WORKER_PROCESSED = 2 # data already processed + + +def _parse_tag_format(tag: str): + """ + Parse the tag. + + Args: + tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor]. + + Returns: + Tuple, (SummaryType, summary_tag). + """ + + summary_type = SummaryType.INVALID + summary_tag = tag + if tag is None: + logger.error("The tag is None") + return summary_type, summary_tag + + # search the slice + slice_begin = FORMAT_BEGIN_SLICE + slice_end = FORMAT_END_SLICE + index = tag.rfind(slice_begin) + if index is -1: + logger.error("The tag(%s) have not the key slice.", tag) + return summary_type, summary_tag + + # slice the tag + summary_tag = tag[:index] + + # check the slice end + if tag[-1:] != slice_end: + logger.error("The tag(%s) end format is error", tag) + return summary_type, summary_tag + + # check the type + type_str = tag[index + 2: -1] + logger.debug("The summary_tag is = %r", summary_tag) + logger.debug("The type_str value is = %r", type_str) + if type_str == FORMAT_SCALAR_STR: + summary_type = SummaryType.SCALAR + elif type_str == FORMAT_TENSOR_STR: + summary_type = SummaryType.TENSOR + elif type_str == FORMAT_IMAGE_STR: + summary_type = SummaryType.IMAGE + elif type_str == FORMAT_HISTOGRAM_STR: + summary_type = SummaryType.HISTOGRAM + else: + logger.error("The tag(%s) type is invalid.", tag) + summary_type = SummaryType.INVALID + + return summary_type, summary_tag + + +class SummaryDataManager: + """Manage the summary global data cache.""" + def __init__(self): + global g_summary_data_dict + self.size = len(g_summary_data_dict) + + @classmethod + def summary_data_save(cls, data): + """Save the global summary cache.""" + global g_summary_data_id + data_id = g_summary_data_id + save_summary_data(data_id, data) + g_summary_data_id += 1 + return data_id + + @classmethod + def summary_file_set(cls, event_writer): + """Support the many event_writer.""" + global g_summary_file, g_summary_writer_id + g_summary_writer_id += 1 + g_summary_file[g_summary_writer_id] = event_writer + return g_summary_writer_id + + @classmethod + def summary_file_get(cls, writer_id=1): + ret = None + global g_summary_file + if writer_id in g_summary_file: + ret = g_summary_file.get(writer_id) + return ret + + +class WorkerScheduler: + """ + Create worker and schedule data to worker. + + Args: + writer_id (int): The index of writer. + """ + def __init__(self, writer_id): + # Create the process of write event file + self.write_lock = mp.Lock() + # Schedule info for all worker + # Format: {worker: (step, WorkerStatus)} + self.schedule_table = {} + # write id + self.writer_id = writer_id + self.has_graph = False + + def dispatch(self, step, data): + """ + Select schedule strategy and dispatch data. + + Args: + step (Number): The number of step index. + data (Object): The data of recode for summary. + + Retruns: + bool, run successfully or not. + """ + # save the data to global cache , convert the tensor to numpy + result, size, data = self._data_convert(data) + if result is False: + logger.error("The step(%r) summary data(%r) is invalid.", step, size) + return False + + data_id = SummaryDataManager.summary_data_save(data) + self._start_worker(step, data_id) + return True + + def _start_worker(self, step, data_id): + """ + Start worker. + + Args: + step (Number): The index of recode. + data_id (str): The id of work. + + Return: + bool, run successfully or not. + """ + # assign the worker + policy = self._make_policy() + if policy == ScheduleMethod.TEMP_WORKER: + worker = SummaryDataProcess(step, data_id, self.write_lock, self.writer_id) + # update the schedule table + self.schedule_table[worker] = (step, data_id, WorkerStatus.WORKER_INIT) + # start the worker + worker.start() + else: + logger.error("Do not support the other scheduler policy now.") + + # update the scheduler infor + self._update_scheduler() + return True + + def _data_convert(self, data_list): + """Convert the data.""" + if data_list is None: + logger.warning("The step does not have record data.") + return False, 0, None + + # convert the summary to numpy + size = 0 + for v_dict in data_list: + tag = v_dict["name"] + data = v_dict["data"] + # confirm the data is valid + summary_type, summary_tag = _parse_tag_format(tag) + if summary_type == SummaryType.INVALID: + logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data) + return False, 0, None + if isinstance(data, Tensor): + # get the summary type and parse the tag + v_dict["name"] = summary_tag + v_dict["type"] = summary_type + v_dict["data"] = data.asnumpy() + size += v_dict["data"].size + else: + logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data) + return False, 0, None + + return True, size, data_list + + def _update_scheduler(self): + """Check the worker status and update schedule table.""" + workers = list(self.schedule_table.keys()) + for worker in workers: + if not worker.is_alive(): + # update the table + worker.join() + del self.schedule_table[worker] + + def close(self): + """Confirm all worker is end.""" + workers = self.schedule_table.keys() + for worker in workers: + if worker.is_alive(): + worker.join() + + def _make_policy(self): + """Select the schedule strategy by data.""" + # now only support the temp worker + return ScheduleMethod.TEMP_WORKER + + +class SummaryDataProcess(mp.Process): + """ + Process that consume the summarydata. + + Args: + step (int): The index of step. + data_id (int): The index of summary data. + write_lock (Lock): The process lock for writer same file. + writer_id (int): The index of writer. + """ + def __init__(self, step, data_id, write_lock, writer_id): + super(SummaryDataProcess, self).__init__() + self.daemon = True + self.writer_id = writer_id + self.writer = SummaryDataManager.summary_file_get(self.writer_id) + if self.writer is None: + logger.error("The writer_id(%r) does not have writer", writer_id) + self.step = step + self.data_id = data_id + self.write_lock = write_lock + self.name = "SummaryDataConsumer_" + str(self.step) + + def run(self): + """The consumer is process the step data and exit.""" + # convert the data to event + # All exceptions need to be caught and end the queue + try: + logger.debug("process(%r) process a data(%r)", self.name, self.step) + # package the summary event + summary_event = package_summary_event(self.data_id, self.step) + # send the event to file + self._write_summary(summary_event) + except Exception as e: + logger.error("Summary data mq consumer exception occurred, value = %r", e) + + def _write_summary(self, summary_event): + """ + Write the summary to event file. + + Note: + The write record format: + 1 uint64 : data length. + 2 uint32 : mask crc value of data length. + 3 bytes : data. + 4 uint32 : mask crc value of data. + + Args: + summary_event (Event): The summary event of proto. + + """ + event_str = summary_event.SerializeToString() + self.write_lock.acquire() + self.writer.write_event_to_file(event_str) + self.writer.flush() + self.write_lock.release() diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 43baebccf9..4c60dce862 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -14,22 +14,17 @@ # ============================================================================ """Record the summary event.""" import os -import re import threading - from mindspore import log as logger - -from ..._c_expression import Tensor -from ..._checkparam import _check_str_by_regular +from ._summary_scheduler import WorkerScheduler, SummaryDataManager +from ._summary_adapter import get_event_file_name, package_graph_event +from ._event_writer import EventRecord from .._utils import _make_directory -from ._event_writer import EventWriter -from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event +from ..._checkparam import _check_str_by_regular -# for the moment, this lock is for caution's sake, -# there are actually no any concurrencies happening. -_summary_lock = threading.Lock() # cache the summary data _summary_tensor_cache = {} +_summary_lock = threading.Lock() def _cache_summary_tensor_data(summary): @@ -39,18 +34,14 @@ def _cache_summary_tensor_data(summary): Args: summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]. """ - with _summary_lock: - for item in summary: - _summary_tensor_cache[item['name']] = item['data'] - return True - - -def _get_summary_tensor_data(): - global _summary_tensor_cache - with _summary_lock: - data = _summary_tensor_cache - _summary_tensor_cache = {} - return data + _summary_lock.acquire() + if "SummaryRecord" in _summary_tensor_cache: + for record in summary: + _summary_tensor_cache["SummaryRecord"].append(record) + else: + _summary_tensor_cache["SummaryRecord"] = summary + _summary_lock.release() + return True class SummaryRecord: @@ -80,7 +71,6 @@ class SummaryRecord: >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, >>> file_prefix="xxx_", file_suffix="_yyy") """ - def __init__(self, log_dir, queue_max_size=0, @@ -111,18 +101,26 @@ class SummaryRecord: self.prefix = file_prefix self.suffix = file_suffix - self.network = network - self.has_graph = False - self._closed = False # create the summary writer file self.event_file_name = get_event_file_name(self.prefix, self.suffix) + if self.log_path[-1:] == '/': + self.full_file_name = self.log_path + self.event_file_name + else: + self.full_file_name = self.log_path + '/' + self.event_file_name + try: - self.full_file_name = os.path.join(self.log_path, self.event_file_name) + self.full_file_name = os.path.realpath(self.full_file_name) except Exception as ex: raise RuntimeError(ex) - self.event_writer = EventWriter(self.full_file_name, self.flush_time) - self.event_writer.write(package_init_event().SerializeToString()) + self.event_writer = EventRecord(self.full_file_name, self.flush_time) + self.writer_id = SummaryDataManager.summary_file_set(self.event_writer) + self.worker_scheduler = WorkerScheduler(self.writer_id) + + self.step = 0 + self._closed = False + self.network = network + self.has_graph = False def record(self, step, train_network=None): """ @@ -147,34 +145,42 @@ class SummaryRecord: if not isinstance(step, int) or isinstance(step, bool): raise ValueError("`step` should be int") # Set the current summary of train step + self.step = step - if self.network is not None and not self.has_graph: + if self.network is not None and self.has_graph is False: graph_proto = self.network.get_func_graph_proto() if graph_proto is None and train_network is not None: graph_proto = train_network.get_func_graph_proto() if graph_proto is None: logger.error("Failed to get proto for graph") else: - self.event_writer.write(package_graph_event(graph_proto).SerializeToString()) + self.event_writer.write_event_to_file( + package_graph_event(graph_proto).SerializeToString()) + self.event_writer.flush() self.has_graph = True - if not _summary_tensor_cache: + data = _summary_tensor_cache.get("SummaryRecord") + if data is None: return True - data = _get_summary_tensor_data() - if not data: - logger.error("The step(%r) does not have record data.", step) + data = _summary_tensor_cache.get("SummaryRecord") + if data is None: + logger.error("The step(%r) does not have record data.", self.step) return False if self.queue_max_size > 0 and len(data) > self.queue_max_size: logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), self.queue_max_size) + # clean the data of cache + del _summary_tensor_cache["SummaryRecord"] + # process the data - result = self._data_convert(data) - if not result: - logger.error("The step(%r) summary data is invalid.", step) - return False - self.event_writer.write((result, step)) - logger.debug("Send the summary data to scheduler for saving, step = %d", step) + self.worker_scheduler.dispatch(self.step, data) + + # count & flush + self.event_writer.count_event() + self.event_writer.flush_cycle() + + logger.debug("Send the summary data to scheduler for saving, step = %d", self.step) return True @property @@ -190,7 +196,7 @@ class SummaryRecord: Returns: String, the full path of log file. """ - return self.full_file_name + return self.event_writer.full_file_name def flush(self): """ @@ -218,44 +224,20 @@ class SummaryRecord: >>> summary_record.close() """ if not self._closed: + self._check_data_before_close() + self.worker_scheduler.close() # event writer flush and close self.event_writer.close() self._closed = True - def _data_convert(self, summary): - """Convert the data.""" - # convert the summary to numpy - result = [] - for name, data in summary.items(): - # confirm the data is valid - summary_tag, summary_type = SummaryRecord._parse_from(name) - if summary_tag is None: - logger.error("The data type is invalid, name = %r, tensor = %r", name, data) - return None - if isinstance(data, Tensor): - result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type}) - else: - logger.error("The data type is invalid, name = %r, tensor = %r", name, data) - return None - - return result - - @staticmethod - def _parse_from(name: str = None): - """ - Parse the tag and type from name. - - Args: - name (str): Format: TAG[:TYPE]. - - Returns: - Tuple, (summary_tag, summary_type). - """ - if name is None: - logger.error("The name is None") - return None, None - match = re.match(r'(.+)\[:(.+)\]', name) - if match: - return match.groups() - logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name) - return None, None + def __del__(self): + """Process exit is called.""" + if hasattr(self, "worker_scheduler"): + if self.worker_scheduler: + self.close() + + def _check_data_before_close(self): + "Check whether there is any data in the cache, and if so, call record" + data = _summary_tensor_cache.get("SummaryRecord") + if data is not None: + self.record(self.step)