浏览代码

Prevent the multiprocess from capturing KeyboardInterrupt

tags/v1.2.0-rc1
ougongchang 5 年前
父节点
当前提交
e5529230bf
共有 3 个文件被更改,包括 14 次插入9 次删除
  1. +1
    -1
      mindspore/train/_utils.py
  2. +5
    -0
      mindspore/train/summary/_writer_pool.py
  3. +8
    -8
      tests/ut/python/train/summary/test_summary_collector.py

+ 1
- 1
mindspore/train/_utils.py 查看文件

@@ -202,7 +202,7 @@ def check_value_type(arg_name, arg_value, valid_types):

if not is_valid:
raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
f'bug got {type(arg_value).__name__}.')
f'but got {type(arg_value).__name__}.')


def read_proto(file_name, proto_format="MINDIR"):


+ 5
- 0
mindspore/train/summary/_writer_pool.py 查看文件

@@ -15,6 +15,7 @@
"""Write events to disk in a base directory."""
import os
import time
import signal
from collections import deque

import mindspore.log as logger
@@ -77,6 +78,10 @@ class WriterPool(ctx.Process):
os.environ['GOTO_NUM_THREADS'] = '2'
os.environ['OMP_NUM_THREADS'] = '2'

# Prevent the multiprocess from capturing KeyboardInterrupt,
# which causes the main process to fail to exit.
signal.signal(signal.SIGINT, signal.SIG_IGN)

with ctx.Pool(min(ctx.cpu_count(), 32)) as pool:
deq = deque()
while True:


+ 8
- 8
tests/ut/python/train/summary/test_summary_collector.py 查看文件

@@ -118,7 +118,7 @@ class TestSummaryCollector:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \
f'bug got {type(collect_freq).__name__}.'
f'but got {type(collect_freq).__name__}.'
assert expected_msg == str(exc.value)

@pytest.mark.parametrize("action", [None, 123, '', '123'])
@@ -128,7 +128,7 @@ class TestSummaryCollector:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir, keep_default_action=action)
expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \
f"bug got {type(action).__name__}."
f"but got {type(action).__name__}."
assert expected_msg == str(exc.value)

@pytest.mark.parametrize("collect_specified_data", [123])
@@ -139,7 +139,7 @@ class TestSummaryCollector:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)

expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \
f"bug got {type(collect_specified_data).__name__}."
f"but got {type(collect_specified_data).__name__}."

assert expected_msg == str(exc.value)

@@ -159,7 +159,7 @@ class TestSummaryCollector:

param_name = list(collect_specified_data)[0]
expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \
f"bug got {type(param_name).__name__}."
f"but got {type(param_name).__name__}."
assert expected_msg == str(exc.value)

@pytest.mark.parametrize("collect_specified_data", [
@@ -183,7 +183,7 @@ class TestSummaryCollector:
param_value = collect_specified_data[param_name]
expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']"
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
f'bug got {type(param_value).__name__}.'
f'but got {type(param_value).__name__}.'

assert expected_msg == str(exc.value)

@@ -216,18 +216,18 @@ class TestSummaryCollector:

if not isinstance(custom_lineage_data, dict):
expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \
f"bug got {type(custom_lineage_data).__name__}."
f"but got {type(custom_lineage_data).__name__}."
else:
param_name = list(custom_lineage_data)[0]
param_value = custom_lineage_data[param_name]
if not isinstance(param_name, str):
arg_name = f'custom_lineage_data -> {param_name}'
expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \
f'bug got {type(param_name).__name__}.'
f'but got {type(param_name).__name__}.'
else:
arg_name = f'the value of custom_lineage_data -> {param_name}'
expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \
f'bug got {type(param_value).__name__}.'
f'but got {type(param_value).__name__}.'

assert expected_msg == str(exc.value)



正在加载...
取消
保存