| @@ -107,6 +107,8 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"r_oi_align_grad", "roi_align_grad"}, | |||
| {"i_ou", "iou"}, | |||
| {"s_gd", "sgd"}, | |||
| {"l_rn", "lrn"}, | |||
| {"l_rn_grad", "lrn_grad"}, | |||
| {"l_ars_update", "lars_v2_update"}, | |||
| {"n_ms_with_mask", "nms_with_mask"}, | |||
| {"square_sum_all", "square_sum_all"}, | |||
| @@ -721,3 +721,15 @@ def get_bprop_basic_lstm_cell(self): | |||
| dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) | |||
| return dxt, dht, dct_1, dw, db | |||
| return bprop | |||
| @bprop_getters.register(P.LRN) | |||
| def get_bprop_lrn(self): | |||
| """Grad definition for `LRN` operation.""" | |||
| grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta) | |||
| def bprop(x, out, dout): | |||
| dx = grad(dout, x, out) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -267,3 +267,5 @@ from .lin_space import _lin_space_tbe | |||
| from .matrix_diag import _matrix_diag_tbe | |||
| from .matrix_diag_part import _matrix_diag_part_tbe | |||
| from .matrix_set_diag import _matrix_set_diag_tbe | |||
| from .lrn import _lrn_tbe | |||
| from .lrn_grad import _lrn_grad_tbe | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LRN op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lrn_op_info = TBERegOp("LRN") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lrn.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lrn") \ | |||
| .partial_flag(True) \ | |||
| .attr("depth_radius", "optional", "int", "all", "5") \ | |||
| .attr("bias", "optional", "float", "all", "1.0") \ | |||
| .attr("alpha", "optional", "float", "all", "1.0") \ | |||
| .attr("beta", "optional", "float", "all", "0.5") \ | |||
| .attr("norm_region", "optional", "str", "all", "ACROSS_CHANNELS") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ | |||
| .get_op_info() | |||
| @op_info_register(lrn_op_info) | |||
| def _lrn_tbe(): | |||
| """LRN TBE register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LRNGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lrn_grad_op_info = TBERegOp("LRNGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lrn_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lrn_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("depth_radius", "optional", "int", "all") \ | |||
| .attr("bias", "optional", "float", "all") \ | |||
| .attr("alpha", "optional", "float", "all") \ | |||
| .attr("beta", "optional", "float", "all") \ | |||
| .input(0, "grads", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "y", False, "required", "all") \ | |||
| .output(0, "z", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||
| .get_op_info() | |||
| @op_info_register(lrn_grad_op_info) | |||
| def _lrn_grad_tbe(): | |||
| """LRNGrad TBE register""" | |||
| return | |||
| @@ -68,7 +68,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, | |||
| SigmoidCrossEntropyWithLogits, | |||
| SmoothL1Loss, Softmax, Softplus, | |||
| SmoothL1Loss, Softmax, Softplus, LRN, | |||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | |||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | |||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | |||
| @@ -316,7 +316,8 @@ __all__ = [ | |||
| "DataFormatDimMap", | |||
| "ApproximateEqual", | |||
| "InplaceUpdate", | |||
| "InTopK" | |||
| "InTopK", | |||
| "LRN" | |||
| ] | |||
| __all__.sort() | |||
| @@ -1364,3 +1364,22 @@ class InvGrad(PrimitiveWithInfer): | |||
| validator.check_type_name("dgate", x, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) | |||
| validator.check_type_name("grad", grad, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) | |||
| return x | |||
| class LRNGrad(PrimitiveWithInfer): | |||
| """Computes gradients for LRN operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): | |||
| self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z']) | |||
| validator.check_value_type("depth_radius", depth_radius, [int], self.name) | |||
| validator.check_value_type("bias", bias, [float], self.name) | |||
| validator.check_value_type("alpha", alpha, [float], self.name) | |||
| validator.check_value_type("beta", beta, [float], self.name) | |||
| def infer_dtype(self, grads, x, y): | |||
| args = {"grads": grads, "x": x, "y": y} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name) | |||
| return x | |||
| def infer_shape(self, grads, x, y): | |||
| return x | |||
| @@ -4252,3 +4252,44 @@ class InTopK(PrimitiveWithInfer): | |||
| validator.check("x2", len(x2_shape), "", 1, Rel.EQ, self.name) | |||
| validator.check("size of x2", x2_shape[0], "x1's first dimension", x1_shape[0], Rel.EQ, self.name) | |||
| return x2_shape | |||
| class LRN(PrimitiveWithInfer): | |||
| r""" | |||
| Local Response Normalization | |||
| Args: | |||
| depth_radius (int): Half-width of the 1-D normalization window. Shape of 0-D. | |||
| bias (float): An offset (usually positive to avoid dividing by 0). | |||
| alpha (float): A scale factor, usually positive. | |||
| beta (float): An exponent. | |||
| norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL". | |||
| Default: "ACROSS_CHANNELS". | |||
| Inputs: | |||
| - **x** (Tensor) - A 4D Tensor with float16 or float32 data type. | |||
| Outputs: | |||
| Tensor, With shape and data type same as the input tensor. | |||
| Examples: | |||
| >>> x = Tensor(np.random.rand(1, 10, 4, 4)), mindspore.float32) | |||
| >>> lrn = P.LRN() | |||
| >>> lrn(x) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"): | |||
| """Init LRN""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||
| validator.check_value_type("depth_radius", depth_radius, [int], self.name) | |||
| validator.check_value_type("bias", bias, [float], self.name) | |||
| validator.check_value_type("alpha", alpha, [float], self.name) | |||
| validator.check_value_type("beta", beta, [float], self.name) | |||
| validator.check_value_type("norm_region", norm_region, [str], self.name) | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| @@ -482,6 +482,29 @@ class PReLUGradNet(nn.Cell): | |||
| def construct(self, dout, x, w): | |||
| return self.prelu_grad(dout, x, w) | |||
| class LRNNet(nn.Cell): | |||
| """ LRNNet definition """ | |||
| def __init__(self): | |||
| super(LRNNet, self).__init__() | |||
| self.lrn = P.LRN() | |||
| def construct(self, x): | |||
| return self.lrn(x) | |||
| class LRNGradNet(nn.Cell): | |||
| """ LRNGradNet definition """ | |||
| def __init__(self): | |||
| super(LRNGradNet, self).__init__() | |||
| self.lrn_grad = G.LRNGrad() | |||
| def construct(self, dout, x, out): | |||
| return self.lrn_grad(dout, x, out) | |||
| test_cases = [ | |||
| ('SoftMaxGrad', { | |||
| 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), | |||
| @@ -593,6 +616,16 @@ test_cases = [ | |||
| Tensor(np.array([1, 2]).astype(np.float32))], | |||
| 'skip': ['backward'] | |||
| }), | |||
| ('LRNNet', { | |||
| 'block': LRNNet(), | |||
| 'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32))], | |||
| }), | |||
| ('LRNGradNet', { | |||
| 'block': LRNGradNet(), | |||
| 'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32)), | |||
| Tensor(np.ones([1, 5, 4, 4], np.float32)), | |||
| Tensor(np.ones([1, 5, 4, 4], np.float32))], | |||
| }), | |||
| ] | |||
| test_cases_for_verify_exception = [ | |||