|
|
|
@@ -254,6 +254,8 @@ class DistributedGradReducer(Cell): |
|
|
|
>>> from mindspore.context import ParallelMode |
|
|
|
>>> from mindspore import nn |
|
|
|
>>> from mindspore import ParameterTuple |
|
|
|
>>> from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, |
|
|
|
>>> _get_parallel_mode) |
|
|
|
>>> |
|
|
|
>>> device_id = int(os.environ["DEVICE_ID"]) |
|
|
|
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, |
|
|
|
@@ -279,11 +281,8 @@ class DistributedGradReducer(Cell): |
|
|
|
>>> ParallelMode.HYBRID_PARALLEL]: |
|
|
|
>>> self.reducer_flag = True |
|
|
|
>>> if self.reducer_flag: |
|
|
|
>>> mean = context.get_auto_parallel_context("gradients_mean") |
|
|
|
>>> if mean.get_device_num_is_set(): |
|
|
|
>>> degree = context.get_auto_parallel_context("device_num") |
|
|
|
>>> else: |
|
|
|
>>> degree = get_group_size() |
|
|
|
>>> mean = _get_gradients_mean() |
|
|
|
>>> degree = _get_device_num() |
|
|
|
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) |
|
|
|
>>> |
|
|
|
>>> def construct(self, *args): |
|
|
|
|