|
|
|
@@ -2141,6 +2141,80 @@ class SparseApplyAdagrad(PrimitiveWithInfer): |
|
|
|
return var_type |
|
|
|
|
|
|
|
|
|
|
|
class SparseApplyFtrlD(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
Conduct experiment on updating on parameters related to FTRL optimization algorithm. |
|
|
|
|
|
|
|
.. math :: |
|
|
|
\text{accum} = \text{grad} * \text{grad} |
|
|
|
|
|
|
|
.. math :: |
|
|
|
\text{linear} += \text{grad} + (\text{accum} ^ {\text{-lr_power}} - |
|
|
|
\frac{\text{accum} ^ \text{-lr_power}}{\text{lr}} * \text{var}) |
|
|
|
|
|
|
|
.. math :: |
|
|
|
\text{quadratic} = {\text{1.0}/({\text{accum}^\text{lr_power} * \text{lr}}) + 2*\text{l2} |
|
|
|
|
|
|
|
.. math :: |
|
|
|
\text{var} = {\text{sign}({linear}) * \text{l1} - \text{linear}})/{ quadratic } |
|
|
|
if \vert linear \vert > l1 \ else \ 0.0 |
|
|
|
|
|
|
|
Args: |
|
|
|
lr (float): Learning rate. |
|
|
|
l1 (float): temp value NO.1. |
|
|
|
l2 (float): temp value No.2. |
|
|
|
lr_power (float): temp value used as power number. |
|
|
|
use_locking (bool): If true, updating the var and accum tensors will be protected. Default: False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **var** (Tensor) - Variable to be update. The type must be float32. |
|
|
|
- **accum** (Tensor) - Accum to be update. The shape must be the same as `var`'s shape, |
|
|
|
the type must be float32. |
|
|
|
- **linear** (Tensor) - Linear to be update. The shape must be the same as `var`'s shape, |
|
|
|
the type must be float32. |
|
|
|
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape, |
|
|
|
the type must be float32. |
|
|
|
- **indices** (Tensor) - A vector of indices into the first dimension of 'var' and 'accum', |
|
|
|
the shape of `indices` must be the same as `grad` in first dimension, the type must be int32. |
|
|
|
|
|
|
|
Output: |
|
|
|
Tensors, has the same shape and type as `var`. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, lr, l1, l2, lr_power, use_locking=False): |
|
|
|
"""init SparseApplyFtrlD""" |
|
|
|
self.lr = validator.check_type("lr", lr, [float]) |
|
|
|
self.l1 = validator.check_type("l1", l1, [float]) |
|
|
|
self.l2 = validator.check_type("l2", l2, [float]) |
|
|
|
self.lr_power = validator.check_type("lr_power", lr_power, [float]) |
|
|
|
self.use_locking = validator.check_type("use_locking", use_locaking, [bool]) |
|
|
|
|
|
|
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): |
|
|
|
validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape) |
|
|
|
validator.check_param_equal('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape)) |
|
|
|
validator.check_param_equal('len of var shape', len(var_shape), 'len of linear shape', len(linear_shape)) |
|
|
|
if len(var_shape) > 1: |
|
|
|
validator.check_param_equal('var_shape', var_shape[1:], 'grad_shape', grad_shape[1:]) |
|
|
|
validator.check_param_equal('var_shape', var_shape[1:], 'linear_shape', linear_shape[1:]) |
|
|
|
validator.check_integer("len of indices shape", len(indices_shape), 1, Rel.EQ) |
|
|
|
validator.check('the first dimension of grad', grad_shape[0], |
|
|
|
'the shape of indices', indices_shape[0], Rel.EQ) |
|
|
|
|
|
|
|
return var_shape |
|
|
|
|
|
|
|
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, indices_type): |
|
|
|
validator.check_subclass("var_type", var_type, mstype.tensor) |
|
|
|
validator.check_subclass("accum_type", accum_type, mstype.tensor) |
|
|
|
validator.check_subclass("linear_type", linear_type, mstype.tensor) |
|
|
|
validator.check_subclass("grad_type", grad_type, mstype.tensor) |
|
|
|
validator.check_subclass("indices_type", indices_type, mstype.tensor) |
|
|
|
validator.check_subclass('indices_type', indices_type, [mstype.int32]) |
|
|
|
|
|
|
|
return var_type |
|
|
|
|
|
|
|
|
|
|
|
class LARSUpdate(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Conduct lars (layer-wise adaptive rate scaling) update on the square sum of gradient. |
|
|
|
@@ -2244,4 +2318,4 @@ class ApplyFtrl(PrimitiveWithInfer): |
|
|
|
validator.check_typename("l1", l1_type,[mstype.float16, mstype.float32]) |
|
|
|
validator.check_typename("l2", l2_type,[mstype.float16, mstype.float32]) |
|
|
|
validator.check_typename("lr_power", lr_power_type,[mstype.float16, mstype.float32]) |
|
|
|
return var_type |
|
|
|
return var_type |