|
|
|
@@ -339,9 +339,9 @@ class TrainOneStepCell(Cell): |
|
|
|
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): |
|
|
|
self.reducer_flag = True |
|
|
|
if self.reducer_flag: |
|
|
|
mean = _get_gradients_mean() |
|
|
|
degree = _get_device_num() |
|
|
|
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree) |
|
|
|
self.mean = _get_gradients_mean() |
|
|
|
self.degree = _get_device_num() |
|
|
|
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) |
|
|
|
|
|
|
|
def construct(self, *inputs): |
|
|
|
weights = self.weights |
|
|
|
|