Browse Source

bugfix(side effect): fix adding wrong control depend between AllReduce and GetStatus.

tags/v0.2.0-alpha
gong chen 5 years ago
parent
commit
5d4144de11
2 changed files with 6 additions and 8 deletions
  1. +3
    -4
      mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py
  2. +3
    -4
      mindspore/nn/wrap/loss_scale.py

+ 3
- 4
mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py View File

@@ -370,7 +370,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
self.grad_reducer = F.identity
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
@@ -428,9 +428,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:


+ 3
- 4
mindspore/nn/wrap/loss_scale.py View File

@@ -220,7 +220,7 @@ class TrainOneStepWithLossScaleCell(Cell):
self.depend_parameter_use = ControlDepend(depend_mode=1)
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.grad_reducer = None
self.grad_reducer = F.identity
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
if self.reducer_flag:
mean = _get_mirror_mean()
@@ -250,9 +250,8 @@ class TrainOneStepWithLossScaleCell(Cell):
scaling_sens = sens
grads = self.grad(self.network, weights)(data, label, F.cast(scaling_sens, F.dtype(loss)))
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
if not self.gpu_target:
self.get_status(init)


Loading…
Cancel
Save