| @@ -456,6 +456,17 @@ def get_bprop_smooth_l1_loss(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.L2Loss) | |||||
| def get_bprop_l2_loss(self): | |||||
| """Grad definition for `L2Loss` operation.""" | |||||
| def bprop(x, out, dout): | |||||
| dx = x * dout | |||||
| return (dx,) | |||||
| return bprop | |||||
| @bprop_getters.register(P.PReLU) | @bprop_getters.register(P.PReLU) | ||||
| def get_bprop_prelu(self): | def get_bprop_prelu(self): | ||||
| """Grad definition for `PReLU` operation.""" | """Grad definition for `PReLU` operation.""" | ||||
| @@ -117,6 +117,7 @@ from .layer_norm_beta_gamma_backprop import _layer_norm_beta_gamma_backprop_tbe | |||||
| from .layer_norm import _layer_norm_tbe | from .layer_norm import _layer_norm_tbe | ||||
| from .layer_norm_grad import _layer_norm_grad_tbe | from .layer_norm_grad import _layer_norm_grad_tbe | ||||
| from .layer_norm_x_backprop import _layer_norm_x_backprop_tbe | from .layer_norm_x_backprop import _layer_norm_x_backprop_tbe | ||||
| from .l2_loss import _l2_loss_tbe | |||||
| from .square_sum_v1 import _square_sum_v1_tbe | from .square_sum_v1 import _square_sum_v1_tbe | ||||
| from .square_sum_v2 import _square_sum_v2_tbe | from .square_sum_v2 import _square_sum_v2_tbe | ||||
| from .confusion_transpose_d import _confusion_transpose_d_tbe | from .confusion_transpose_d import _confusion_transpose_d_tbe | ||||
| @@ -0,0 +1,44 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """L2Loss op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| l2_loss_op_info = TBERegOp("L2Loss") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("l2_loss.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("l2_loss") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", None, "required", None) \ | |||||
| .output(0, "y", True, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(l2_loss_op_info) | |||||
| def _l2_loss_tbe(): | |||||
| """L2Loss TBE register""" | |||||
| return | |||||
| @@ -55,7 +55,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||||
| DropoutDoMask, | DropoutDoMask, | ||||
| DropoutGenMask, Flatten, FusedBatchNorm, | DropoutGenMask, Flatten, FusedBatchNorm, | ||||
| Gelu, Elu, | Gelu, Elu, | ||||
| GetNext, L2Normalize, LayerNorm, | |||||
| GetNext, L2Normalize, LayerNorm, L2Loss, | |||||
| LogSoftmax, | LogSoftmax, | ||||
| MaxPool, ExtractImagePatches, | MaxPool, ExtractImagePatches, | ||||
| AvgPool, Conv2DBackpropInput, | AvgPool, Conv2DBackpropInput, | ||||
| @@ -167,6 +167,7 @@ __all__ = [ | |||||
| 'FloatStatus', | 'FloatStatus', | ||||
| 'Reciprocal', | 'Reciprocal', | ||||
| 'SmoothL1Loss', | 'SmoothL1Loss', | ||||
| 'L2Loss', | |||||
| 'ReduceAll', | 'ReduceAll', | ||||
| 'ScalarToArray', | 'ScalarToArray', | ||||
| 'ScalarToTensor', | 'ScalarToTensor', | ||||
| @@ -1332,6 +1332,41 @@ class SmoothL1Loss(PrimitiveWithInfer): | |||||
| return prediction | return prediction | ||||
| class L2Loss(PrimitiveWithInfer): | |||||
| """ | |||||
| Calculates half of the L2 norm of a tensor without using the `sqrt`. | |||||
| Set `input_x` as x and output as loss. | |||||
| .. math:: | |||||
| loss = sum(x ** 2) / 2 | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - A input Tensor. | |||||
| Outputs: | |||||
| Tensor. Has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor. | |||||
| Examples | |||||
| >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16) | |||||
| >>> l2_loss = P.L2Loss() | |||||
| >>> l2_loss(input_x) | |||||
| 7.0 | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init L2Loss""" | |||||
| def infer_shape(self, input_x): | |||||
| loss_shape = [] | |||||
| return loss_shape | |||||
| def infer_dtype(self, x_type): | |||||
| validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name) | |||||
| return x_type | |||||
| class SGD(PrimitiveWithInfer): | class SGD(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Computes stochastic gradient descent (optionally with momentum). | Computes stochastic gradient descent (optionally with momentum). | ||||
| @@ -871,6 +871,14 @@ test_case_nn_ops = [ | |||||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], | 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], | ||||
| 'desc_bprop': [3, 3], | 'desc_bprop': [3, 3], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('L2Loss_1', { | |||||
| 'block': P.L2Loss(), | |||||
| 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float16)], | |||||
| 'desc_bprop': []}), | |||||
| ('L2Loss_2', { | |||||
| 'block': P.L2Loss(), | |||||
| 'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)], | |||||
| 'desc_bprop': []}), | |||||
| ] | ] | ||||
| test_case_array_ops = [ | test_case_array_ops = [ | ||||