diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index b3d4bc8a0e..f0688a9b47 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -164,7 +164,7 @@ class Adam(Optimizer): The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse - behavior is currently performed on the CPU, weight decay is not supported. + behavior is currently performed on the CPU. Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index a40d6737cb..b840b89241 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -73,7 +73,7 @@ class FTRL(Optimizer): Note: The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse - behavior is currently performed on the CPU, weight decay is not supported. + behavior is currently performed on the CPU. Args: params (list[Parameter]): A list of parameter, which will be updated. The element in `params` @@ -124,7 +124,7 @@ class FTRL(Optimizer): linear = self.linear lr = self.learning_rate if self.weight_decay > 0.0: - grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) + grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) grads = self.scale_grad(grads) success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index 48d33bf798..7d53aad488 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -94,8 +94,7 @@ class LazyAdam(Optimizer): The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the `sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under - continuous development. The sparse behavior is currently performed on the CPU, weight decay is - not supported. + continuous development. The sparse behavior is currently performed on the CPU. Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index a811edcabc..5b13d7cfbd 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -195,12 +195,12 @@ class Optimizer(Cell): params = self.parameters if self.is_group: if self.exec_weight_decay: - gradients = self.hyper_map(F.partial(_apply_decay), self.weight_decay, self.decay_flags, - params, gradients) + gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags, + params, gradients) else: if self.weight_decay > 0: - gradients = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_flags, - params, gradients) + gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags, + params, gradients) return gradients @@ -479,10 +479,20 @@ class Optimizer(Cell): op_add = P.AddN() +op_gather = P.GatherV2() _apply_decay = C.MultitypeFuncGraph("apply_decay") +@_apply_decay.register("Number", "Bool", "Tensor", "Tuple") +def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): + """Get grad with weight_decay.""" + if if_apply: + weight = op_gather(weight, gradient[0], 0) + return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2] + return gradient + + @_apply_decay.register("Number", "Bool", "Tensor", "Tensor") def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 380720404a..795cf8ab05 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -60,7 +60,7 @@ class ProximalAdagrad(Optimizer): Note: The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse - behavior is currently performed on the CPU, weight decay is not supported. + behavior is currently performed on the CPU. Args: params (list[Parameter]): A list of parameter, which will be updated. The element in `params` diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index 3fd18b9664..b435bf65b9 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -107,7 +107,7 @@ def test_sparse_adam_compile(): net = NetWithSparseGatherV2() net.set_train() - optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0) + optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) train_network = TrainOneStepCell(net, optimizer) _executor.compile(train_network, indices, label) diff --git a/tests/ut/python/nn/optim/test_ftrl.py b/tests/ut/python/nn/optim/test_ftrl.py index f0f094c177..de59dfdbad 100644 --- a/tests/ut/python/nn/optim/test_ftrl.py +++ b/tests/ut/python/nn/optim/test_ftrl.py @@ -71,6 +71,6 @@ def test_spares_ftrl_compile(): net = NetWithSparseGatherV2() net.set_train() - optimizer = FTRL(net.trainable_params(), loss_scale=2.0) + optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0) train_network = TrainOneStepCell(net, optimizer) _executor.compile(train_network, indices, label) diff --git a/tests/ut/python/nn/optim/test_lazyadam.py b/tests/ut/python/nn/optim/test_lazyadam.py index 713fffc50d..ce66b404e2 100644 --- a/tests/ut/python/nn/optim/test_lazyadam.py +++ b/tests/ut/python/nn/optim/test_lazyadam.py @@ -75,7 +75,7 @@ def test_spares_lazy_adam_compile(): net = NetWithSparseGatherV2() net.set_train() - optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, loss_scale=2.0) + optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0) train_network = TrainOneStepCell(net, optimizer) _executor.compile(train_network, indices, label) diff --git a/tests/ut/python/nn/optim/test_proximal_ada_grad.py b/tests/ut/python/nn/optim/test_proximal_ada_grad.py index a43a4ad23d..c7e6d3f88a 100644 --- a/tests/ut/python/nn/optim/test_proximal_ada_grad.py +++ b/tests/ut/python/nn/optim/test_proximal_ada_grad.py @@ -57,7 +57,7 @@ def test_proximal_ada_grad(): net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = ProximalAdagrad(net.trainable_params()) + optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) @@ -70,6 +70,6 @@ def test_spares_proximal_ada_grad_compile(): net = NetWithSparseGatherV2() net.set_train() - optimizer = ProximalAdagrad(net.trainable_params(), loss_scale=2.0) + optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0) train_network = TrainOneStepCell(net, optimizer) _executor.compile(train_network, indices, label) diff --git a/tests/ut/python/nn/optim/test_rmsprop.py b/tests/ut/python/nn/optim/test_rmsprop.py index 2e3fc90f5f..683220eefe 100644 --- a/tests/ut/python/nn/optim/test_rmsprop.py +++ b/tests/ut/python/nn/optim/test_rmsprop.py @@ -57,7 +57,7 @@ def test_rmsprop_compile(): def test_rmsprop_e(): net = Net() with pytest.raises(ValueError): - RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1) + RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1, weight_decay=0.9) with pytest.raises(TypeError): - RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1) + RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1, weight_decay=0.9)