Browse Source

optimize SoftmaxCrossEntropWithLogits and momentum

tags/v0.3.0-alpha
zhaojichen 5 years ago
parent
commit
ff710dde5d
3 changed files with 20 additions and 7 deletions
  1. +13
    -3
      mindspore/nn/loss/loss.py
  2. +5
    -2
      mindspore/nn/optim/momentum.py
  3. +2
    -2
      mindspore/ops/operations/nn_ops.py

+ 13
- 3
mindspore/nn/loss/loss.py View File

@@ -18,6 +18,8 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ... import context from ... import context




@@ -215,6 +217,8 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
sparse (bool): Specifies whether labels use sparse format or not. Default: False. sparse (bool): Specifies whether labels use sparse format or not. Default: False.
reduction (Union[str, None]): Type of reduction to apply to loss. Support 'sum' or 'mean' If None, reduction (Union[str, None]): Type of reduction to apply to loss. Support 'sum' or 'mean' If None,
do not reduction. Default: None. do not reduction. Default: None.
smooth_factor (float): Label smoothing factor. It is a optional input. Default: 0.
num_classes (int): The number of classes in the task. It is a optional input Default: 2.


Inputs: Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`. - **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
@@ -235,14 +239,20 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
def __init__(self, def __init__(self,
is_grad=True, is_grad=True,
sparse=False, sparse=False,
reduction=None):
reduction=None,
smooth_factor=0,
num_classes=2):
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
self.is_grad = is_grad self.is_grad = is_grad
self.sparse = sparse self.sparse = sparse
validator.check_integer("num_classes", num_classes, 1, Rel.GT, self.cls_name)
validator.check_number_range("smooth_factor", smooth_factor, 0, 1, Rel.INC_BOTH, self.cls_name)
self.smooth_factor = smooth_factor
self.num_classes = num_classes
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.one_hot = P.OneHot() self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.on_value = Tensor(1.0 - self.smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * self.smooth_factor / (self.num_classes - 1), mstype.float32)
self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"]


if self.is_cpugpu: if self.is_cpugpu:


+ 5
- 2
mindspore/nn/optim/momentum.py View File

@@ -17,6 +17,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool
from .optimizer import Optimizer from .optimizer import Optimizer


momentum_opt = C.MultitypeFuncGraph("momentum_opt") momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@@ -67,6 +68,7 @@ class Momentum(Optimizer):
momentum (float): Hyperparameter of type float, means momentum for the moving average. momentum (float): Hyperparameter of type float, means momentum for the moving average.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0. loss_scale (float): A floating point value for the loss scale. Default: 1.0.
use_nesterov (bool): Enable Nesterov momentum. Default: False.


Inputs: Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@@ -95,15 +97,16 @@ class Momentum(Optimizer):
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0):
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters self.params = self.parameters
self.use_nesterov = check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros') self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)


def construct(self, gradients): def construct(self, gradients):
params = self.params params = self.params


+ 2
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -1757,8 +1757,8 @@ class LayerNorm(Primitive):


- **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`. - **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`.
The shape is :math:`(N, C)`. The shape is :math:`(N, C)`.
- **updated_gamma** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_beta** (Tensor) - Tensor of shape :math:`(C,)`.
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.


Examples: Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32) >>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)


Loading…
Cancel
Save