|
|
|
@@ -2478,6 +2478,27 @@ class LARSUpdate(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, representing the new gradient. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore import Tensor |
|
|
|
>>> from mindspore.ops import operations as P |
|
|
|
>>> from mindspore.ops import functional as F |
|
|
|
>>> import mindspore.nn as nn |
|
|
|
>>> import numpy as np |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(Net, self).__init__() |
|
|
|
>>> self.lars = P.LARSUpdate() |
|
|
|
>>> self.reduce = P.ReduceSum() |
|
|
|
>>> def construct(self, weight, gradient): |
|
|
|
>>> w_square_sum = self.reduce(F.square(weight)) |
|
|
|
>>> grad_square_sum = self.reduce(F.square(gradient)) |
|
|
|
>>> grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0) |
|
|
|
>>> return grad_t |
|
|
|
>>> weight = np.random.random(size=(2, 3)).astype(np.float32) |
|
|
|
>>> gradient = np.random.random(size=(2, 3)).astype(np.float32) |
|
|
|
>>> net = Net() |
|
|
|
>>> ms_output = net(Tensor(weight), Tensor(gradient)) |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
|