| @@ -18,6 +18,8 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| from ... import context | from ... import context | ||||
| @@ -215,6 +217,8 @@ class SoftmaxCrossEntropyWithLogits(_Loss): | |||||
| sparse (bool): Specifies whether labels use sparse format or not. Default: False. | sparse (bool): Specifies whether labels use sparse format or not. Default: False. | ||||
| reduction (Union[str, None]): Type of reduction to apply to loss. Support 'sum' or 'mean' If None, | reduction (Union[str, None]): Type of reduction to apply to loss. Support 'sum' or 'mean' If None, | ||||
| do not reduction. Default: None. | do not reduction. Default: None. | ||||
| smooth_factor (float): Label smoothing factor. It is a optional input. Default: 0. | |||||
| num_classes (int): The number of classes in the task. It is a optional input Default: 2. | |||||
| Inputs: | Inputs: | ||||
| - **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`. | - **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`. | ||||
| @@ -235,14 +239,20 @@ class SoftmaxCrossEntropyWithLogits(_Loss): | |||||
| def __init__(self, | def __init__(self, | ||||
| is_grad=True, | is_grad=True, | ||||
| sparse=False, | sparse=False, | ||||
| reduction=None): | |||||
| reduction=None, | |||||
| smooth_factor=0, | |||||
| num_classes=2): | |||||
| super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) | super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) | ||||
| self.is_grad = is_grad | self.is_grad = is_grad | ||||
| self.sparse = sparse | self.sparse = sparse | ||||
| validator.check_integer("num_classes", num_classes, 1, Rel.GT, self.cls_name) | |||||
| validator.check_number_range("smooth_factor", smooth_factor, 0, 1, Rel.INC_BOTH, self.cls_name) | |||||
| self.smooth_factor = smooth_factor | |||||
| self.num_classes = num_classes | |||||
| self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() | self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() | ||||
| self.one_hot = P.OneHot() | self.one_hot = P.OneHot() | ||||
| self.on_value = Tensor(1.0, mstype.float32) | |||||
| self.off_value = Tensor(0.0, mstype.float32) | |||||
| self.on_value = Tensor(1.0 - self.smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * self.smooth_factor / (self.num_classes - 1), mstype.float32) | |||||
| self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] | self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] | ||||
| if self.is_cpugpu: | if self.is_cpugpu: | ||||
| @@ -17,6 +17,7 @@ from mindspore.ops import functional as F, composite as C, operations as P | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore._checkparam import check_bool | |||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| momentum_opt = C.MultitypeFuncGraph("momentum_opt") | momentum_opt = C.MultitypeFuncGraph("momentum_opt") | ||||
| @@ -67,6 +68,7 @@ class Momentum(Optimizer): | |||||
| momentum (float): Hyperparameter of type float, means momentum for the moving average. | momentum (float): Hyperparameter of type float, means momentum for the moving average. | ||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | ||||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | loss_scale (float): A floating point value for the loss scale. Default: 1.0. | ||||
| use_nesterov (bool): Enable Nesterov momentum. Default: False. | |||||
| Inputs: | Inputs: | ||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | ||||
| @@ -95,15 +97,16 @@ class Momentum(Optimizer): | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | >>> loss = nn.SoftmaxCrossEntropyWithLogits() | ||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | ||||
| """ | """ | ||||
| def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0): | |||||
| def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False): | |||||
| super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) | super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) | ||||
| if isinstance(momentum, float) and momentum < 0.0: | if isinstance(momentum, float) and momentum < 0.0: | ||||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | ||||
| self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") | self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") | ||||
| self.params = self.parameters | self.params = self.parameters | ||||
| self.use_nesterov = check_bool(use_nesterov) | |||||
| self.moments = self.params.clone(prefix="moments", init='zeros') | self.moments = self.params.clone(prefix="moments", init='zeros') | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.opt = P.ApplyMomentum() | |||||
| self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) | |||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| params = self.params | params = self.params | ||||
| @@ -1757,8 +1757,8 @@ class LayerNorm(Primitive): | |||||
| - **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`. | - **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`. | ||||
| The shape is :math:`(N, C)`. | The shape is :math:`(N, C)`. | ||||
| - **updated_gamma** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| - **updated_beta** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| - **mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| - **variance** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| Examples: | Examples: | ||||
| >>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32) | >>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32) | ||||