|
|
|
@@ -18,7 +18,6 @@ from types import FunctionType, MethodType |
|
|
|
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, |
|
|
|
_get_parallel_mode) |
|
|
|
from mindspore.context import ParallelMode |
|
|
|
from ...common.tensor import Tensor |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ...common.parameter import Parameter, ParameterTuple |
|
|
|
from ...ops import composite as C |
|
|
|
@@ -197,15 +196,16 @@ class ForwardValueAndGrad(Cell): |
|
|
|
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. |
|
|
|
Default: False. |
|
|
|
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through |
|
|
|
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair |
|
|
|
parameter, the key must be sens. |
|
|
|
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. |
|
|
|
the input parameter. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. |
|
|
|
- **(\*sens)** - A sensitivity (gradient with respect to output) as the input of backpropagation. |
|
|
|
If network has single output, the sens is a tensor. |
|
|
|
If network has multiple outputs, the sens is the tuple(tensor). |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running. |
|
|
|
- **forward value** - The result of network forward running. |
|
|
|
- **gradients** (tuple(tensor)) - The gradients of network parameters and inputs. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
@@ -219,8 +219,8 @@ class ForwardValueAndGrad(Cell): |
|
|
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() |
|
|
|
>>> #1) Using the WithLossCell existing provide |
|
|
|
>>> loss_net = nn.WithLossCell(net, loss_fn) |
|
|
|
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True) |
|
|
|
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) |
|
|
|
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weights, get_by_list=True) |
|
|
|
>>> loss, grads = forward_value_and_grad(inputs, labels) |
|
|
|
>>> |
|
|
|
>>> #2) Using user-defined WithLossCell |
|
|
|
>>> class MyWithLossCell(Cell): |
|
|
|
@@ -238,40 +238,40 @@ class ForwardValueAndGrad(Cell): |
|
|
|
... return self._backbone |
|
|
|
... |
|
|
|
>>> loss_net = MyWithLossCell(net, loss_fn) |
|
|
|
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True) |
|
|
|
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) |
|
|
|
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weights, get_by_list=True) |
|
|
|
>>> loss, grads = forward_value_and_grad(inputs, labels) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False, sens=1.0): |
|
|
|
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False): |
|
|
|
super(ForwardValueAndGrad, self).__init__(auto_prefix=False) |
|
|
|
if not isinstance(network, (Cell, FunctionType, MethodType)): |
|
|
|
raise TypeError(f"The type of training network should be cell, function type or method type, " |
|
|
|
f"but got '{type(network)}'") |
|
|
|
if not isinstance(get_all, bool): |
|
|
|
raise TypeError(f"The type of get_all should be bool, but got '{type(get_all)}'") |
|
|
|
if not isinstance(get_by_list, bool): |
|
|
|
raise TypeError(f"The type of get_by_list should be bool, but got '{type(get_by_list)}'") |
|
|
|
if get_by_list and not isinstance(weights, ParameterTuple): |
|
|
|
raise TypeError(f"When get_by_list is set to True, the parameters of training network should be " |
|
|
|
f"ParameterTuple type, but got '{type(weights)}'") |
|
|
|
if get_by_list is not True and weights is not None: |
|
|
|
raise TypeError(f"When get_by_list is set to False, the parameters of training network should be " |
|
|
|
f"NoneType, but got '{type(weights)}'") |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
if isinstance(network, Cell): |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = weights |
|
|
|
self.get_all = get_all |
|
|
|
self.get_by_list = get_by_list |
|
|
|
self.sens_param = sens_param |
|
|
|
self.sens = sens |
|
|
|
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param) |
|
|
|
|
|
|
|
def construct(self, *inputs): |
|
|
|
weights = self.weights |
|
|
|
loss = self.network(*inputs) |
|
|
|
grad_inputs = inputs |
|
|
|
if self.sens_param: |
|
|
|
sens = self.sens |
|
|
|
if not isinstance(self.sens, Tensor): |
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) |
|
|
|
grads = self.grad(self.network, weights)(*inputs, sens) |
|
|
|
inputs = inputs[:-1] |
|
|
|
loss = self.network(*inputs) |
|
|
|
if self.get_by_list: |
|
|
|
grads = self.grad(self.network, self.weights)(*grad_inputs) |
|
|
|
else: |
|
|
|
grads = self.grad(self.network, weights)(*inputs) |
|
|
|
grads = self.grad(self.network)(*grad_inputs) |
|
|
|
return loss, grads |
|
|
|
|
|
|
|
|
|
|
|
|