| @@ -88,7 +88,7 @@ class WithGradCell(Cell): | |||
| Run in PyNative mode. | |||
| Args: | |||
| network (Cell): The target network to wrap. | |||
| network (Cell): The target network to wrap. The network only supports single output. | |||
| loss_fn (Cell): Primitive loss function used to compute gradients. Default: None. | |||
| sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape | |||
| should be same as the `network` output. If None, we will fill one to a same type shape of | |||
| @@ -143,7 +143,7 @@ class TrainOneStepCell(Cell): | |||
| parallel modes are available for training. | |||
| Args: | |||
| network (Cell): The training network. | |||
| network (Cell): The training network. The network only supports single output. | |||
| optimizer (Cell): Optimizer for updating the weights. | |||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||
| @@ -49,6 +49,7 @@ grad_overflow = P.FloatStatus() | |||
| def _tensor_grad_overflow(grad): | |||
| return grad_overflow(grad) | |||
| class DynamicLossScaleUpdateCell(Cell): | |||
| r""" | |||
| Dynamic Loss scale update cell. | |||
| @@ -168,27 +169,26 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update | |||
| Cell as args. The loss scale value can be updated in both host side or device side. The | |||
| TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input | |||
| data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should | |||
| be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`. | |||
| If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored. | |||
| TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data. | |||
| The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side, | |||
| the value should be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic | |||
| should be provied by Cell type of `scale_sense`. If Cell type of `scale_sense` is not None and Tensor type | |||
| of `scale_sense` is provided, the Cell type of `scale_sense` will be ignored. | |||
| Args: | |||
| network (Cell): The training network. | |||
| network (Cell): The training network. The network only supports single output. | |||
| optimizer (Cell): Optimizer for updating the weights. | |||
| scale_update_cell(Cell): The loss scaling update logic cell. Default: None. | |||
| scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value | |||
| is Tensor type, Tensor with shape :math:`()`. Default: None. | |||
| Inputs: | |||
| - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||
| - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||
| - **scaling_sens** (Tensor) - Tensor of shape :math:`()`. | |||
| - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. | |||
| Outputs: | |||
| Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value. | |||
| - **loss** (Tensor) - Tensor with shape :math:`()`. | |||
| - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. | |||
| - **loss_scale** (Tensor) - Tensor with shape :math:`()`. | |||
| Examples: | |||
| >>> net_with_loss = Net() | |||
| @@ -203,7 +203,7 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| >>> output = train_network(inputs, label, scaling_sens) | |||
| """ | |||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||
| def __init__(self, network, optimizer, scale_sense=None): | |||
| super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| @@ -236,29 +236,29 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | |||
| self.loss_scale = None | |||
| self.loss_scaling_manager = scale_update_cell | |||
| if scale_update_cell: | |||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||
| name="loss_scale") | |||
| self.scale_sense = None | |||
| self.loss_scaling_manager = None | |||
| if isinstance(scale_sense, Cell): | |||
| self.loss_scaling_manager = scale_sense | |||
| self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), | |||
| name="scale_sense") | |||
| if isinstance(scale_sense, Tensor): | |||
| self.scale_sense = Parameter(scale_sense, name='scale_sense') | |||
| @C.add_flags(has_effect=True) | |||
| def construct(self, data, label, sens=None): | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| loss = self.network(data, label) | |||
| loss = self.network(*inputs) | |||
| init = False | |||
| if not self.gpu_target: | |||
| # init overflow buffer | |||
| init = self.alloc_status() | |||
| # clear overflow buffer | |||
| self.clear_status(init) | |||
| if sens is None: | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| scaling_sens = self.scale_sense | |||
| scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | |||
| grads = self.grad(self.network, weights)(data, label, scaling_sens_filled) | |||
| grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) | |||
| grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| @@ -279,8 +279,8 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| else: | |||
| cond = self.less_equal(self.base, flag_sum) | |||
| overflow = cond | |||
| if sens is None: | |||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||
| if self.loss_scaling_manager is not None: | |||
| overflow = self.loss_scaling_manager(self.scale_sense, cond) | |||
| # if there is no overflow, do optimize | |||
| if overflow: | |||
| opt = False | |||
| @@ -288,3 +288,9 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| opt = self.optimizer(grads) | |||
| ret = (loss, cond, scaling_sens) | |||
| return F.depend(ret, opt) | |||
| def set_sense_scale(self, sens): | |||
| """If the user has set the sens in the training process and wants to reassign the value, he can call | |||
| this function again to make modification, and sens needs to be of type Tensor.""" | |||
| if self.scale_sense and isinstance(sens, Tensor): | |||
| self.self.scale_sense.set_data(sens) | |||
| @@ -182,7 +182,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||
| "are supported in current version. If you use `O2` option, please" | |||
| "use `loss_scale_manager=None` or `FixedLossScaleManager`") | |||
| network = nn.TrainOneStepWithLossScaleCell(network, optimizer, | |||
| scale_update_cell=update_cell).set_train() | |||
| scale_sense=update_cell).set_train() | |||
| return network | |||
| network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() | |||
| return network | |||
| @@ -34,7 +34,6 @@ from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from ..context import ParallelMode | |||
| from ..parallel._utils import _need_to_full, _to_full_tensor | |||
| from ..parallel._cost_model_context import _set_multi_subgraphs | |||
| from ..common import dtype as mstype | |||
| from .dataset_helper import DatasetHelper, connect_network_with_dataset | |||
| from . import amp | |||
| @@ -489,11 +488,6 @@ class Model: | |||
| "return two elements, but got {}".format(len_element)) | |||
| cb_params.cur_step_num += 1 | |||
| overflow = False | |||
| if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): | |||
| scaling_sens = self._get_scaling_sens() | |||
| next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) | |||
| cb_params.train_dataset_element = next_element | |||
| list_callback.step_begin(run_context) | |||
| outputs = self._train_network(*next_element) | |||
| @@ -148,7 +148,6 @@ class MSELoss(nn.Cell): | |||
| def test_loss_scale_fp16_lr_overflow(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) | |||
| lr = Tensor(np.ones([1], np.float32) * 0.1) | |||
| net = NetFP16(16, 16) | |||
| net.set_train() | |||
| @@ -157,9 +156,11 @@ def test_loss_scale_fp16_lr_overflow(): | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) | |||
| output_1 = train_network(inputs, label, scaling_sens) | |||
| output_2 = train_network(inputs, label, scaling_sens) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, | |||
| scale_sense=Tensor(np.full((1), np.finfo(np.float32).max), | |||
| dtype=mstype.float32)) | |||
| output_1 = train_network(inputs, label) | |||
| output_2 = train_network(inputs, label) | |||
| assert output_1[0].asnumpy() == output_2[0].asnumpy() | |||
| assert output_1[1].asnumpy() == output_2[1].asnumpy() == True | |||
| @@ -188,16 +189,17 @@ def test_loss_scale_fp16_model_train_overflow(): | |||
| def test_loss_scale_fp16_opt_rmsprop_overflow(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full(1, np.finfo(np.float32).max), dtype=mstype.float32) | |||
| net = NetFP16(16, 16) | |||
| net.set_train() | |||
| loss = MSELoss() | |||
| optimizer = RMSProp(net.trainable_params(), learning_rate=0.1) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) | |||
| output_1 = train_network(inputs, label, scaling_sens) | |||
| output_2 = train_network(inputs, label, scaling_sens) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, | |||
| scale_sense=Tensor(np.full(1, np.finfo(np.float32).max), | |||
| dtype=mstype.float32)) | |||
| output_1 = train_network(inputs, label) | |||
| output_2 = train_network(inputs, label) | |||
| assert output_1[0].asnumpy() == output_2[0].asnumpy() | |||
| assert output_1[1].asnumpy() == output_2[1].asnumpy() == True | |||
| @@ -208,7 +210,6 @@ def test_loss_scale_fp16_opt_rmsprop_overflow(): | |||
| def test_loss_scale_fp16_overflow(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) | |||
| net = NetFP16(16, 16) | |||
| net.set_train() | |||
| @@ -216,8 +217,10 @@ def test_loss_scale_fp16_overflow(): | |||
| optimizer = Lamb(net.trainable_params(), learning_rate=0.01) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| net_with_loss.set_grad() | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) | |||
| output_1 = train_network(inputs, label, scaling_sens) | |||
| output_2 = train_network(inputs, label, scaling_sens) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, | |||
| scale_sense=Tensor(np.full((1), np.finfo(np.float32).max), | |||
| dtype=mstype.float32)) | |||
| output_1 = train_network(inputs, label) | |||
| output_2 = train_network(inputs, label) | |||
| assert output_1[0].asnumpy() == output_2[0].asnumpy() | |||
| assert output_1[1].asnumpy() == output_2[1].asnumpy() == True | |||
| @@ -177,7 +177,7 @@ def test_compile_grad_error(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| scale_manager = DynamicLossScaleManager() | |||
| update_cell = scale_manager.get_update_cell() | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) | |||
| train_network.set_train() | |||
| with pytest.raises(TypeError) as e: | |||
| train_network(inputs, label) | |||
| @@ -100,70 +100,71 @@ class MSELoss(nn.Cell): | |||
| def test_momentum_compile(): | |||
| inputs = Tensor(np.ones([15, 1]).astype(np.float32)) | |||
| label = Tensor(np.zeros([15, 1]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) | |||
| net = Net(1, 1) | |||
| loss = MSELoss() | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, | |||
| scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32)) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label, scaling_sens) | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| def test_compile_fp16_not_overflow(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) | |||
| net = NetFP16(16, 16) | |||
| loss = MSELoss() | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, | |||
| scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32)) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label, scaling_sens) | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| def test_compile_fp16_lr_overflow(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) | |||
| lr = Tensor(np.ones([1], np.float32) * 0.1) | |||
| net = NetFP16(16, 16) | |||
| loss = MSELoss() | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, | |||
| scale_sense=Tensor(np.full((1), np.finfo(np.float32).max), | |||
| dtype=mstype.float32)) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label, scaling_sens) | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| def test_compile_fp16_overflow(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) | |||
| net = NetFP16(16, 16) | |||
| loss = MSELoss() | |||
| optimizer = Lamb(net.trainable_params(), learning_rate=0.01) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, | |||
| scale_sense=Tensor(np.full((1), np.finfo(np.float32).max), | |||
| dtype=mstype.float32)) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label, scaling_sens) | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| def test_compile_fp16_lr_overflow_with_lossscale_update(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) | |||
| lr = Tensor(np.ones([1], np.float32) * 0.1) | |||
| net = NetFP16(16, 16) | |||
| loss = MSELoss() | |||
| @@ -172,9 +173,9 @@ def test_compile_fp16_lr_overflow_with_lossscale_update(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| scale_manager = DynamicLossScaleManager() | |||
| manager = scale_manager.get_update_cell() | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label, scaling_sens) | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| @@ -209,7 +210,6 @@ def test_compile_f16_model_train_fixed(): | |||
| def test_compile_fp16_lr_overflow_fixed_feed(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) | |||
| lr = Tensor(np.ones([1], np.float32) * 0.1) | |||
| net = NetFP16(16, 16) | |||
| loss = MSELoss() | |||
| @@ -218,16 +218,15 @@ def test_compile_fp16_lr_overflow_fixed_feed(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| scale_manager = FixedLossScaleManager() | |||
| update_cell = scale_manager.get_update_cell() | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label, scaling_sens) | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| def test_compile_fp16_lr_overflow_dynamic_feed(): | |||
| inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) | |||
| lr = Tensor(np.ones([1], np.float32) * 0.1) | |||
| net = NetFP16(16, 16) | |||
| loss = MSELoss() | |||
| @@ -236,9 +235,9 @@ def test_compile_fp16_lr_overflow_dynamic_feed(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| scale_manager = DynamicLossScaleManager() | |||
| update_cell = scale_manager.get_update_cell() | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label, scaling_sens) | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| @@ -253,7 +252,7 @@ def test_compile_fp16_lr_overflow_fixed_graph(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| scale_manager = FixedLossScaleManager(drop_overflow_update=True) | |||
| update_cell = scale_manager.get_update_cell() | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| @@ -270,7 +269,7 @@ def test_compile_fp16_lr_overflow_dynamic_graph(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| scale_manager = DynamicLossScaleManager() | |||
| update_cell = scale_manager.get_update_cell() | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| @@ -279,7 +278,6 @@ def test_compile_fp16_lr_overflow_dynamic_graph(): | |||
| def adam_compile(loss_scale=1.0): | |||
| inputs = Tensor(np.ones([15, 1]).astype(np.float32)) | |||
| label = Tensor(np.zeros([15, 1]).astype(np.float32)) | |||
| scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) | |||
| net = Net(1, 1) | |||
| loss = MSELoss() | |||
| @@ -287,14 +285,17 @@ def adam_compile(loss_scale=1.0): | |||
| use_nesterov=False, weight_decay=0.0, loss_scale=loss_scale) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) | |||
| train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, | |||
| scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32)) | |||
| train_network.set_train() | |||
| output = train_network(inputs, label, scaling_sens) | |||
| output = train_network(inputs, label) | |||
| print("the result is ", output) | |||
| def test_adam_compile(): | |||
| adam_compile() | |||
| def test_adam_loss_scale_compile(): | |||
| """ test setting loss_scale to 1e-40 """ | |||
| adam_compile(loss_scale=1e-40) | |||