Browse Source

fix bug in sparse proximal ada grad

tags/v0.6.0-beta
wangnan39@huawei.com 5 years ago
parent
commit
19762375a5
2 changed files with 13 additions and 1 deletions
  1. +12
    -0
      mindspore/nn/learning_rate_schedule.py
  2. +1
    -1
      mindspore/nn/optim/proximal_ada_grad.py

+ 12
- 0
mindspore/nn/learning_rate_schedule.py View File

@@ -24,10 +24,22 @@ from .._checkparam import Rel




class LearningRateSchedule(Cell): class LearningRateSchedule(Cell):
"""Basic class of learning rate schedule."""
def __init__(self): def __init__(self):
super(LearningRateSchedule, self).__init__() super(LearningRateSchedule, self).__init__()


def construct(self, global_step): def construct(self, global_step):
"""
Defines the computation to get the current learning rate.

This method should be overridden by all subclasses.

Note:
The output should be a Tensor of scalar.

Inputs:
Tensor. The current step number.
"""
raise NotImplementedError raise NotImplementedError






+ 1
- 1
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -24,7 +24,7 @@ _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor", @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor",
"Tensor") "Tensor")
def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum):
def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
"""Apply sparse proximal_ada_grad optimizer to the weight parameter.""" """Apply sparse proximal_ada_grad optimizer to the weight parameter."""
success = True success = True
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices())) success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))


Loading…
Cancel
Save