|
|
|
@@ -3600,6 +3600,88 @@ class SparseApplyAdagrad(PrimitiveWithInfer): |
|
|
|
return var_type, accum_type |
|
|
|
|
|
|
|
|
|
|
|
class SparseApplyAdagradV2(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
Update relevant entries according to the adagrad scheme. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
accum += grad * grad |
|
|
|
.. math:: |
|
|
|
var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon} |
|
|
|
|
|
|
|
Args: |
|
|
|
lr (float): Learning rate. |
|
|
|
epsilon (float): A small value added for numerical stability. |
|
|
|
use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False. |
|
|
|
update_slots (bool): If `True`, the computation logic will be different to `False`. Default: True. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **var** (Parameter) - Variable to be updated. The type must be float32. |
|
|
|
- **accum** (Parameter) - Accum to be updated. The shape must be the same as `var`'s shape, |
|
|
|
the type must be float32. |
|
|
|
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except first dimension, |
|
|
|
the type must be float32. |
|
|
|
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. |
|
|
|
The shape of `indices` must be the same as `grad` in first dimension, the type must be int32. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tuple of 2 Tensor, the updated parameters. |
|
|
|
|
|
|
|
- **var** (Tensor) - The same shape and data type as `var`. |
|
|
|
- **accum** (Tensor) - The same shape and data type as `accum`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> import numpy as np |
|
|
|
>>> import mindspore.nn as nn |
|
|
|
>>> from mindspore import Tensor, Parameter |
|
|
|
>>> from mindspore.ops import operations as P |
|
|
|
>>> import mindspore.common.dtype as mstype |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(Net, self).__init__() |
|
|
|
>>> self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=1e-8, epsilon=1e-6) |
|
|
|
>>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") |
|
|
|
>>> self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") |
|
|
|
>>> |
|
|
|
>>> def construct(self, grad, indices): |
|
|
|
>>> out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices) |
|
|
|
>>> return out |
|
|
|
>>> net = Net() |
|
|
|
>>> grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) |
|
|
|
>>> indices = Tensor([0, 1, 2], mstype.int32) |
|
|
|
>>> result = net(grad, indices) |
|
|
|
""" |
|
|
|
|
|
|
|
__mindspore_signature__ = ( |
|
|
|
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) |
|
|
|
) |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, lr, epsilon, use_locking=False, update_slots=True): |
|
|
|
self.lr = validator.check_value_type("lr", lr, [float], self.name) |
|
|
|
self.epsilon = validator.check_value_type("epsilon", epsilon, [float], self.name) |
|
|
|
self.use_locking = validator.check_value_type("update_slots", update_slots, [bool], self.name) |
|
|
|
self.update_slots = validator.check_value_type("use_locking", use_locking, [bool], self.name) |
|
|
|
|
|
|
|
def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape): |
|
|
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) |
|
|
|
validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) |
|
|
|
if len(var_shape) > 1: |
|
|
|
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) |
|
|
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) |
|
|
|
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) |
|
|
|
return var_shape, accum_shape |
|
|
|
|
|
|
|
def infer_dtype(self, var_type, accum_type, grad_type, indices_type): |
|
|
|
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} |
|
|
|
validator.check_tensor_type_same(args, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) |
|
|
|
return var_type, accum_type |
|
|
|
|
|
|
|
|
|
|
|
class ApplyProximalAdagrad(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
Update relevant entries according to the proximal adagrad algorithm. |
|
|
|
@@ -3664,7 +3746,8 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=False): |
|
|
|
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) |
|
|
|
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], |
|
|
|
outputs=['var', 'accum']) |
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) |
|
|
|
|
|
|
|
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): |
|
|
|
@@ -4371,6 +4454,98 @@ class SparseApplyFtrl(PrimitiveWithInfer): |
|
|
|
return var_dtype, accum_dtype, linear_dtype |
|
|
|
|
|
|
|
|
|
|
|
class SparseApplyFtrlV2(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Update relevant entries according to the FTRL-proximal scheme. |
|
|
|
|
|
|
|
Args: |
|
|
|
lr (float): The learning rate value, must be positive. |
|
|
|
l1 (float): l1 regularization strength, must be greater than or equal to zero. |
|
|
|
l2 (float): l2 regularization strength, must be greater than or equal to zero. |
|
|
|
l2_shrinkage (float): L2 shrinkage regularization. |
|
|
|
lr_power (float): Learning rate power controls how the learning rate decreases during training, |
|
|
|
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero. |
|
|
|
use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **var** (Parameter) - The variable to be updated. The data type must be float32. |
|
|
|
- **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. |
|
|
|
- **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. |
|
|
|
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. |
|
|
|
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. |
|
|
|
The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tuple of 3 Tensor, the updated parameters. |
|
|
|
|
|
|
|
- **var** (Tensor): Tensor, has the same shape and type as `var`. |
|
|
|
- **accum** (Tensor): Tensor, has the same shape and type as `accum`. |
|
|
|
- **linear** (Tensor): Tensor, has the same shape and type as `linear`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> import mindspore |
|
|
|
>>> import mindspore.nn as nn |
|
|
|
>>> import numpy as np |
|
|
|
>>> from mindspore import Parameter |
|
|
|
>>> from mindspore import Tensor |
|
|
|
>>> from mindspore.ops import operations as P |
|
|
|
>>> class SparseApplyFtrlV2Net(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(SparseApplyFtrlV2Net, self).__init__() |
|
|
|
>>> self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, |
|
|
|
l2_shrinkage=0.0, lr_power=-0.5) |
|
|
|
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") |
|
|
|
>>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") |
|
|
|
>>> self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") |
|
|
|
>>> |
|
|
|
>>> def construct(self, grad, indices): |
|
|
|
>>> out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices) |
|
|
|
>>> return out |
|
|
|
>>> |
|
|
|
>>> net = SparseApplyFtrlV2Net() |
|
|
|
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) |
|
|
|
>>> indices = Tensor(np.ones([3]), mindspore.int32) |
|
|
|
>>> output = net(grad, indices) |
|
|
|
""" |
|
|
|
|
|
|
|
__mindspore_signature__ = ( |
|
|
|
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) |
|
|
|
) |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, lr, l1, l2, l2_shrinkage, lr_power, use_locking=False): |
|
|
|
validator.check_value_type("lr", lr, [float], self.name) |
|
|
|
validator.check_value_type("l1", l1, [float], self.name) |
|
|
|
validator.check_value_type("l2", l2, [float], self.name) |
|
|
|
validator.check_value_type("lr_power", lr_power, [float], self.name) |
|
|
|
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) |
|
|
|
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) |
|
|
|
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) |
|
|
|
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) |
|
|
|
self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name) |
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) |
|
|
|
|
|
|
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): |
|
|
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) |
|
|
|
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) |
|
|
|
if len(var_shape) > 1: |
|
|
|
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) |
|
|
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) |
|
|
|
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) |
|
|
|
return var_shape, accum_shape, linear_shape |
|
|
|
|
|
|
|
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): |
|
|
|
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, |
|
|
|
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype} |
|
|
|
validator.check_tensor_type_same(args, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) |
|
|
|
return var_dtype, accum_dtype, linear_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ConfusionMulGrad(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
`output0` is the result of which input0 dot multily input1. |
|
|
|
|