diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index ae730d78a7..887c2a7528 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -456,6 +456,17 @@ def get_bprop_smooth_l1_loss(self): return bprop +@bprop_getters.register(P.L2Loss) +def get_bprop_l2_loss(self): + """Grad definition for `L2Loss` operation.""" + + def bprop(x, out, dout): + dx = x * dout + return (dx,) + + return bprop + + @bprop_getters.register(P.PReLU) def get_bprop_prelu(self): """Grad definition for `PReLU` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 2cffc37491..37da184869 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -117,6 +117,7 @@ from .layer_norm_beta_gamma_backprop import _layer_norm_beta_gamma_backprop_tbe from .layer_norm import _layer_norm_tbe from .layer_norm_grad import _layer_norm_grad_tbe from .layer_norm_x_backprop import _layer_norm_x_backprop_tbe +from .l2_loss import _l2_loss_tbe from .square_sum_v1 import _square_sum_v1_tbe from .square_sum_v2 import _square_sum_v2_tbe from .confusion_transpose_d import _confusion_transpose_d_tbe diff --git a/mindspore/ops/_op_impl/tbe/l2_loss.py b/mindspore/ops/_op_impl/tbe/l2_loss.py new file mode 100644 index 0000000000..7d1394ad64 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/l2_loss.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""L2Loss op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +l2_loss_op_info = TBERegOp("L2Loss") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("l2_loss.so") \ + .compute_cost(10) \ + .kernel_name("l2_loss") \ + .partial_flag(True) \ + .input(0, "x", None, "required", None) \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_Default) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_Default) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(l2_loss_op_info) +def _l2_loss_tbe(): + """L2Loss TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 1f0ee8a04d..2860690b91 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -55,7 +55,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, DropoutDoMask, DropoutGenMask, Flatten, FusedBatchNorm, Gelu, Elu, - GetNext, L2Normalize, LayerNorm, + GetNext, L2Normalize, LayerNorm, L2Loss, LogSoftmax, MaxPool, ExtractImagePatches, AvgPool, Conv2DBackpropInput, @@ -167,6 +167,7 @@ __all__ = [ 'FloatStatus', 'Reciprocal', 'SmoothL1Loss', + 'L2Loss', 'ReduceAll', 'ScalarToArray', 'ScalarToTensor', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index ed9f0742e8..6f39fdd2ae 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1332,6 +1332,41 @@ class SmoothL1Loss(PrimitiveWithInfer): return prediction +class L2Loss(PrimitiveWithInfer): + """ + Calculates half of the L2 norm of a tensor without using the `sqrt`. + + Set `input_x` as x and output as loss. + + .. math:: + loss = sum(x ** 2) / 2 + + Inputs: + - **input_x** (Tensor) - A input Tensor. + + Outputs: + Tensor. Has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor. + + Examples + >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16) + >>> l2_loss = P.L2Loss() + >>> l2_loss(input_x) + 7.0 + """ + @prim_attr_register + def __init__(self): + """init L2Loss""" + + def infer_shape(self, input_x): + loss_shape = [] + return loss_shape + + def infer_dtype(self, x_type): + validator.check_subclass("x_type", x_type, mstype.tensor, self.name) + validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name) + return x_type + + class SGD(PrimitiveWithInfer): """ Computes stochastic gradient descent (optionally with momentum). diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 1a79935467..1bd3a2e438 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -871,6 +871,14 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], 'desc_bprop': [3, 3], 'skip': ['backward']}), + ('L2Loss_1', { + 'block': P.L2Loss(), + 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float16)], + 'desc_bprop': []}), + ('L2Loss_2', { + 'block': P.L2Loss(), + 'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)], + 'desc_bprop': []}), ] test_case_array_ops = [