|
|
|
@@ -19,6 +19,8 @@ import numbers |
|
|
|
import numpy as np |
|
|
|
from .._c_expression import ParamInfo |
|
|
|
from . import dtype as mstype |
|
|
|
from .. import context |
|
|
|
from ..parallel._utils import _get_parallel_mode |
|
|
|
from .initializer import initializer |
|
|
|
from .tensor import Tensor |
|
|
|
from .._checkparam import Validator |
|
|
|
@@ -292,7 +294,18 @@ class Parameter(Tensor_): |
|
|
|
|
|
|
|
@comm_fusion.setter |
|
|
|
def comm_fusion(self, comm_fusion_): |
|
|
|
"""Set the fusion type for communication operators corresponding to this parameter.""" |
|
|
|
""" |
|
|
|
In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or |
|
|
|
gradients aggregation are inserted automatically.Set the fusion type for communication operators generated |
|
|
|
for this parameter. Only `Ascend` and `Graph` mode is supported. |
|
|
|
|
|
|
|
Args: |
|
|
|
comm_fusion_ (int): The value of fusion must be greater than or equal to 0. |
|
|
|
When the value of fusion is 0, operators will not be fused together. |
|
|
|
""" |
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode(): |
|
|
|
raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE") |
|
|
|
Validator.check_non_negative_int(comm_fusion_) |
|
|
|
self.param_info.comm_fusion = comm_fusion_ |
|
|
|
|
|
|
|
@property |
|
|
|
|