diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 0d0cad7530..96416b6bcc 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -111,7 +111,7 @@ class SummaryCollector(Callback): and float. Default: None, it means there is no custom data. collect_tensor_freq (Optional[int]): The same semantics as the `collect_freq`, but controls TensorSummary only. Because TensorSummary data is too large to be compared with other summary data, this parameter is used to - reduce its collection. By default, The maximum number of steps for collecting TensorSummary data is 21, + reduce its collection. By default, The maximum number of steps for collecting TensorSummary data is 20, but it will not exceed the number of steps for collecting other summary data. Default: None, which means to follow the behavior as described above. For example, given `collect_freq=10`, when the total steps is 600, TensorSummary will be collected 20 steps, while other summary data 61 steps, diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index 888c5a90a4..cdd77d2edf 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -16,7 +16,6 @@ import os import time from collections import deque -from multiprocessing import Pool, Process, Queue, cpu_count import mindspore.log as logger @@ -24,6 +23,12 @@ from ._lineage_adapter import serialize_to_lineage_event from ._summary_adapter import package_graph_event, package_summary_event from ._summary_writer import LineageWriter, SummaryWriter +try: + from multiprocessing import get_context + ctx = get_context('forkserver') +except ValueError: + import multiprocessing as ctx + def _pack_data(datadict, wall_time): """Pack data according to which plugin.""" @@ -42,7 +47,7 @@ def _pack_data(datadict, wall_time): return result -class WriterPool(Process): +class WriterPool(ctx.Process): """ Use a set of pooled resident processes for writing a list of file. @@ -54,12 +59,12 @@ class WriterPool(Process): def __init__(self, base_dir, max_file_size, **filedict) -> None: super().__init__() self._base_dir, self._filedict = base_dir, filedict - self._queue, self._writers_ = Queue(cpu_count() * 2), None + self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None self._max_file_size = max_file_size self.start() def run(self): - with Pool(min(cpu_count(), 32)) as pool: + with ctx.Pool(min(ctx.cpu_count(), 32)) as pool: deq = deque() while True: while deq and deq[0].ready():