diff --git a/tests/ut/python/parallel/__init__.py b/tests/ut/python/parallel/__init__.py index edd469899e..7f4e7e22ed 100644 --- a/tests/ut/python/parallel/__init__.py +++ b/tests/ut/python/parallel/__init__.py @@ -17,11 +17,12 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._cost_model_context import reset_cost_model_context from mindspore.parallel._utils import _reset_op_id from mindspore.parallel.algo_parameter_config import reset_algo_parameters - +from mindspore.communication._comm_helper import GlobalComm def setup_module(): auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) + GlobalComm.INITED = True reset_cost_model_context() reset_algo_parameters() _reset_op_id() @@ -29,6 +30,7 @@ def setup_module(): def teardown_module(): context.reset_auto_parallel_context() + GlobalComm.INITED = False reset_cost_model_context() reset_algo_parameters() _reset_op_id()