|
|
|
@@ -268,47 +268,45 @@ class DistributedGradReducer(Cell): |
|
|
|
>>> context.reset_auto_parallel_context() |
|
|
|
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) |
|
|
|
>>> |
|
|
|
>>> |
|
|
|
>>> class TrainingWrapper(nn.Cell): |
|
|
|
>>> def __init__(self, network, optimizer, sens=1.0): |
|
|
|
>>> super(TrainingWrapper, self).__init__(auto_prefix=False) |
|
|
|
>>> self.network = network |
|
|
|
>>> self.network.add_flags(defer_inline=True) |
|
|
|
>>> self.weights = optimizer.parameters |
|
|
|
>>> self.optimizer = optimizer |
|
|
|
>>> self.grad = C.GradOperation(get_by_list=True, sens_param=True) |
|
|
|
>>> self.sens = sens |
|
|
|
>>> self.reducer_flag = False |
|
|
|
>>> self.grad_reducer = None |
|
|
|
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode") |
|
|
|
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL, |
|
|
|
>>> ParallelMode.HYBRID_PARALLEL]: |
|
|
|
>>> self.reducer_flag = True |
|
|
|
>>> if self.reducer_flag: |
|
|
|
>>> mean = _get_gradients_mean() |
|
|
|
>>> degree = _get_device_num() |
|
|
|
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) |
|
|
|
>>> |
|
|
|
>>> def construct(self, *args): |
|
|
|
>>> weights = self.weights |
|
|
|
>>> loss = self.network(*args) |
|
|
|
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) |
|
|
|
>>> grads = self.grad(self.network, weights)(*args, sens) |
|
|
|
>>> if self.reducer_flag: |
|
|
|
>>> # apply grad reducer on grads |
|
|
|
>>> grads = self.grad_reducer(grads) |
|
|
|
>>> return F.depend(loss, self.optimizer(grads)) |
|
|
|
... def __init__(self, network, optimizer, sens=1.0): |
|
|
|
... super(TrainingWrapper, self).__init__(auto_prefix=False) |
|
|
|
... self.network = network |
|
|
|
... self.network.add_flags(defer_inline=True) |
|
|
|
... self.weights = optimizer.parameters |
|
|
|
... self.optimizer = optimizer |
|
|
|
... self.grad = C.GradOperation(get_by_list=True, sens_param=True) |
|
|
|
... self.sens = sens |
|
|
|
... self.reducer_flag = False |
|
|
|
... self.grad_reducer = None |
|
|
|
... self.parallel_mode = context.get_auto_parallel_context("parallel_mode") |
|
|
|
... if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: |
|
|
|
... self.reducer_flag = True |
|
|
|
... if self.reducer_flag: |
|
|
|
... mean = _get_gradients_mean() |
|
|
|
... degree = _get_device_num() |
|
|
|
... self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) |
|
|
|
... |
|
|
|
... def construct(self, *args): |
|
|
|
... weights = self.weights |
|
|
|
... loss = self.network(*args) |
|
|
|
... sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) |
|
|
|
... grads = self.grad(self.network, weights)(*args, sens) |
|
|
|
... if self.reducer_flag: |
|
|
|
... # apply grad reducer on grads |
|
|
|
... grads = self.grad_reducer(grads) |
|
|
|
... return F.depend(loss, self.optimizer(grads)) |
|
|
|
>>> |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
>>> def __init__(self, in_features, out_features): |
|
|
|
>>> super(Net, self).__init__() |
|
|
|
>>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), |
|
|
|
>>> name='weight') |
|
|
|
>>> self.matmul = P.MatMul() |
|
|
|
>>> |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> output = self.matmul(x, self.weight) |
|
|
|
>>> return output |
|
|
|
... def __init__(self, in_features, out_features): |
|
|
|
... super(Net, self).__init__() |
|
|
|
... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), |
|
|
|
... name='weight') |
|
|
|
... self.matmul = P.MatMul() |
|
|
|
... |
|
|
|
... def construct(self, x): |
|
|
|
... output = self.matmul(x, self.weight) |
|
|
|
... return output |
|
|
|
>>> |
|
|
|
>>> size, in_features, out_features = 16, 16, 10 |
|
|
|
>>> network = Net(in_features, out_features) |
|
|
|
|