Browse Source

vm for LRN and LRNGrad

tags/v0.6.0-beta
jiangjinsheng 5 years ago
parent
commit
a1e148cb4d
9 changed files with 195 additions and 2 deletions
  1. +2
    -0
      mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
  2. +12
    -0
      mindspore/ops/_grad/grad_nn_ops.py
  3. +2
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  4. +41
    -0
      mindspore/ops/_op_impl/tbe/lrn.py
  5. +42
    -0
      mindspore/ops/_op_impl/tbe/lrn_grad.py
  6. +3
    -2
      mindspore/ops/operations/__init__.py
  7. +19
    -0
      mindspore/ops/operations/_grad_ops.py
  8. +41
    -0
      mindspore/ops/operations/nn_ops.py
  9. +33
    -0
      tests/ut/python/ops/test_nn_ops.py

+ 2
- 0
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc View File

@@ -107,6 +107,8 @@ static std::map<string, string> tbe_func_adapter_map = {
{"r_oi_align_grad", "roi_align_grad"},
{"i_ou", "iou"},
{"s_gd", "sgd"},
{"l_rn", "lrn"},
{"l_rn_grad", "lrn_grad"},
{"l_ars_update", "lars_v2_update"},
{"n_ms_with_mask", "nms_with_mask"},
{"square_sum_all", "square_sum_all"},


+ 12
- 0
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -721,3 +721,15 @@ def get_bprop_basic_lstm_cell(self):
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
return dxt, dht, dct_1, dw, db
return bprop


@bprop_getters.register(P.LRN)
def get_bprop_lrn(self):
"""Grad definition for `LRN` operation."""
grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)

def bprop(x, out, dout):
dx = grad(dout, x, out)
return (dx,)

return bprop

+ 2
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -267,3 +267,5 @@ from .lin_space import _lin_space_tbe
from .matrix_diag import _matrix_diag_tbe
from .matrix_diag_part import _matrix_diag_part_tbe
from .matrix_set_diag import _matrix_set_diag_tbe
from .lrn import _lrn_tbe
from .lrn_grad import _lrn_grad_tbe

+ 41
- 0
mindspore/ops/_op_impl/tbe/lrn.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""LRN op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

lrn_op_info = TBERegOp("LRN") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("lrn.so") \
.compute_cost(10) \
.kernel_name("lrn") \
.partial_flag(True) \
.attr("depth_radius", "optional", "int", "all", "5") \
.attr("bias", "optional", "float", "all", "1.0") \
.attr("alpha", "optional", "float", "all", "1.0") \
.attr("beta", "optional", "float", "all", "0.5") \
.attr("norm_region", "optional", "str", "all", "ACROSS_CHANNELS") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()


@op_info_register(lrn_op_info)
def _lrn_tbe():
"""LRN TBE register"""
return

+ 42
- 0
mindspore/ops/_op_impl/tbe/lrn_grad.py View File

@@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""LRNGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

lrn_grad_op_info = TBERegOp("LRNGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("lrn_grad.so") \
.compute_cost(10) \
.kernel_name("lrn_grad") \
.partial_flag(True) \
.attr("depth_radius", "optional", "int", "all") \
.attr("bias", "optional", "float", "all") \
.attr("alpha", "optional", "float", "all") \
.attr("beta", "optional", "float", "all") \
.input(0, "grads", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.input(2, "y", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()


@op_info_register(lrn_grad_op_info)
def _lrn_grad_tbe():
"""LRNGrad TBE register"""
return

+ 3
- 2
mindspore/ops/operations/__init__.py View File

@@ -68,7 +68,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
SigmoidCrossEntropyWithLogits,
SmoothL1Loss, Softmax, Softplus,
SmoothL1Loss, Softmax, Softplus, LRN,
SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
@@ -316,7 +316,8 @@ __all__ = [
"DataFormatDimMap",
"ApproximateEqual",
"InplaceUpdate",
"InTopK"
"InTopK",
"LRN"
]

__all__.sort()

+ 19
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -1364,3 +1364,22 @@ class InvGrad(PrimitiveWithInfer):
validator.check_type_name("dgate", x, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
validator.check_type_name("grad", grad, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
return x


class LRNGrad(PrimitiveWithInfer):
"""Computes gradients for LRN operation."""
@prim_attr_register
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
validator.check_value_type("depth_radius", depth_radius, [int], self.name)
validator.check_value_type("bias", bias, [float], self.name)
validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_value_type("beta", beta, [float], self.name)

def infer_dtype(self, grads, x, y):
args = {"grads": grads, "x": x, "y": y}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name)
return x

def infer_shape(self, grads, x, y):
return x

+ 41
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -4252,3 +4252,44 @@ class InTopK(PrimitiveWithInfer):
validator.check("x2", len(x2_shape), "", 1, Rel.EQ, self.name)
validator.check("size of x2", x2_shape[0], "x1's first dimension", x1_shape[0], Rel.EQ, self.name)
return x2_shape


class LRN(PrimitiveWithInfer):
r"""
Local Response Normalization

Args:
depth_radius (int): Half-width of the 1-D normalization window. Shape of 0-D.
bias (float): An offset (usually positive to avoid dividing by 0).
alpha (float): A scale factor, usually positive.
beta (float): An exponent.
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL".
Default: "ACROSS_CHANNELS".

Inputs:
- **x** (Tensor) - A 4D Tensor with float16 or float32 data type.

Outputs:
Tensor, With shape and data type same as the input tensor.

Examples:
>>> x = Tensor(np.random.rand(1, 10, 4, 4)), mindspore.float32)
>>> lrn = P.LRN()
>>> lrn(x)
"""
@prim_attr_register
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"):
"""Init LRN"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
validator.check_value_type("depth_radius", depth_radius, [int], self.name)
validator.check_value_type("bias", bias, [float], self.name)
validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_value_type("beta", beta, [float], self.name)
validator.check_value_type("norm_region", norm_region, [str], self.name)

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)
return x_dtype

def infer_shape(self, x_shape):
return x_shape

+ 33
- 0
tests/ut/python/ops/test_nn_ops.py View File

@@ -482,6 +482,29 @@ class PReLUGradNet(nn.Cell):
def construct(self, dout, x, w):
return self.prelu_grad(dout, x, w)


class LRNNet(nn.Cell):
""" LRNNet definition """

def __init__(self):
super(LRNNet, self).__init__()
self.lrn = P.LRN()

def construct(self, x):
return self.lrn(x)


class LRNGradNet(nn.Cell):
""" LRNGradNet definition """

def __init__(self):
super(LRNGradNet, self).__init__()
self.lrn_grad = G.LRNGrad()

def construct(self, dout, x, out):
return self.lrn_grad(dout, x, out)


test_cases = [
('SoftMaxGrad', {
'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())),
@@ -593,6 +616,16 @@ test_cases = [
Tensor(np.array([1, 2]).astype(np.float32))],
'skip': ['backward']
}),
('LRNNet', {
'block': LRNNet(),
'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32))],
}),
('LRNGradNet', {
'block': LRNGradNet(),
'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32)),
Tensor(np.ones([1, 5, 4, 4], np.float32)),
Tensor(np.ones([1, 5, 4, 4], np.float32))],
}),
]

test_cases_for_verify_exception = [


Loading…
Cancel
Save