|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Schedule the event writer process."""
- import multiprocessing as mp
- from enum import Enum, unique
- from mindspore import log as logger
- from ..._c_expression import Tensor
- from ._summary_adapter import SummaryType, package_summary_event, save_summary_data
-
- # define the type of summary
- FORMAT_SCALAR_STR = "Scalar"
- FORMAT_TENSOR_STR = "Tensor"
- FORMAT_IMAGE_STR = "Image"
- FORMAT_HISTOGRAM_STR = "Histogram"
- FORMAT_BEGIN_SLICE = "[:"
- FORMAT_END_SLICE = "]"
-
- # cache the summary data dict
- # {id: SummaryData}
- # |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
- g_summary_data_id = 0
- g_summary_data_dict = {}
- # cache the summary data file
- g_summary_writer_id = 0
- g_summary_file = {}
-
-
- @unique
- class ScheduleMethod(Enum):
- """Schedule method type."""
- FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue
- TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy)
- CACHE_DATA = 2 # Cache data util have idle worker to process it
-
-
- @unique
- class WorkerStatus(Enum):
- """Worker status."""
- WORKER_INIT = 0 # data is exist but not process
- WORKER_PROCESSING = 1 # data is processing
- WORKER_PROCESSED = 2 # data already processed
-
-
- def _parse_tag_format(tag: str):
- """
- Parse the tag.
-
- Args:
- tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor].
-
- Returns:
- Tuple, (SummaryType, summary_tag).
- """
-
- summary_type = SummaryType.INVALID
- summary_tag = tag
- if tag is None:
- logger.error("The tag is None")
- return summary_type, summary_tag
-
- # search the slice
- slice_begin = FORMAT_BEGIN_SLICE
- slice_end = FORMAT_END_SLICE
- index = tag.rfind(slice_begin)
- if index is -1:
- logger.error("The tag(%s) have not the key slice.", tag)
- return summary_type, summary_tag
-
- # slice the tag
- summary_tag = tag[:index]
-
- # check the slice end
- if tag[-1:] != slice_end:
- logger.error("The tag(%s) end format is error", tag)
- return summary_type, summary_tag
-
- # check the type
- type_str = tag[index + 2: -1]
- logger.debug("The summary_tag is = %r", summary_tag)
- logger.debug("The type_str value is = %r", type_str)
- if type_str == FORMAT_SCALAR_STR:
- summary_type = SummaryType.SCALAR
- elif type_str == FORMAT_TENSOR_STR:
- summary_type = SummaryType.TENSOR
- elif type_str == FORMAT_IMAGE_STR:
- summary_type = SummaryType.IMAGE
- elif type_str == FORMAT_HISTOGRAM_STR:
- summary_type = SummaryType.HISTOGRAM
- else:
- logger.error("The tag(%s) type is invalid.", tag)
- summary_type = SummaryType.INVALID
-
- return summary_type, summary_tag
-
-
- class SummaryDataManager:
- """Manage the summary global data cache."""
- def __init__(self):
- global g_summary_data_dict
- self.size = len(g_summary_data_dict)
-
- @classmethod
- def summary_data_save(cls, data):
- """Save the global summary cache."""
- global g_summary_data_id
- data_id = g_summary_data_id
- save_summary_data(data_id, data)
- g_summary_data_id += 1
- return data_id
-
- @classmethod
- def summary_file_set(cls, event_writer):
- """Support the many event_writer."""
- global g_summary_file, g_summary_writer_id
- g_summary_writer_id += 1
- g_summary_file[g_summary_writer_id] = event_writer
- return g_summary_writer_id
-
- @classmethod
- def summary_file_get(cls, writer_id=1):
- ret = None
- global g_summary_file
- if writer_id in g_summary_file:
- ret = g_summary_file.get(writer_id)
- return ret
-
-
- class WorkerScheduler:
- """
- Create worker and schedule data to worker.
-
- Args:
- writer_id (int): The index of writer.
- """
- def __init__(self, writer_id):
- # Create the process of write event file
- self.write_lock = mp.Lock()
- # Schedule info for all worker
- # Format: {worker: (step, WorkerStatus)}
- self.schedule_table = {}
- # write id
- self.writer_id = writer_id
- self.has_graph = False
-
- def dispatch(self, step, data):
- """
- Select schedule strategy and dispatch data.
-
- Args:
- step (Number): The number of step index.
- data (Object): The data of recode for summary.
-
- Retruns:
- bool, run successfully or not.
- """
- # save the data to global cache , convert the tensor to numpy
- result, size, data = self._data_convert(data)
- if result is False:
- logger.error("The step(%r) summary data(%r) is invalid.", step, size)
- return False
-
- data_id = SummaryDataManager.summary_data_save(data)
- self._start_worker(step, data_id)
- return True
-
- def _start_worker(self, step, data_id):
- """
- Start worker.
-
- Args:
- step (Number): The index of recode.
- data_id (str): The id of work.
-
- Return:
- bool, run successfully or not.
- """
- # assign the worker
- policy = self._make_policy()
- if policy == ScheduleMethod.TEMP_WORKER:
- worker = SummaryDataProcess(step, data_id, self.write_lock, self.writer_id)
- # update the schedule table
- self.schedule_table[worker] = (step, data_id, WorkerStatus.WORKER_INIT)
- # start the worker
- worker.start()
- else:
- logger.error("Do not support the other scheduler policy now.")
-
- # update the scheduler infor
- self._update_scheduler()
- return True
-
- def _data_convert(self, data_list):
- """Convert the data."""
- if data_list is None:
- logger.warning("The step does not have record data.")
- return False, 0, None
-
- # convert the summary to numpy
- size = 0
- for v_dict in data_list:
- tag = v_dict["name"]
- data = v_dict["data"]
- # confirm the data is valid
- summary_type, summary_tag = _parse_tag_format(tag)
- if summary_type == SummaryType.INVALID:
- logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
- return False, 0, None
- if isinstance(data, Tensor):
- # get the summary type and parse the tag
- v_dict["name"] = summary_tag
- v_dict["type"] = summary_type
- v_dict["data"] = data.asnumpy()
- size += v_dict["data"].size
- else:
- logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
- return False, 0, None
-
- return True, size, data_list
-
- def _update_scheduler(self):
- """Check the worker status and update schedule table."""
- workers = list(self.schedule_table.keys())
- for worker in workers:
- if not worker.is_alive():
- # update the table
- worker.join()
- del self.schedule_table[worker]
-
- def close(self):
- """Confirm all worker is end."""
- workers = self.schedule_table.keys()
- for worker in workers:
- if worker.is_alive():
- worker.join()
-
- def _make_policy(self):
- """Select the schedule strategy by data."""
- # now only support the temp worker
- return ScheduleMethod.TEMP_WORKER
-
-
- class SummaryDataProcess(mp.Process):
- """
- Process that consume the summarydata.
-
- Args:
- step (int): The index of step.
- data_id (int): The index of summary data.
- write_lock (Lock): The process lock for writer same file.
- writer_id (int): The index of writer.
- """
- def __init__(self, step, data_id, write_lock, writer_id):
- super(SummaryDataProcess, self).__init__()
- self.daemon = True
- self.writer_id = writer_id
- self.writer = SummaryDataManager.summary_file_get(self.writer_id)
- if self.writer is None:
- logger.error("The writer_id(%r) does not have writer", writer_id)
- self.step = step
- self.data_id = data_id
- self.write_lock = write_lock
- self.name = "SummaryDataConsumer_" + str(self.step)
-
- def run(self):
- """The consumer is process the step data and exit."""
- # convert the data to event
- # All exceptions need to be caught and end the queue
- try:
- logger.debug("process(%r) process a data(%r)", self.name, self.step)
- # package the summary event
- summary_event = package_summary_event(self.data_id, self.step)
- # send the event to file
- self._write_summary(summary_event)
- except Exception as e:
- logger.error("Summary data mq consumer exception occurred, value = %r", e)
-
- def _write_summary(self, summary_event):
- """
- Write the summary to event file.
-
- Note:
- The write record format:
- 1 uint64 : data length.
- 2 uint32 : mask crc value of data length.
- 3 bytes : data.
- 4 uint32 : mask crc value of data.
-
- Args:
- summary_event (Event): The summary event of proto.
-
- """
- event_str = summary_event.SerializeToString()
- self.write_lock.acquire()
- self.writer.write_event_to_file(event_str)
- self.writer.flush()
- self.write_lock.release()
|