Browse Source

support update parameter for vm

tags/v0.2.0-alpha
wangnan39@huawei.com 5 years ago
parent
commit
b812b18c02
15 changed files with 34 additions and 203 deletions
  1. +8
    -8
      mindspore/common/parameter.py
  2. +0
    -12
      mindspore/nn/optim/adam.py
  3. +1
    -1
      mindspore/nn/optim/ftrl.py
  4. +0
    -17
      mindspore/nn/optim/lars.py
  5. +3
    -17
      mindspore/nn/optim/momentum.py
  6. +2
    -0
      mindspore/nn/optim/optimizer.py
  7. +2
    -19
      mindspore/nn/optim/rmsprop.py
  8. +3
    -17
      mindspore/nn/optim/sgd.py
  9. +3
    -27
      mindspore/nn/wrap/cell_wrapper.py
  10. +8
    -3
      mindspore/train/serialization.py
  11. +0
    -4
      tests/ut/python/nn/test_cell_wrapper.py
  12. +0
    -74
      tests/ut/python/nn/test_parameter.py
  13. +1
    -1
      tests/ut/python/ops/test_momentum.py
  14. +1
    -1
      tests/ut/python/pynative_mode/test_cell_bprop.py
  15. +2
    -2
      tests/vm_impl/nn_ops_vm_impl.py

+ 8
- 8
mindspore/common/parameter.py View File

@@ -15,7 +15,6 @@


"""Parameter for cell.""" """Parameter for cell."""
from copy import copy, deepcopy from copy import copy, deepcopy
import numpy as np
from .initializer import initializer from .initializer import initializer
from .tensor import Tensor from .tensor import Tensor
from .._checkparam import _check_str_by_regular from .._checkparam import _check_str_by_regular
@@ -176,14 +175,15 @@ class Parameter:
return res return res


def set_parameter_data(self, data): def set_parameter_data(self, data):
if isinstance(data, (Tensor, list, int, float,
np.float16, np.float32, np.int32, np.int16, np.ndarray)) and not isinstance(data, bool):
if isinstance(data, Tensor):
# make a copy of Tensor to init the parameter
data = Tensor(data.asnumpy().copy())
self.default_input = data
"""Set `default_input` of current `Parameter`."""
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Tensor):
# make a copy of Tensor to init the parameter
data = Tensor(data.asnumpy().copy())
else: else:
raise ValueError("Parameter data must be tensor or number.")
data = Tensor(data)
self.default_input = data




class ParameterTuple(tuple): class ParameterTuple(tuple):


+ 0
- 12
mindspore/nn/optim/adam.py View File

@@ -101,17 +101,6 @@ def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, ep
return success return success




@adam_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor")
def _run_opt_with_two_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1,
moment2):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
return success


class Adam(Optimizer): class Adam(Optimizer):
r""" r"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm. Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
@@ -183,7 +172,6 @@ class Adam(Optimizer):
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')


self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov) self.opt = P.Adam(use_locking, use_nesterov)




+ 1
- 1
mindspore/nn/optim/ftrl.py View File

@@ -23,7 +23,7 @@ from mindspore._checkparam import Rel
from .optimizer import Optimizer, apply_decay, grad_scale from .optimizer import Optimizer, apply_decay, grad_scale


ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
@ftrl_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
@ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
"""Apply ftrl optimizer to the weight parameter.""" """Apply ftrl optimizer to the weight parameter."""
success = True success = True


+ 0
- 17
mindspore/nn/optim/lars.py View File

@@ -43,23 +43,6 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f
return gradient return gradient




@lars_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Bool", "Bool")
def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag):
"""Apply lars optimizer to the weight parameter."""
if lars_flag:
op_reduce = P.ReduceSum()
w_square_sum = op_reduce(F.square(weight))
grad_square_sum = op_reduce(F.square(gradient))
if decay_flag:
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate)
else:
num_zero = 0.0
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, num_zero, learning_rate)
return grad_t

return gradient


class LARS(Optimizer): class LARS(Optimizer):
""" """
Implements the LARS algorithm with LARSUpdate Operator. Implements the LARS algorithm with LARSUpdate Operator.


+ 3
- 17
mindspore/nn/optim/momentum.py View File

@@ -15,19 +15,13 @@
"""momentum""" """momentum"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from .optimizer import Optimizer from .optimizer import Optimizer


momentum_opt = C.MultitypeFuncGraph("momentum_opt") momentum_opt = C.MultitypeFuncGraph("momentum_opt")




@momentum_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success


@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor.""" """Apply momentum optimizer to the weight parameter using Tensor."""
@@ -36,14 +30,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
return success return success




@momentum_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_dyn(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success


class Momentum(Optimizer): class Momentum(Optimizer):
""" """
Implements the Momentum algorithm. Implements the Momentum algorithm.
@@ -86,7 +72,7 @@ class Momentum(Optimizer):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(momentum, name="momentum")
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros') self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()


+ 2
- 0
mindspore/nn/optim/optimizer.py View File

@@ -22,6 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@@ -64,6 +65,7 @@ class Optimizer(Cell):
self.assignadd = None self.assignadd = None
self.global_step = None self.global_step = None
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
learning_rate = Tensor(learning_rate, mstype.float32)
else: else:
self.dynamic_lr = True self.dynamic_lr = True
self.gather = P.GatherV2() self.gather = P.GatherV2()


+ 2
- 19
mindspore/nn/optim/rmsprop.py View File

@@ -21,34 +21,17 @@ rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")




@rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success


@rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") @rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success return success




@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success


@centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", @centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor") "Tensor", "Tensor")
def _centered_rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))


+ 3
- 17
mindspore/nn/optim/sgd.py View File

@@ -15,20 +15,14 @@
"""sgd""" """sgd"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer from .optimizer import Optimizer


sgd_opt = C.MultitypeFuncGraph("sgd_opt") sgd_opt = C.MultitypeFuncGraph("sgd_opt")




@sgd_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
return success


@sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat): def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter using Tensor.""" """Apply sgd optimizer to the weight parameter using Tensor."""
@@ -37,14 +31,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, s
return success return success




@sgd_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_dyn(opt, learning_rate, momentum, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
return success


class SGD(Optimizer): class SGD(Optimizer):
""" """
Implements stochastic gradient descent (optionally with momentum). Implements stochastic gradient descent (optionally with momentum).
@@ -105,7 +91,7 @@ class SGD(Optimizer):


self.opt = P.SGD(dampening, weight_decay, nesterov) self.opt = P.SGD(dampening, weight_decay, nesterov)


self.momentum = Parameter(momentum, name="momentum")
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.accum = self.parameters.clone(prefix="accum", init='zeros') self.accum = self.parameters.clone(prefix="accum", init='zeros')
self.stat = self.parameters.clone(prefix="stat", init='ones') self.stat = self.parameters.clone(prefix="stat", init='ones')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()


+ 3
- 27
mindspore/nn/wrap/cell_wrapper.py View File

@@ -13,17 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Cell_wrapper.""" """Cell_wrapper."""
import copy

import numpy as np

from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode

from ...common import Tensor
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.initializer import initializer
from ...common.parameter import Parameter, ParameterTuple from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C from ...ops import composite as C
from ...ops import functional as F from ...ops import functional as F
@@ -348,25 +341,8 @@ class ParameterUpdate(Cell):
super(ParameterUpdate, self).__init__(auto_prefix=False) super(ParameterUpdate, self).__init__(auto_prefix=False)
if not isinstance(param, Parameter): if not isinstance(param, Parameter):
raise TypeError("`param` must be `Parameter`, but got {}".format(param)) raise TypeError("`param` must be `Parameter`, but got {}".format(param))

default_input = param.default_input
if isinstance(default_input, Tensor):
shape = default_input.shape()
zero_dtype = default_input.dtype()
elif isinstance(default_input, float):
shape = [1]
zero_dtype = mstype.float32
elif isinstance(default_input, int):
shape = [1]
zero_dtype = mstype.int32
else:
raise TypeError("`default_input` in `param` must be Tensor, float or int, but got {}".format(default_input))

self._param = Parameter(initializer(copy.deepcopy(default_input), shape), param.name)
self._param.is_init = True
self._zero = Tensor(np.zeros(shape), zero_dtype)
self._param = param


def construct(self, x): def construct(self, x):
zero = self._param + self._zero
F.control_depend(zero, F.assign(self._param, x))
return zero
self._param = x
return x

+ 8
- 3
mindspore/train/serialization.py View File

@@ -36,7 +36,6 @@ tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype
tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64, tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64} "Float16": np.float16, "Float32": np.float32, "Float64": np.float64}



def _special_process_par(par, new_par): def _special_process_par(par, new_par):
""" """
Processes the special condition. Processes the special condition.
@@ -182,8 +181,14 @@ def load_checkpoint(ckpoint_file_name, net=None):
param_data = np.fromstring(data, np_type) param_data = np.fromstring(data, np_type)
dims = element.tensor.dims dims = element.tensor.dims


if dims in [[0], [1]]:
parameter_dict[element.tag] = Parameter(param_data[0], name=element.tag)
if dims == [0]:
if 'Float' in data_type:
param_data = float(param_data[0])
elif 'Int' in data_type:
param_data = int(param_data[0])
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
elif dims == [1]:
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
else: else:
param_dim = [] param_dim = []
for dim in dims: for dim in dims:


+ 0
- 4
tests/ut/python/nn/test_cell_wrapper.py View File

@@ -94,10 +94,6 @@ def test_parameter_update_float32():
def test_parameter_update_error(): def test_parameter_update_error():
""" test_parameter_update """ """ test_parameter_update """
input_np = np.array([1]) input_np = np.array([1])
input_parameter = Parameter(np.array([1]), 'input_parameter')


with pytest.raises(TypeError): with pytest.raises(TypeError):
ParameterUpdate(input_np) ParameterUpdate(input_np)

with pytest.raises(TypeError):
ParameterUpdate(input_parameter)

+ 0
- 74
tests/ut/python/nn/test_parameter.py View File

@@ -52,86 +52,12 @@ def test_parameter_tuple_illegal():




def test_parameter_init_illegal(): def test_parameter_init_illegal():
import numpy as np
dat = np.array([[1, 2, 3], [2, 3, 4]])
tensor = Tensor(dat)
data_none = None
data_bool = True data_bool = True
data_str = "nicai" data_str = "nicai"
data_int = 3
data_list = [1, "2", True]
data_tuple = (1, 2, 3)
np_arr_int16 = np.ones([1,1], dtype=np.int16)
np_arr_int32 = np.ones([1,1], dtype=np.int32)
np_arr_float16 = np.ones([1,1], dtype=np.float16)
np_arr_float32 = np.ones([1,1], dtype=np.float32)

# with pytest.raises(ValueError):
# Parameter(np_arr_int16[0][0], name=data_str)
Parameter(np_arr_int32[0], name=data_str)
Parameter(np_arr_float16[0], name=data_str)
Parameter(np_arr_float32[0], name=data_str)
Parameter(np_arr_float32, name=data_str)

Parameter(tensor, name=data_str)
Parameter(data_int, name=data_str)
Parameter(dat, name=data_str)
with pytest.raises(ValueError):
Parameter(data_none, name=data_str)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Parameter(data_bool, name=data_str) Parameter(data_bool, name=data_str)
with pytest.raises(ValueError):
Parameter(data_str, name=data_str)
Parameter(data_list, name=data_str)
with pytest.raises(ValueError):
Parameter(data_tuple, name=data_str)

Parameter(tensor, name=data_str)
Parameter(tensor, name=data_none)
with pytest.raises(ValueError):
Parameter(tensor, name=dat)
with pytest.raises(ValueError):
Parameter(tensor, name=tensor)
with pytest.raises(ValueError):
Parameter(tensor, name=data_bool)
with pytest.raises(ValueError):
Parameter(tensor, name=data_int)
with pytest.raises(ValueError):
Parameter(tensor, name=data_list)
with pytest.raises(ValueError):
Parameter(tensor, name=data_tuple)


Parameter(tensor, name=data_str, requires_grad=data_bool)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_none)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=dat)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=tensor)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_str)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_int)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_list)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_tuple)


Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_bool)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=dat)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=tensor)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_none)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_str)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_int)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_list)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_tuple)




def test_check_str_by_regular(): def test_check_str_by_regular():


+ 1
- 1
tests/ut/python/ops/test_momentum.py View File

@@ -31,7 +31,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
run_opt = C.MultitypeFuncGraph("run_opt") run_opt = C.MultitypeFuncGraph("run_opt")




@run_opt.register("Function", "Int", "Number", "Number",
@run_opt.register("Function", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor",
"Tensor") "Tensor")
def tensor_run_opt(opt, iters, learning_rate, momentum, def tensor_run_opt(opt, iters, learning_rate, momentum,


+ 1
- 1
tests/ut/python/pynative_mode/test_cell_bprop.py View File

@@ -51,7 +51,7 @@ class InlineMulADD(nn.Cell):
def __init__(self): def __init__(self):
super(InlineMulADD, self).__init__() super(InlineMulADD, self).__init__()
self.mul_add = MulAdd() self.mul_add = MulAdd()
self.param = Parameter(2, 'param')
self.param = 2


def construct(self, x, y): def construct(self, x, y):
return self.mul_add(x, y) + x + self.param * y return self.mul_add(x, y) + x + self.param * y


+ 2
- 2
tests/vm_impl/nn_ops_vm_impl.py View File

@@ -377,8 +377,8 @@ def vm_impl_momentum(self):
accumulation = accumulation.asnumpy() accumulation = accumulation.asnumpy()
variable = variable.asnumpy() variable = variable.asnumpy()
shape = accumulation.shape shape = accumulation.shape
learning_rate = np.full(shape, learning_rate)
momentum = np.full(shape, momentum)
learning_rate = np.full(shape, learning_rate.asnumpy())
momentum = np.full(shape, momentum.asnumpy())
accumulation = accumulation * momentum + gradient accumulation = accumulation * momentum + gradient
if use_nesterov is True: if use_nesterov is True:
variable -= gradient * learning_rate + accumulation * momentum * learning_rate variable -= gradient * learning_rate + accumulation * momentum * learning_rate


Loading…
Cancel
Save