From 3240b2d8e1eb79c79ef80406ee817d52a6627db6 Mon Sep 17 00:00:00 2001 From: ougongchang Date: Tue, 2 Feb 2021 15:49:22 +0800 Subject: [PATCH] Add device id to summary file name To prevent write data conflicts in multi-card scenarios, the file on each card is increased by device_id --- mindspore/train/callback/_dataset_graph.py | 8 +++++++- mindspore/train/callback/_summary_collector.py | 2 ++ mindspore/train/summary/_summary_adapter.py | 16 +++++++++++++--- mindspore/train/summary/_writer_pool.py | 5 +++++ mindspore/train/summary/summary_record.py | 6 +++--- 5 files changed, 30 insertions(+), 7 deletions(-) diff --git a/mindspore/train/callback/_dataset_graph.py b/mindspore/train/callback/_dataset_graph.py index 761e47dd9b..e7188be75b 100644 --- a/mindspore/train/callback/_dataset_graph.py +++ b/mindspore/train/callback/_dataset_graph.py @@ -33,7 +33,13 @@ class DatasetGraph: DatasetGraph, a object of lineage_pb2.DatasetGraph. """ dataset_package = import_module('mindspore.dataset') - dataset_dict = dataset_package.serialize(dataset) + try: + dataset_dict = dataset_package.serialize(dataset) + except (TypeError, OSError) as exc: + logger.warning("Summary can not collect dataset graph, there is an error in dataset internal, " + "detail: %s.", str(exc)) + return None + dataset_graph_proto = lineage_pb2.DatasetGraph() if not isinstance(dataset_dict, dict): logger.warning("The dataset graph serialized from dataset object is not a dict. " diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 7296c9868d..da2f7da470 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -518,6 +518,8 @@ class SummaryCollector(Callback): train_dataset = cb_params.train_dataset dataset_graph = DatasetGraph() graph_bytes = dataset_graph.package_dataset_graph(train_dataset) + if graph_bytes is None: + return self._record.add_value('dataset_graph', 'train_dataset', graph_bytes) def _collect_graphs(self, cb_params): diff --git a/mindspore/train/summary/_summary_adapter.py b/mindspore/train/summary/_summary_adapter.py index 8e39d2ee5b..bd71e8ed60 100644 --- a/mindspore/train/summary/_summary_adapter.py +++ b/mindspore/train/summary/_summary_adapter.py @@ -20,6 +20,8 @@ import numpy as np from PIL import Image from mindspore import log as logger +from mindspore import context +from mindspore.communication.management import get_rank from ..._checkparam import Validator from ..anf_ir_pb2 import DataType, ModelProto @@ -53,10 +55,18 @@ def get_event_file_name(prefix, suffix, time_second): file_name = "" hostname = platform.node() - if prefix is not None: - file_name = file_name + prefix + device_num = context.get_auto_parallel_context('device_num') + device_id = context.get_context('device_id') + if device_num > 1: + # Notice: + # In GPU distribute training scene, get_context('device_id') will not work, + # so we use get_rank instead of get_context. + device_id = get_rank() + + file_name = f'{file_name}{EVENT_FILE_NAME_MARK}{time_second}.{device_id}.{hostname}' - file_name = file_name + EVENT_FILE_NAME_MARK + time_second + "." + hostname + if prefix is not None: + file_name = prefix + file_name if suffix is not None: file_name = file_name + suffix diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index 3a40eeefa2..da0c0b255e 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -97,6 +97,11 @@ class WriterPool(ctx.Process): with ctx.Pool(min(ctx.cpu_count(), 32)) as pool: deq = deque() while True: + if not self._writers: + logger.warning("Can not find any writer to write summary data, " + "so SummaryRecord will not record data.") + break + while deq and deq[0].ready(): for plugin, data in deq.popleft().get(): self._write(plugin, data) diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 3c29b41431..1d52a62bba 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -112,8 +112,8 @@ class SummaryRecord: network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). \ Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. - raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs - in recording data. Default: False, this means that error logs are printed and no exception is thrown. + raise_exception (bool, optional): Sets whether to throw an exception when a RuntimeError or OSError exception + occurs in recording data. Default: False, this means that error logs are printed and no exception is thrown. export_options (Union[None, dict]): Perform custom operations on the export data. Default: None, it means there is no export data. Note that the size of export files is not limited by the max_file_size. @@ -177,7 +177,7 @@ class SummaryRecord: if self._export_options is not None: export_dir = "export_{}".format(time_second) - filename_dict = dict(summary=self.full_file_name, + filename_dict = dict(summary=self.event_file_name, lineage=get_event_file_name(self.prefix, '_lineage', time_second), explainer=get_event_file_name(self.prefix, '_explain', time_second), exporter=export_dir)