Merge pull request !6159 from wangnan39/optim_train_one_step_celltags/v1.0.0
| @@ -185,23 +185,21 @@ class TrainOneStepCell(Cell): | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| parallel_mode = _get_parallel_mode() | |||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||
| self.grad_reducer = F.identity | |||
| self.parallel_mode = _get_parallel_mode() | |||
| if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = _get_gradients_mean() | |||
| degree = _get_device_num() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| self.grad_reducer = DistributedGradReducer(self.weights, mean, degree) | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| loss = self.network(*inputs) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(*inputs, sens) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.grad_reducer(grads) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| @@ -14,9 +14,8 @@ | |||
| # ============================================================================ | |||
| """Loss scale cell for loss scale training.""" | |||
| import mindspore.context as context | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean | |||
| from .cell_wrapper import TrainOneStepCell | |||
| from ..cell import Cell | |||
| from ...common import Tensor, RowTensor | |||
| from ...common.parameter import Parameter | |||
| @@ -163,7 +162,7 @@ class FixedLossScaleUpdateCell(Cell): | |||
| return overflow | |||
| class TrainOneStepWithLossScaleCell(Cell): | |||
| class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||
| r""" | |||
| Network training with loss scaling. | |||
| @@ -203,15 +202,8 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) | |||
| >>> output = train_network(inputs, label, scaling_sens) | |||
| """ | |||
| def __init__(self, network, optimizer, scale_sense): | |||
| super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.network.add_flags(defer_inline=True) | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) | |||
| self.hyper_map = C.HyperMap() | |||
| if context.get_context("device_target") == "GPU": | |||
| self.gpu_target = True | |||
| @@ -228,13 +220,6 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| self.less_equal = LessEqual() | |||
| self.depend_parameter_use = ControlDepend(depend_mode=1) | |||
| self.allreduce = P.AllReduce() | |||
| self.parallel_mode = _get_parallel_mode() | |||
| self.grad_reducer = F.identity | |||
| self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] | |||
| if self.reducer_flag: | |||
| mean = _get_gradients_mean() | |||
| degree = _get_device_num() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | |||
| self.loss_scaling_manager = None | |||
| @@ -271,23 +271,7 @@ class BertTrainOneStepCell(nn.Cell): | |||
| sens (Number): The adjust parameter. Default: 1.0. | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(BertTrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = None | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("gradients_mean") | |||
| degree = get_group_size() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) | |||
| self.cast = P.Cast() | |||
| self.hyper_map = C.HyperMap() | |||
| @@ -322,9 +306,7 @@ class BertTrainOneStepCell(nn.Cell): | |||
| self.cast(F.tuple_to_array((self.sens,)), | |||
| mstype.float32)) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.grad_reducer(grads) | |||
| succ = self.optimizer(grads) | |||
| return F.depend(loss, succ) | |||
| @@ -289,23 +289,7 @@ class BertTrainOneStepCell(nn.Cell): | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(BertTrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = None | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("gradients_mean") | |||
| degree = get_group_size() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) | |||
| self.cast = P.Cast() | |||
| self.hyper_map = C.HyperMap() | |||
| @@ -340,9 +324,7 @@ class BertTrainOneStepCell(nn.Cell): | |||
| self.cast(F.tuple_to_array((self.sens,)), | |||
| mstype.float32)) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.grad_reducer(grads) | |||
| succ = self.optimizer(grads) | |||
| return F.depend(loss, succ) | |||