Merge pull request !6159 from wangnan39/optim_train_one_step_celltags/v1.0.0
| @@ -185,23 +185,21 @@ class TrainOneStepCell(Cell): | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | self.grad = C.GradOperation(get_by_list=True, sens_param=True) | ||||
| self.sens = sens | self.sens = sens | ||||
| self.reducer_flag = False | self.reducer_flag = False | ||||
| self.grad_reducer = None | |||||
| parallel_mode = _get_parallel_mode() | |||||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||||
| self.grad_reducer = F.identity | |||||
| self.parallel_mode = _get_parallel_mode() | |||||
| if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||||
| self.reducer_flag = True | self.reducer_flag = True | ||||
| if self.reducer_flag: | if self.reducer_flag: | ||||
| mean = _get_gradients_mean() | mean = _get_gradients_mean() | ||||
| degree = _get_device_num() | degree = _get_device_num() | ||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.grad_reducer = DistributedGradReducer(self.weights, mean, degree) | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| weights = self.weights | weights = self.weights | ||||
| loss = self.network(*inputs) | loss = self.network(*inputs) | ||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | ||||
| grads = self.grad(self.network, weights)(*inputs, sens) | grads = self.grad(self.network, weights)(*inputs, sens) | ||||
| if self.reducer_flag: | |||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(grads) | |||||
| grads = self.grad_reducer(grads) | |||||
| return F.depend(loss, self.optimizer(grads)) | return F.depend(loss, self.optimizer(grads)) | ||||
| @@ -14,9 +14,8 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Loss scale cell for loss scale training.""" | """Loss scale cell for loss scale training.""" | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean | |||||
| from .cell_wrapper import TrainOneStepCell | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from ...common import Tensor, RowTensor | from ...common import Tensor, RowTensor | ||||
| from ...common.parameter import Parameter | from ...common.parameter import Parameter | ||||
| @@ -163,7 +162,7 @@ class FixedLossScaleUpdateCell(Cell): | |||||
| return overflow | return overflow | ||||
| class TrainOneStepWithLossScaleCell(Cell): | |||||
| class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||||
| r""" | r""" | ||||
| Network training with loss scaling. | Network training with loss scaling. | ||||
| @@ -203,15 +202,8 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) | >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) | ||||
| >>> output = train_network(inputs, label, scaling_sens) | >>> output = train_network(inputs, label, scaling_sens) | ||||
| """ | """ | ||||
| def __init__(self, network, optimizer, scale_sense): | def __init__(self, network, optimizer, scale_sense): | ||||
| super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.network.add_flags(defer_inline=True) | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) | |||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| if context.get_context("device_target") == "GPU": | if context.get_context("device_target") == "GPU": | ||||
| self.gpu_target = True | self.gpu_target = True | ||||
| @@ -228,13 +220,6 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| self.less_equal = LessEqual() | self.less_equal = LessEqual() | ||||
| self.depend_parameter_use = ControlDepend(depend_mode=1) | self.depend_parameter_use = ControlDepend(depend_mode=1) | ||||
| self.allreduce = P.AllReduce() | self.allreduce = P.AllReduce() | ||||
| self.parallel_mode = _get_parallel_mode() | |||||
| self.grad_reducer = F.identity | |||||
| self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] | |||||
| if self.reducer_flag: | |||||
| mean = _get_gradients_mean() | |||||
| degree = _get_device_num() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | ||||
| self.loss_scaling_manager = None | self.loss_scaling_manager = None | ||||
| @@ -271,23 +271,7 @@ class BertTrainOneStepCell(nn.Cell): | |||||
| sens (Number): The adjust parameter. Default: 1.0. | sens (Number): The adjust parameter. Default: 1.0. | ||||
| """ | """ | ||||
| def __init__(self, network, optimizer, sens=1.0): | def __init__(self, network, optimizer, sens=1.0): | ||||
| super(BertTrainOneStepCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| self.sens = sens | |||||
| self.reducer_flag = False | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = None | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("gradients_mean") | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| @@ -322,9 +306,7 @@ class BertTrainOneStepCell(nn.Cell): | |||||
| self.cast(F.tuple_to_array((self.sens,)), | self.cast(F.tuple_to_array((self.sens,)), | ||||
| mstype.float32)) | mstype.float32)) | ||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | ||||
| if self.reducer_flag: | |||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(grads) | |||||
| grads = self.grad_reducer(grads) | |||||
| succ = self.optimizer(grads) | succ = self.optimizer(grads) | ||||
| return F.depend(loss, succ) | return F.depend(loss, succ) | ||||
| @@ -289,23 +289,7 @@ class BertTrainOneStepCell(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, network, optimizer, sens=1.0): | def __init__(self, network, optimizer, sens=1.0): | ||||
| super(BertTrainOneStepCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| self.sens = sens | |||||
| self.reducer_flag = False | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = None | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("gradients_mean") | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| @@ -340,9 +324,7 @@ class BertTrainOneStepCell(nn.Cell): | |||||
| self.cast(F.tuple_to_array((self.sens,)), | self.cast(F.tuple_to_array((self.sens,)), | ||||
| mstype.float32)) | mstype.float32)) | ||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | ||||
| if self.reducer_flag: | |||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(grads) | |||||
| grads = self.grad_reducer(grads) | |||||
| succ = self.optimizer(grads) | succ = self.optimizer(grads) | ||||
| return F.depend(loss, succ) | return F.depend(loss, succ) | ||||