|
|
@@ -27,25 +27,40 @@ from .optimizer import Optimizer |
|
|
_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") |
|
|
_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", |
|
|
|
|
|
"RowTensor", "Tensor", "Tensor", "Tensor") |
|
|
|
|
|
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, |
|
|
|
|
|
moment1, moment2): |
|
|
|
|
|
|
|
|
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", |
|
|
|
|
|
"Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") |
|
|
|
|
|
def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, |
|
|
|
|
|
lr, gradient, params, moment1, moment2, ps_parameter): |
|
|
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" |
|
|
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" |
|
|
success = True |
|
|
success = True |
|
|
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, |
|
|
|
|
|
eps, gradient.values, gradient.indices)) |
|
|
|
|
|
|
|
|
indices = gradient.indices |
|
|
|
|
|
values = gradient.values |
|
|
|
|
|
if ps_parameter: |
|
|
|
|
|
op_shape = P.Shape() |
|
|
|
|
|
shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), |
|
|
|
|
|
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), |
|
|
|
|
|
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) |
|
|
|
|
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, |
|
|
|
|
|
eps, values, indices), shapes), params)) |
|
|
|
|
|
else: |
|
|
|
|
|
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, |
|
|
|
|
|
eps, values, indices)) |
|
|
return success |
|
|
return success |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", |
|
|
|
|
|
"Tensor", "Tensor", "Tensor") |
|
|
|
|
|
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, |
|
|
|
|
|
moment1, moment2): |
|
|
|
|
|
"""Apply adam optimizer to the weight parameter using Tensor.""" |
|
|
|
|
|
|
|
|
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", |
|
|
|
|
|
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") |
|
|
|
|
|
def _run_opt_with_one_number(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, |
|
|
|
|
|
lr, gradient, params, moment1, moment2, ps_parameter): |
|
|
|
|
|
"""Apply lazy adam optimizer to the weight parameter using Tensor.""" |
|
|
success = True |
|
|
success = True |
|
|
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, |
|
|
|
|
|
eps, gradient)) |
|
|
|
|
|
|
|
|
if ps_parameter: |
|
|
|
|
|
op_shape = P.Shape() |
|
|
|
|
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), |
|
|
|
|
|
(op_shape(params), op_shape(moment1), op_shape(moment2))), params)) |
|
|
|
|
|
else: |
|
|
|
|
|
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, |
|
|
|
|
|
eps, gradient)) |
|
|
return success |
|
|
return success |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -173,7 +188,7 @@ class LazyAdam(Optimizer): |
|
|
self.beta2 = Tensor(beta2, mstype.float32) |
|
|
self.beta2 = Tensor(beta2, mstype.float32) |
|
|
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") |
|
|
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") |
|
|
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") |
|
|
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") |
|
|
self.eps = eps |
|
|
|
|
|
|
|
|
self.eps = Tensor(eps, mstype.float32) |
|
|
self.use_nesterov = use_nesterov |
|
|
self.use_nesterov = use_nesterov |
|
|
self.use_locking = use_locking |
|
|
self.use_locking = use_locking |
|
|
|
|
|
|
|
|
@@ -184,6 +199,10 @@ class LazyAdam(Optimizer): |
|
|
self.opt = P.Adam(use_locking, use_nesterov) |
|
|
self.opt = P.Adam(use_locking, use_nesterov) |
|
|
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov) |
|
|
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov) |
|
|
|
|
|
|
|
|
|
|
|
self._ps_pull = P.Pull() |
|
|
|
|
|
self._ps_push = P.Push("Adam", [0, 1, 2]) |
|
|
|
|
|
self._ps_push.add_prim_attr("use_nesterov", use_nesterov) |
|
|
|
|
|
|
|
|
def construct(self, gradients): |
|
|
def construct(self, gradients): |
|
|
gradients = self.decay_weight(gradients) |
|
|
gradients = self.decay_weight(gradients) |
|
|
gradients = self.scale_grad(gradients) |
|
|
gradients = self.scale_grad(gradients) |
|
|
@@ -193,11 +212,11 @@ class LazyAdam(Optimizer): |
|
|
self.beta2_power = self.beta2_power * self.beta2 |
|
|
self.beta2_power = self.beta2_power * self.beta2 |
|
|
|
|
|
|
|
|
if self.is_group_lr: |
|
|
if self.is_group_lr: |
|
|
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power, |
|
|
|
|
|
self.beta2_power, self.beta1, self.beta2, self.eps), |
|
|
|
|
|
lr, gradients, self.parameters, self.moment1, self.moment2) |
|
|
|
|
|
|
|
|
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, |
|
|
|
|
|
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps), |
|
|
|
|
|
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) |
|
|
else: |
|
|
else: |
|
|
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power, |
|
|
|
|
|
self.beta2_power, self.beta1, self.beta2, self.eps, lr), |
|
|
|
|
|
gradients, self.parameters, self.moment1, self.moment2) |
|
|
|
|
|
|
|
|
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, |
|
|
|
|
|
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr), |
|
|
|
|
|
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) |
|
|
return success |
|
|
return success |