|
|
|
@@ -79,7 +79,7 @@ class Parameter(Tensor_): |
|
|
|
default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data, to be set initialized. |
|
|
|
name (str): Name of the child parameter. Default: None. |
|
|
|
requires_grad (bool): True if the parameter requires gradient. Default: True. |
|
|
|
layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode, |
|
|
|
layerwise_parallel (bool): When layerwise_parallel is true in data/hybrid parallel mode, |
|
|
|
broadcast and gradients communication would not be applied to parameters. Default: False. |
|
|
|
parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel |
|
|
|
mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`. |
|
|
|
@@ -385,7 +385,10 @@ class Parameter(Tensor_): |
|
|
|
|
|
|
|
@property |
|
|
|
def layerwise_parallel(self): |
|
|
|
"""Return whether the parameter is layerwise parallel.""" |
|
|
|
""" |
|
|
|
When layerwise_parallel is true in data/hybrid parallel mode, broadcast and gradients communication would not |
|
|
|
be applied to parameters. |
|
|
|
""" |
|
|
|
return self.param_info.layerwise_parallel |
|
|
|
|
|
|
|
@layerwise_parallel.setter |
|
|
|
@@ -396,7 +399,10 @@ class Parameter(Tensor_): |
|
|
|
|
|
|
|
@property |
|
|
|
def parallel_optimizer(self): |
|
|
|
"""Return whether the parameter requires weight shard for parallel optimizer.""" |
|
|
|
""" |
|
|
|
It is used to filter the weight shard operation in semi auto or auto parallel mode. It works only |
|
|
|
when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`. |
|
|
|
""" |
|
|
|
return self.param_info.parallel_optimizer |
|
|
|
|
|
|
|
@parallel_optimizer.setter |
|
|
|
@@ -554,8 +560,8 @@ class Parameter(Tensor_): |
|
|
|
if layout is not None: |
|
|
|
if not isinstance(layout, tuple): |
|
|
|
raise TypeError("The layout should be tuple, but got layout is {}.".format(layout)) |
|
|
|
if len(layout) < 3: |
|
|
|
raise ValueError("The length of layout must be larger than 2, but got layout is {}.".format(layout)) |
|
|
|
if len(layout) < 6: |
|
|
|
raise ValueError("The length of layout must be larger than 5, but got layout is {}.".format(layout)) |
|
|
|
slice_index = int(_get_slice_index(layout[0], layout[1])) |
|
|
|
init_data_args += (slice_index, layout[2], layout[5]) |
|
|
|
|
|
|
|
|