|
|
|
@@ -220,7 +220,7 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
self.depend_parameter_use = ControlDepend(depend_mode=1) |
|
|
|
self.allreduce = P.AllReduce() |
|
|
|
self.parallel_mode = _get_parallel_mode() |
|
|
|
self.grad_reducer = None |
|
|
|
self.grad_reducer = F.identity |
|
|
|
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] |
|
|
|
if self.reducer_flag: |
|
|
|
mean = _get_mirror_mean() |
|
|
|
@@ -250,9 +250,8 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
scaling_sens = sens |
|
|
|
grads = self.grad(self.network, weights)(data, label, F.cast(scaling_sens, F.dtype(loss))) |
|
|
|
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) |
|
|
|
if self.reducer_flag: |
|
|
|
# apply grad reducer on grads |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
# apply grad reducer on grads |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
# get the overflow buffer |
|
|
|
if not self.gpu_target: |
|
|
|
self.get_status(init) |
|
|
|
|