Browse Source

Add L2Loss op for VM

tags/v0.2.0-alpha
liuxiao 5 years ago
parent
commit
c874e2d484
6 changed files with 101 additions and 1 deletions
  1. +11
    -0
      mindspore/ops/_grad/grad_nn_ops.py
  2. +1
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  3. +44
    -0
      mindspore/ops/_op_impl/tbe/l2_loss.py
  4. +2
    -1
      mindspore/ops/operations/__init__.py
  5. +35
    -0
      mindspore/ops/operations/nn_ops.py
  6. +8
    -0
      tests/ut/python/ops/test_ops.py

+ 11
- 0
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -456,6 +456,17 @@ def get_bprop_smooth_l1_loss(self):
return bprop return bprop




@bprop_getters.register(P.L2Loss)
def get_bprop_l2_loss(self):
"""Grad definition for `L2Loss` operation."""

def bprop(x, out, dout):
dx = x * dout
return (dx,)

return bprop


@bprop_getters.register(P.PReLU) @bprop_getters.register(P.PReLU)
def get_bprop_prelu(self): def get_bprop_prelu(self):
"""Grad definition for `PReLU` operation.""" """Grad definition for `PReLU` operation."""


+ 1
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -117,6 +117,7 @@ from .layer_norm_beta_gamma_backprop import _layer_norm_beta_gamma_backprop_tbe
from .layer_norm import _layer_norm_tbe from .layer_norm import _layer_norm_tbe
from .layer_norm_grad import _layer_norm_grad_tbe from .layer_norm_grad import _layer_norm_grad_tbe
from .layer_norm_x_backprop import _layer_norm_x_backprop_tbe from .layer_norm_x_backprop import _layer_norm_x_backprop_tbe
from .l2_loss import _l2_loss_tbe
from .square_sum_v1 import _square_sum_v1_tbe from .square_sum_v1 import _square_sum_v1_tbe
from .square_sum_v2 import _square_sum_v2_tbe from .square_sum_v2 import _square_sum_v2_tbe
from .confusion_transpose_d import _confusion_transpose_d_tbe from .confusion_transpose_d import _confusion_transpose_d_tbe


+ 44
- 0
mindspore/ops/_op_impl/tbe/l2_loss.py View File

@@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""L2Loss op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

l2_loss_op_info = TBERegOp("L2Loss") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("l2_loss.so") \
.compute_cost(10) \
.kernel_name("l2_loss") \
.partial_flag(True) \
.input(0, "x", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_Default) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_Default) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_Default) \
.get_op_info()


@op_info_register(l2_loss_op_info)
def _l2_loss_tbe():
"""L2Loss TBE register"""
return

+ 2
- 1
mindspore/ops/operations/__init__.py View File

@@ -55,7 +55,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
DropoutDoMask, DropoutDoMask,
DropoutGenMask, Flatten, FusedBatchNorm, DropoutGenMask, Flatten, FusedBatchNorm,
Gelu, Elu, Gelu, Elu,
GetNext, L2Normalize, LayerNorm,
GetNext, L2Normalize, LayerNorm, L2Loss,
LogSoftmax, LogSoftmax,
MaxPool, ExtractImagePatches, MaxPool, ExtractImagePatches,
AvgPool, Conv2DBackpropInput, AvgPool, Conv2DBackpropInput,
@@ -167,6 +167,7 @@ __all__ = [
'FloatStatus', 'FloatStatus',
'Reciprocal', 'Reciprocal',
'SmoothL1Loss', 'SmoothL1Loss',
'L2Loss',
'ReduceAll', 'ReduceAll',
'ScalarToArray', 'ScalarToArray',
'ScalarToTensor', 'ScalarToTensor',


+ 35
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -1332,6 +1332,41 @@ class SmoothL1Loss(PrimitiveWithInfer):
return prediction return prediction




class L2Loss(PrimitiveWithInfer):
"""
Calculates half of the L2 norm of a tensor without using the `sqrt`.

Set `input_x` as x and output as loss.

.. math::
loss = sum(x ** 2) / 2

Inputs:
- **input_x** (Tensor) - A input Tensor.

Outputs:
Tensor. Has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor.

Examples
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16)
>>> l2_loss = P.L2Loss()
>>> l2_loss(input_x)
7.0
"""
@prim_attr_register
def __init__(self):
"""init L2Loss"""

def infer_shape(self, input_x):
loss_shape = []
return loss_shape

def infer_dtype(self, x_type):
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name)
return x_type


class SGD(PrimitiveWithInfer): class SGD(PrimitiveWithInfer):
""" """
Computes stochastic gradient descent (optionally with momentum). Computes stochastic gradient descent (optionally with momentum).


+ 8
- 0
tests/ut/python/ops/test_ops.py View File

@@ -871,6 +871,14 @@ test_case_nn_ops = [
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
'desc_bprop': [3, 3], 'desc_bprop': [3, 3],
'skip': ['backward']}), 'skip': ['backward']}),
('L2Loss_1', {
'block': P.L2Loss(),
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float16)],
'desc_bprop': []}),
('L2Loss_2', {
'block': P.L2Loss(),
'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)],
'desc_bprop': []}),
] ]


test_case_array_ops = [ test_case_array_ops = [


Loading…
Cancel
Save