Browse Source

!4058 modify parameter input

Merge pull request !4058 from lijiaqi/cell_inputs
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
cfae4096d2
4 changed files with 33 additions and 31 deletions
  1. +8
    -4
      mindspore/nn/cell.py
  2. +7
    -7
      mindspore/nn/optim/momentum.py
  3. +8
    -8
      mindspore/nn/optim/sgd.py
  4. +10
    -12
      mindspore/nn/wrap/cell_wrapper.py

+ 8
- 4
mindspore/nn/cell.py View File

@@ -383,9 +383,13 @@ class Cell:
inputs (Function or Cell): inputs of construct method. inputs (Function or Cell): inputs of construct method.
""" """
parallel_inputs_run = [] parallel_inputs_run = []
if len(inputs) > self._construct_inputs_num:
raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.
format(len(inputs), self._construct_inputs_num))
# judge if *args exists in input
if self.argspec[1] is not None:
prefix = self.argspec[1]
for i in range(len(inputs)):
key = prefix + str(i)
self._construct_inputs_names = self._construct_inputs_names + (key,)
self._construct_inputs_num = self._construct_inputs_num + 1
for i, tensor in enumerate(inputs): for i, tensor in enumerate(inputs):
key = self._construct_inputs_names[i] key = self._construct_inputs_names[i]
# if input is not used, self.parameter_layout_dict may not contain the key # if input is not used, self.parameter_layout_dict may not contain the key
@@ -412,7 +416,7 @@ class Cell:
from mindspore._extends.parse.parser import get_parse_method_of_class from mindspore._extends.parse.parser import get_parse_method_of_class


fn = get_parse_method_of_class(self) fn = get_parse_method_of_class(self)
inspect.getfullargspec(fn)
self.argspec = inspect.getfullargspec(fn)
self._construct_inputs_num = fn.__code__.co_argcount self._construct_inputs_num = fn.__code__.co_argcount
self._construct_inputs_names = fn.__code__.co_varnames self._construct_inputs_names = fn.__code__.co_varnames




+ 7
- 7
mindspore/nn/optim/momentum.py View File

@@ -41,7 +41,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment,




class Momentum(Optimizer): class Momentum(Optimizer):
"""
r"""
Implements the Momentum algorithm. Implements the Momentum algorithm.


Refer to the paper on the importance of initialization and momentum in deep learning for more details. Refer to the paper on the importance of initialization and momentum in deep learning for more details.
@@ -56,13 +56,13 @@ class Momentum(Optimizer):
.. math:: .. math::
v_{t} = v_{t-1} \ast u + gradients v_{t} = v_{t-1} \ast u + gradients


If use_nesterov is True:
.. math::
p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr)
If use_nesterov is True:
.. math::
p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr)


If use_nesterov is Flase:
.. math::
p_{t} = p_{t-1} - lr \ast v_{t}
If use_nesterov is Flase:
.. math::
p_{t} = p_{t-1} - lr \ast v_{t}


Here: where grad, lr, p, v and u denote the gradients, learning_rate, params, moments, and momentum respectively. Here: where grad, lr, p, v and u denote the gradients, learning_rate, params, moments, and momentum respectively.




+ 8
- 8
mindspore/nn/optim/sgd.py View File

@@ -32,7 +32,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, s




class SGD(Optimizer): class SGD(Optimizer):
"""
r"""
Implements stochastic gradient descent (optionally with momentum). Implements stochastic gradient descent (optionally with momentum).


Introduction to SGD can be found at https://en.wikipedia.org/wiki/Stochastic_gradient_descent. Introduction to SGD can be found at https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
@@ -47,15 +47,15 @@ class SGD(Optimizer):
To improve parameter groups performance, the customized order of parameters can be supported. To improve parameter groups performance, the customized order of parameters can be supported.


.. math:: .. math::
v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)


If nesterov is True:
.. math::
p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
If nesterov is True:
.. math::
p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})


If nesterov is Flase:
.. math::
p_{t+1} = p_{t} - lr \ast v_{t+1}
If nesterov is Flase:
.. math::
p_{t+1} = p_{t} - lr \ast v_{t+1}


To be noticed, for the first step, v_{t+1} = gradient To be noticed, for the first step, v_{t+1} = gradient




+ 10
- 12
mindspore/nn/wrap/cell_wrapper.py View File

@@ -82,7 +82,7 @@ class WithGradCell(Cell):


Wraps the network with backward cell to compute gradients. A network with a loss function is necessary Wraps the network with backward cell to compute gradients. A network with a loss function is necessary
as argument. If loss function in None, the network must be a wrapper of network and loss function. This as argument. If loss function in None, the network must be a wrapper of network and loss function. This
Cell accepts data and label as inputs and returns gradients for each trainable parameter.
Cell accepts *inputs as inputs and returns gradients for each trainable parameter.


Note: Note:
Run in PyNative mode. Run in PyNative mode.
@@ -95,8 +95,7 @@ class WithGradCell(Cell):
output value. Default: None. output value. Default: None.


Inputs: Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.


Outputs: Outputs:
list, a list of Tensors with identical shapes as trainable weights. list, a list of Tensors with identical shapes as trainable weights.
@@ -126,12 +125,12 @@ class WithGradCell(Cell):
self.network_with_loss = WithLossCell(self.network, self.loss_fn) self.network_with_loss = WithLossCell(self.network, self.loss_fn)
self.network_with_loss.set_train() self.network_with_loss.set_train()


def construct(self, data, label):
def construct(self, *inputs):
weights = self.weights weights = self.weights
if self.sens is None: if self.sens is None:
grads = self.grad(self.network_with_loss, weights)(data, label)
grads = self.grad(self.network_with_loss, weights)(*inputs)
else: else:
grads = self.grad(self.network_with_loss, weights)(data, label, self.sens)
grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens)
return grads return grads




@@ -139,7 +138,7 @@ class TrainOneStepCell(Cell):
r""" r"""
Network training package class. Network training package class.


Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
Wraps the network with an optimizer. The resulting Cell be trained with input *inputs.
Backward graph will be created in the construct function to do parameter updating. Different Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training. parallel modes are available to run the training.


@@ -149,8 +148,7 @@ class TrainOneStepCell(Cell):
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.


Inputs: Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.


Outputs: Outputs:
Tensor, a scalar Tensor with shape :math:`()`. Tensor, a scalar Tensor with shape :math:`()`.
@@ -181,11 +179,11 @@ class TrainOneStepCell(Cell):
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)


def construct(self, data, label):
def construct(self, *inputs):
weights = self.weights weights = self.weights
loss = self.network(data, label)
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, label, sens)
grads = self.grad(self.network, weights)(*inputs, sens)
if self.reducer_flag: if self.reducer_flag:
# apply grad reducer on grads # apply grad reducer on grads
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)


Loading…
Cancel
Save