Browse Source

!19097 update parallel api

Merge pull request !19097 from gziyan/update_parallel_api
tags/v1.3.0
i-robot Gitee 4 years ago
parent
commit
195a6dd09d
2 changed files with 15 additions and 5 deletions
  1. +11
    -5
      mindspore/common/parameter.py
  2. +4
    -0
      mindspore/nn/cell.py

+ 11
- 5
mindspore/common/parameter.py View File

@@ -79,7 +79,7 @@ class Parameter(Tensor_):
default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data, to be set initialized.
name (str): Name of the child parameter. Default: None.
requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode,
layerwise_parallel (bool): When layerwise_parallel is true in data/hybrid parallel mode,
broadcast and gradients communication would not be applied to parameters. Default: False.
parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel
mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
@@ -385,7 +385,10 @@ class Parameter(Tensor_):

@property
def layerwise_parallel(self):
"""Return whether the parameter is layerwise parallel."""
"""
When layerwise_parallel is true in data/hybrid parallel mode, broadcast and gradients communication would not
be applied to parameters.
"""
return self.param_info.layerwise_parallel

@layerwise_parallel.setter
@@ -396,7 +399,10 @@ class Parameter(Tensor_):

@property
def parallel_optimizer(self):
"""Return whether the parameter requires weight shard for parallel optimizer."""
"""
It is used to filter the weight shard operation in semi auto or auto parallel mode. It works only
when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
"""
return self.param_info.parallel_optimizer

@parallel_optimizer.setter
@@ -554,8 +560,8 @@ class Parameter(Tensor_):
if layout is not None:
if not isinstance(layout, tuple):
raise TypeError("The layout should be tuple, but got layout is {}.".format(layout))
if len(layout) < 3:
raise ValueError("The length of layout must be larger than 2, but got layout is {}.".format(layout))
if len(layout) < 6:
raise ValueError("The length of layout must be larger than 5, but got layout is {}.".format(layout))
slice_index = int(_get_slice_index(layout[0], layout[1]))
init_data_args += (slice_index, layout[2], layout[5])



+ 4
- 0
mindspore/nn/cell.py View File

@@ -208,6 +208,10 @@ class Cell(Cell_):

@property
def parameter_layout_dict(self):
"""
`parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and
distributed operator information.
"""
return self._parameter_layout_dict

@property


Loading…
Cancel
Save