diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index b7bad4fff8..deb858ff39 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -107,6 +107,8 @@ static std::map tbe_func_adapter_map = { {"r_oi_align_grad", "roi_align_grad"}, {"i_ou", "iou"}, {"s_gd", "sgd"}, + {"l_rn", "lrn"}, + {"l_rn_grad", "lrn_grad"}, {"l_ars_update", "lars_v2_update"}, {"n_ms_with_mask", "nms_with_mask"}, {"square_sum_all", "square_sum_all"}, diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 13fb89b23f..00b9e3051b 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -721,3 +721,15 @@ def get_bprop_basic_lstm_cell(self): dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) return dxt, dht, dct_1, dw, db return bprop + + +@bprop_getters.register(P.LRN) +def get_bprop_lrn(self): + """Grad definition for `LRN` operation.""" + grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta) + + def bprop(x, out, dout): + dx = grad(dout, x, out) + return (dx,) + + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 631ec1bf44..7207e5ee69 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -267,3 +267,5 @@ from .lin_space import _lin_space_tbe from .matrix_diag import _matrix_diag_tbe from .matrix_diag_part import _matrix_diag_part_tbe from .matrix_set_diag import _matrix_set_diag_tbe +from .lrn import _lrn_tbe +from .lrn_grad import _lrn_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/lrn.py b/mindspore/ops/_op_impl/tbe/lrn.py new file mode 100644 index 0000000000..2f22684f09 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/lrn.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LRN op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +lrn_op_info = TBERegOp("LRN") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("lrn.so") \ + .compute_cost(10) \ + .kernel_name("lrn") \ + .partial_flag(True) \ + .attr("depth_radius", "optional", "int", "all", "5") \ + .attr("bias", "optional", "float", "all", "1.0") \ + .attr("alpha", "optional", "float", "all", "1.0") \ + .attr("beta", "optional", "float", "all", "0.5") \ + .attr("norm_region", "optional", "str", "all", "ACROSS_CHANNELS") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ + .get_op_info() + + +@op_info_register(lrn_op_info) +def _lrn_tbe(): + """LRN TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/lrn_grad.py b/mindspore/ops/_op_impl/tbe/lrn_grad.py new file mode 100644 index 0000000000..4d37cf741b --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/lrn_grad.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LRNGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +lrn_grad_op_info = TBERegOp("LRNGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("lrn_grad.so") \ + .compute_cost(10) \ + .kernel_name("lrn_grad") \ + .partial_flag(True) \ + .attr("depth_radius", "optional", "int", "all") \ + .attr("bias", "optional", "float", "all") \ + .attr("alpha", "optional", "float", "all") \ + .attr("beta", "optional", "float", "all") \ + .input(0, "grads", False, "required", "all") \ + .input(1, "x", False, "required", "all") \ + .input(2, "y", False, "required", "all") \ + .output(0, "z", False, "required", "all") \ + .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ + .get_op_info() + + +@op_info_register(lrn_grad_op_info) +def _lrn_grad_tbe(): + """LRNGrad TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 6193292316..f9a6ee39a9 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -68,7 +68,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, - SmoothL1Loss, Softmax, Softplus, + SmoothL1Loss, Softmax, Softplus, LRN, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, @@ -316,7 +316,8 @@ __all__ = [ "DataFormatDimMap", "ApproximateEqual", "InplaceUpdate", - "InTopK" + "InTopK", + "LRN" ] __all__.sort() diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index c3f97b9f33..94b514cd0f 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1364,3 +1364,22 @@ class InvGrad(PrimitiveWithInfer): validator.check_type_name("dgate", x, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) validator.check_type_name("grad", grad, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) return x + + +class LRNGrad(PrimitiveWithInfer): + """Computes gradients for LRN operation.""" + @prim_attr_register + def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): + self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z']) + validator.check_value_type("depth_radius", depth_radius, [int], self.name) + validator.check_value_type("bias", bias, [float], self.name) + validator.check_value_type("alpha", alpha, [float], self.name) + validator.check_value_type("beta", beta, [float], self.name) + + def infer_dtype(self, grads, x, y): + args = {"grads": grads, "x": x, "y": y} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name) + return x + + def infer_shape(self, grads, x, y): + return x diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 521607ffb9..28944f8b4e 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -4252,3 +4252,44 @@ class InTopK(PrimitiveWithInfer): validator.check("x2", len(x2_shape), "", 1, Rel.EQ, self.name) validator.check("size of x2", x2_shape[0], "x1's first dimension", x1_shape[0], Rel.EQ, self.name) return x2_shape + + +class LRN(PrimitiveWithInfer): + r""" + Local Response Normalization + + Args: + depth_radius (int): Half-width of the 1-D normalization window. Shape of 0-D. + bias (float): An offset (usually positive to avoid dividing by 0). + alpha (float): A scale factor, usually positive. + beta (float): An exponent. + norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL". + Default: "ACROSS_CHANNELS". + + Inputs: + - **x** (Tensor) - A 4D Tensor with float16 or float32 data type. + + Outputs: + Tensor, With shape and data type same as the input tensor. + + Examples: + >>> x = Tensor(np.random.rand(1, 10, 4, 4)), mindspore.float32) + >>> lrn = P.LRN() + >>> lrn(x) + """ + @prim_attr_register + def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"): + """Init LRN""" + self.init_prim_io_names(inputs=['x'], outputs=['y']) + validator.check_value_type("depth_radius", depth_radius, [int], self.name) + validator.check_value_type("bias", bias, [float], self.name) + validator.check_value_type("alpha", alpha, [float], self.name) + validator.check_value_type("beta", beta, [float], self.name) + validator.check_value_type("norm_region", norm_region, [str], self.name) + + def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name) + return x_dtype + + def infer_shape(self, x_shape): + return x_shape diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index e950707234..ed7a8e695e 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -482,6 +482,29 @@ class PReLUGradNet(nn.Cell): def construct(self, dout, x, w): return self.prelu_grad(dout, x, w) + +class LRNNet(nn.Cell): + """ LRNNet definition """ + + def __init__(self): + super(LRNNet, self).__init__() + self.lrn = P.LRN() + + def construct(self, x): + return self.lrn(x) + + +class LRNGradNet(nn.Cell): + """ LRNGradNet definition """ + + def __init__(self): + super(LRNGradNet, self).__init__() + self.lrn_grad = G.LRNGrad() + + def construct(self, dout, x, out): + return self.lrn_grad(dout, x, out) + + test_cases = [ ('SoftMaxGrad', { 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), @@ -593,6 +616,16 @@ test_cases = [ Tensor(np.array([1, 2]).astype(np.float32))], 'skip': ['backward'] }), + ('LRNNet', { + 'block': LRNNet(), + 'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32))], + }), + ('LRNGradNet', { + 'block': LRNGradNet(), + 'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32)), + Tensor(np.ones([1, 5, 4, 4], np.float32)), + Tensor(np.ones([1, 5, 4, 4], np.float32))], + }), ] test_cases_for_verify_exception = [