|
|
|
@@ -15,12 +15,14 @@ |
|
|
|
"""Utils of auto parallel""" |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore._c_expression import reset_op_id |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.common.dtype import dtype_to_nptype |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.communication.management import get_group_size, get_rank |
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context |
|
|
|
from mindspore.common.seed import get_seed |
|
|
|
|
|
|
|
|
|
|
|
def _get_parallel_mode(): |
|
|
|
@@ -136,16 +138,11 @@ def _get_global_rank(): |
|
|
|
def _get_parameter_broadcast(): |
|
|
|
"""Get the parameter broadcast.""" |
|
|
|
parallel_mode = auto_parallel_context().get_parallel_mode() |
|
|
|
if parallel_mode == "stand_alone": |
|
|
|
parameter_broadcast = False |
|
|
|
return parameter_broadcast |
|
|
|
parameter_broadcast = auto_parallel_context().get_parameter_broadcast() |
|
|
|
|
|
|
|
if auto_parallel_context().get_parameter_broadcast_is_set() is True: |
|
|
|
parameter_broadcast = auto_parallel_context().get_parameter_broadcast() |
|
|
|
elif parallel_mode in ("data_parallel", "hybrid_parallel"): |
|
|
|
parameter_broadcast = True |
|
|
|
else: |
|
|
|
parameter_broadcast = False |
|
|
|
if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None: |
|
|
|
logger.warning("You are suggested to use mindspore.common.set_seed() to share" |
|
|
|
" parameters among devices.") |
|
|
|
|
|
|
|
return parameter_broadcast |
|
|
|
|
|
|
|
|