| @@ -27,6 +27,7 @@ from .lars import LARS | |||||
| from .ftrl import FTRL | from .ftrl import FTRL | ||||
| from .rmsprop import RMSProp | from .rmsprop import RMSProp | ||||
| from .proximal_ada_grad import ProximalAdagrad | from .proximal_ada_grad import ProximalAdagrad | ||||
| from .lazyadam import LazyAdam | |||||
| __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', | |||||
| __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', | |||||
| 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad'] | 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad'] | ||||
| @@ -101,10 +101,21 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po | |||||
| validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) | validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) | ||||
| @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor") | |||||
| def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, | |||||
| moment2): | |||||
| @adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", | |||||
| "Tensor", "Tensor", "Tensor") | |||||
| def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||||
| moment1, moment2): | |||||
| """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | |||||
| success = True | |||||
| success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | |||||
| eps, gradient[1], gradient[0])) | |||||
| return success | |||||
| @adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor", "Tensor") | |||||
| def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||||
| moment1, moment2): | |||||
| """Apply adam optimizer to the weight parameter using Tensor.""" | """Apply adam optimizer to the weight parameter using Tensor.""" | ||||
| success = True | success = True | ||||
| success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | ||||
| @@ -144,6 +155,10 @@ class Adam(Optimizer): | |||||
| To improve parameter groups performance, the customized order of parameters can be supported. | To improve parameter groups performance, the customized order of parameters can be supported. | ||||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the | |||||
| `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse | |||||
| behavior is currently performed on the CPU, weight decay and loss scale is not supported. | |||||
| Args: | Args: | ||||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | ||||
| the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", | the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", | ||||
| @@ -231,12 +246,9 @@ class Adam(Optimizer): | |||||
| self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') | self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.map_ = C.Map() | |||||
| self.opt = P.Adam(use_locking, use_nesterov) | self.opt = P.Adam(use_locking, use_nesterov) | ||||
| self.pow = P.Pow() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.one = Tensor(np.array([1.0]).astype(np.float32)) | |||||
| self.realdiv = P.RealDiv() | |||||
| self.sparse_opt = P.SparseApplyAdam() | |||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| params = self.parameters | params = self.parameters | ||||
| @@ -251,13 +263,13 @@ class Adam(Optimizer): | |||||
| beta2_power = self.beta2_power * self.beta2 | beta2_power = self.beta2_power * self.beta2 | ||||
| self.beta2_power = beta2_power | self.beta2_power = beta2_power | ||||
| if self.is_group_lr: | if self.is_group_lr: | ||||
| success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, | |||||
| self.beta2, self.eps), | |||||
| lr, gradients, params, moment1, moment2) | |||||
| success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, | |||||
| self.beta1, self.beta2, self.eps), | |||||
| lr, gradients, params, moment1, moment2) | |||||
| else: | else: | ||||
| success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, | |||||
| self.beta2, self.eps, lr), | |||||
| gradients, params, moment1, moment2) | |||||
| success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power, | |||||
| self.beta1, self.beta2, self.eps, lr), | |||||
| gradients, params, moment1, moment2) | |||||
| return success | return success | ||||
| @@ -23,8 +23,18 @@ from .optimizer import Optimizer, apply_decay, grad_scale | |||||
| ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | ||||
| @ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") | |||||
| def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): | |||||
| @ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", | |||||
| "Tensor") | |||||
| def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): | |||||
| """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" | |||||
| success = True | |||||
| success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0])) | |||||
| return success | |||||
| @ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor") | |||||
| def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): | |||||
| """Apply ftrl optimizer to the weight parameter.""" | """Apply ftrl optimizer to the weight parameter.""" | ||||
| success = True | success = True | ||||
| success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) | success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) | ||||
| @@ -67,6 +77,11 @@ class FTRL(Optimizer): | |||||
| <https://arxiv.org/abs/1002.4908>`_. Refer to paper `Ad Click Prediction: a View from the Trenches | <https://arxiv.org/abs/1002.4908>`_. Refer to paper `Ad Click Prediction: a View from the Trenches | ||||
| <https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document. | <https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document. | ||||
| Note: | |||||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the | |||||
| `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse | |||||
| behavior is currently performed on the CPU, weight decay and loss scale is not supported. | |||||
| Args: | Args: | ||||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `params` | params (list[Parameter]): A list of parameter, which will be updated. The element in `params` | ||||
| should be Parameter. | should be Parameter. | ||||
| @@ -109,8 +124,9 @@ class FTRL(Optimizer): | |||||
| self.weight_decay = weight_decay | self.weight_decay = weight_decay | ||||
| self.decay_tf = tuple((lambda: True)() for x in self.parameters) | self.decay_tf = tuple((lambda: True)() for x in self.parameters) | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.map_ = C.Map() | |||||
| self.opt = P.ApplyFtrl(use_locking=use_locking) | self.opt = P.ApplyFtrl(use_locking=use_locking) | ||||
| self.one = Tensor(1, mstype.int32) | |||||
| self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) | |||||
| def construct(self, grads): | def construct(self, grads): | ||||
| params = self.parameters | params = self.parameters | ||||
| @@ -121,6 +137,6 @@ class FTRL(Optimizer): | |||||
| if self.reciprocal_scale != 1.0: | if self.reciprocal_scale != 1.0: | ||||
| grads = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), grads) | grads = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), grads) | ||||
| lr = self.learning_rate | lr = self.learning_rate | ||||
| success = self.hyper_map(F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power), | |||||
| linear, grads, params, moments) | |||||
| success = self.map_(F.partial(ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), | |||||
| linear, grads, params, moments) | |||||
| return success | return success | ||||
| @@ -0,0 +1,202 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """lazy adam""" | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| from .optimizer import Optimizer | |||||
| lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") | |||||
| @lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple", | |||||
| "Tensor", "Tensor", "Tensor") | |||||
| def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||||
| moment1, moment2): | |||||
| """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" | |||||
| success = True | |||||
| success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | |||||
| eps, gradient[1], gradient[0])) | |||||
| return success | |||||
| @lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor", "Tensor") | |||||
| def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | |||||
| moment1, moment2): | |||||
| """Apply adam optimizer to the weight parameter using Tensor.""" | |||||
| success = True | |||||
| success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | |||||
| eps, gradient)) | |||||
| return success | |||||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||||
| """Check the type of inputs.""" | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||||
| validator.check_value_type("eps", eps, [float], prim_name) | |||||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| class LazyAdam(Optimizer): | |||||
| r""" | |||||
| Updates gradients by Adaptive Moment Estimation (Adam) algorithm. | |||||
| The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_. | |||||
| The updating formulas are as follows, | |||||
| .. math:: | |||||
| \begin{array}{ll} \\ | |||||
| m = \beta_1 * m + (1 - \beta_1) * g \\ | |||||
| v = \beta_2 * v + (1 - \beta_2) * g * g \\ | |||||
| l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\ | |||||
| w = w - l * \frac{m}{\sqrt{v} + \epsilon} | |||||
| \end{array} | |||||
| :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`, | |||||
| :math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent | |||||
| `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent | |||||
| `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`, | |||||
| :math:`\epsilon` represents `eps`. | |||||
| Note: | |||||
| The LazyAdam optimizer supports separating parameter groups. Different parameter groups can set different | |||||
| `learning_rate` and `weight_decay`. | |||||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||||
| value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be | |||||
| applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. | |||||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the | |||||
| `sparse_grad` of `Parameter` being set as True. The sparse behavior, to be notice, is not equivalent to the | |||||
| original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under | |||||
| continuous development. The sparse behavior is currently performed on the CPU, weight decay and loss scale is | |||||
| not supported. | |||||
| Args: | |||||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||||
| the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||||
| "lr" and "weight_decay" are the keys can be parsed. | |||||
| - params: Required. The value should be a list of `Parameter`. | |||||
| - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. | |||||
| If not, the `learning_rate` in the API will be used. | |||||
| - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay | |||||
| will be used. If not, the `weight_decay` in the API will be used. | |||||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||||
| use dynamic learning rate, then the i-th step will | |||||
| take the i-th value as the learning rate. | |||||
| When the learning_rate is float or learning_rate is a Tensor | |||||
| but the dims of the Tensor is 0, use fixed learning rate. | |||||
| Other cases are not supported. Default: 1e-3. | |||||
| beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: | |||||
| 0.9. | |||||
| beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: | |||||
| 0.999. | |||||
| eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: | |||||
| 1e-8. | |||||
| use_locking (bool): Whether to enable a lock to protect updating variable tensors. | |||||
| If True, updating of the var, m, and v tensors will be protected by a lock. | |||||
| If False, the result is unpredictable. Default: False. | |||||
| use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. | |||||
| If True, updates the gradients using NAG. | |||||
| If False, updates the gradients without using NAG. Default: False. | |||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||||
| loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: | |||||
| 1.0. | |||||
| Inputs: | |||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||||
| Outputs: | |||||
| Tensor[bool], the value is True. | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> #1) All parameters use the same learning rate and weight decay | |||||
| >>> optim = nn.LazyAdam(params=net.trainable_params()) | |||||
| >>> | |||||
| >>> #2) Use parameter groups and set different values | |||||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, | |||||
| >>> {'params': no_conv_params}] | |||||
| >>> opt = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0) | |||||
| >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 | |||||
| >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a | |||||
| >>> # learning rate of 0.1 and a weight decay of 0.0. | |||||
| >>> | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim) | |||||
| """ | |||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, | |||||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0): | |||||
| super(LazyAdam, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||||
| validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) | |||||
| validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) | |||||
| validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| self.beta1 = Tensor(beta1, mstype.float32) | |||||
| self.beta2 = Tensor(beta2, mstype.float32) | |||||
| self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") | |||||
| self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") | |||||
| self.eps = eps | |||||
| self.use_nesterov = use_nesterov | |||||
| self.use_locking = use_locking | |||||
| self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') | |||||
| self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.map_ = C.Map() | |||||
| self.opt = P.Adam(use_locking, use_nesterov) | |||||
| self.sparse_opt = P.SparseApplyLazyAdam(use_locking, use_nesterov) | |||||
| def construct(self, gradients): | |||||
| gradients = self.decay_weight(gradients) | |||||
| gradients = self.scale_grad(gradients) | |||||
| lr = self.get_lr() | |||||
| self.beta1_power = self.beta1_power * self.beta1 | |||||
| self.beta2_power = self.beta2_power * self.beta2 | |||||
| if self.is_group_lr: | |||||
| success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power, | |||||
| self.beta2_power, self.beta1, self.beta2, self.eps), | |||||
| lr, gradients, self.parameters, self.moment1, self.moment2) | |||||
| else: | |||||
| success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power, | |||||
| self.beta2_power, self.beta1, self.beta2, self.eps, lr), | |||||
| gradients, self.parameters, self.moment1, self.moment2) | |||||
| return success | |||||
| @@ -21,7 +21,7 @@ from mindspore import Tensor, Parameter | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | from mindspore.nn import TrainOneStepCell, WithLossCell | ||||
| from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR | |||||
| from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -50,6 +50,19 @@ class NetWithoutWeight(nn.Cell): | |||||
| return x | return x | ||||
| class NetWithSparseGatherV2(nn.Cell): | |||||
| """ NetWithSparseGatherV2 definition """ | |||||
| def __init__(self): | |||||
| super(NetWithSparseGatherV2, self).__init__() | |||||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True) | |||||
| self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") | |||||
| self.axis = 0 | |||||
| self.gather = P.SparseGatherV2() | |||||
| def construct(self, indices, label): | |||||
| return self.gather(self.weight1, indices, self.axis) + self.weight2 | |||||
| def test_adamwithoutparam(): | def test_adamwithoutparam(): | ||||
| net = NetWithoutWeight() | net = NetWithoutWeight() | ||||
| net.set_train() | net.set_train() | ||||
| @@ -72,6 +85,33 @@ def test_adamw_compile(): | |||||
| _executor.compile(train_network, inputs, label) | _executor.compile(train_network, inputs, label) | ||||
| def test_adam_compile(): | |||||
| """ test adam compile """ | |||||
| inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | |||||
| label = Tensor(np.zeros([1, 10]).astype(np.float32)) | |||||
| net = Net() | |||||
| net.set_train() | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optimizer = Adam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9) | |||||
| net_with_loss = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||||
| _executor.compile(train_network, inputs, label) | |||||
| def test_spares_adam_compile(): | |||||
| """ test_sparse_adam_compile """ | |||||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||||
| net = NetWithSparseGatherV2() | |||||
| net.set_train() | |||||
| optimizer = Adam(net.trainable_params(), learning_rate=0.1) | |||||
| train_network = TrainOneStepCell(net, optimizer) | |||||
| _executor.compile(train_network, indices, label) | |||||
| def test_AdamWeightDecay_beta1(): | def test_AdamWeightDecay_beta1(): | ||||
| net = Net() | net = Net() | ||||
| print("**********", net.get_parameters()) | print("**********", net.get_parameters()) | ||||
| @@ -37,6 +37,19 @@ class Net(nn.Cell): | |||||
| return x | return x | ||||
| class NetWithSparseGatherV2(nn.Cell): | |||||
| """ NetWithSparseGatherV2 definition """ | |||||
| def __init__(self): | |||||
| super(NetWithSparseGatherV2, self).__init__() | |||||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True) | |||||
| self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") | |||||
| self.axis = 0 | |||||
| self.gather = P.SparseGatherV2() | |||||
| def construct(self, indices, label): | |||||
| return self.gather(self.weight1, indices, self.axis) + self.weight2 | |||||
| def test_ftrl(): | def test_ftrl(): | ||||
| """ test_ftrl """ | """ test_ftrl """ | ||||
| inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | ||||
| @@ -48,3 +61,15 @@ def test_ftrl(): | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | train_network = TrainOneStepCell(net_with_loss, optimizer) | ||||
| _executor.compile(train_network, inputs, label) | _executor.compile(train_network, inputs, label) | ||||
| def test_spares_ftrl_compile(): | |||||
| """ test sparse ftrl compile """ | |||||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||||
| net = NetWithSparseGatherV2() | |||||
| net.set_train() | |||||
| optimizer = FTRL(net.trainable_params()) | |||||
| train_network = TrainOneStepCell(net, optimizer) | |||||
| _executor.compile(train_network, indices, label) | |||||
| @@ -0,0 +1,88 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test lazy adam """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import LazyAdam | |||||
| from mindspore.ops import operations as P | |||||
| class Net(nn.Cell): | |||||
| """ Net definition """ | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") | |||||
| self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias") | |||||
| self.matmul = P.MatMul() | |||||
| self.biasAdd = P.BiasAdd() | |||||
| def construct(self, x): | |||||
| x = self.biasAdd(self.matmul(x, self.weight), self.bias) | |||||
| return x | |||||
| class NetWithSparseGatherV2(nn.Cell): | |||||
| """ NetWithSparseGatherV2 definition """ | |||||
| def __init__(self): | |||||
| super(NetWithSparseGatherV2, self).__init__() | |||||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True) | |||||
| self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") | |||||
| self.axis = 0 | |||||
| self.gather = P.SparseGatherV2() | |||||
| def construct(self, indices, label): | |||||
| return self.gather(self.weight1, indices, self.axis) + self.weight2 | |||||
| def test_lazy_adam_compile(): | |||||
| """ test lazy adam compile """ | |||||
| inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | |||||
| label = Tensor(np.zeros([1, 10]).astype(np.float32)) | |||||
| net = Net() | |||||
| net.set_train() | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9) | |||||
| net_with_loss = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||||
| _executor.compile(train_network, inputs, label) | |||||
| def test_spares_lazy_adam_compile(): | |||||
| """ test sparse adam compile """ | |||||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||||
| net = NetWithSparseGatherV2() | |||||
| net.set_train() | |||||
| optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1) | |||||
| train_network = TrainOneStepCell(net, optimizer) | |||||
| _executor.compile(train_network, indices, label) | |||||
| def test_lazy_adam_error(): | |||||
| net = Net() | |||||
| with pytest.raises(ValueError): | |||||
| LazyAdam(net.get_parameters(), learning_rate=-0.1) | |||||
| with pytest.raises(TypeError): | |||||
| LazyAdam(net.get_parameters(), learning_rate=0.1, beta1=2) | |||||