GitOrigin-RevId: 939a4d26dd
tags/v1.7.0
| @@ -69,6 +69,7 @@ __all__ = [ | |||
| "leaky_relu", | |||
| "linear", | |||
| "local_conv2d", | |||
| "local_response_norm", | |||
| "logsigmoid", | |||
| "logsumexp", | |||
| "logsoftmax", | |||
| @@ -1746,6 +1747,53 @@ def pad( | |||
| return output | |||
| def local_response_norm( | |||
| inp: Tensor, | |||
| kernel_size: int = 5, | |||
| k: float = 2.0, | |||
| alpha: float = 1e-4, | |||
| beta: float = 0.75, | |||
| ) -> Tensor: | |||
| r""" | |||
| Apply local response normalization to the input tensor. | |||
| Args: | |||
| kernel_size: the size of the kernel to apply LRN on. | |||
| k: hyperparameter k. The default vaule is 2.0. | |||
| alpha: hyperparameter alpha. The default value is 1e-4. | |||
| beta: hyperparameter beta. The default value is 0.75. | |||
| Example: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as f | |||
| import numpy as np | |||
| inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5)) | |||
| GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066], | |||
| [ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ], | |||
| [ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675], | |||
| [14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ], | |||
| [19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]]) | |||
| out = f.local_response_norm(inp, kernel_size=3, k=1.0, alpha=1e-4, beta=0.75) | |||
| np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6) | |||
| print('pass') | |||
| Outputs: | |||
| .. testoutput:: | |||
| pass | |||
| """ | |||
| op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,) | |||
| (output,) = apply(op, inp) | |||
| return output | |||
| @lru_cache(maxsize=None) | |||
| def _get_layerPixelShuffle(device, dtype, dim_order): | |||
| @subgraph("LayerPixelShuffle", dtype, device, 3) | |||
| @@ -29,6 +29,7 @@ from .elemwise import Elemwise | |||
| from .embedding import Embedding | |||
| from .identity import Identity | |||
| from .linear import Linear | |||
| from .lrn import LocalResponseNorm | |||
| from .module import Module | |||
| from .normalization import GroupNorm, InstanceNorm, LayerNorm | |||
| from .padding import Pad | |||
| @@ -0,0 +1,69 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Tuple, Union | |||
| from ..functional import local_response_norm | |||
| from .module import Module | |||
| class LocalResponseNorm(Module): | |||
| r""" | |||
| Apply local response normalization to the input tensor. | |||
| Args: | |||
| kernel_size: the size of the kernel to apply LRN on. | |||
| k: hyperparameter k. The default vaule is 2.0. | |||
| alpha: hyperparameter alpha. The default value is 1e-4. | |||
| beta: hyperparameter beta. The default value is 0.75. | |||
| Example: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.module as M | |||
| import numpy as np | |||
| inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5)) | |||
| GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066], | |||
| [ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ], | |||
| [ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675], | |||
| [14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ], | |||
| [19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]]) | |||
| op = M.LocalResponseNorm(kernel_size=3, k=1.0, alpha=1e-4, beta=0.75) | |||
| out = op(inp) | |||
| np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6) | |||
| print('pass') | |||
| Outputs: | |||
| .. testoutput:: | |||
| pass | |||
| """ | |||
| def __init__( | |||
| self, | |||
| kernel_size: int = 5, | |||
| k: float = 2.0, | |||
| alpha: float = 1e-4, | |||
| beta: float = 0.75, | |||
| **kwargs | |||
| ): | |||
| super(LocalResponseNorm, self).__init__(**kwargs) | |||
| self.kernel_size = kernel_size | |||
| self.k = k | |||
| self.alpha = alpha | |||
| self.beta = beta | |||
| def forward(self, inp): | |||
| return local_response_norm(inp, self.kernel_size, self.k, self.alpha, self.beta) | |||
| @@ -21,6 +21,7 @@ | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lrn.h" | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| @@ -654,4 +655,13 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| } | |||
| OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace padding | |||
| namespace lrn { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const LRN&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::LRN::make(inputs[0], op.param()); | |||
| } | |||
| OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace LRN | |||
| } // namespace mgb::imperative | |||
| @@ -422,4 +422,6 @@ def Split: MgbHashableOp<"Split", [EmptyParam]> { | |||
| def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||
| def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||
| #endif // MGB_OPS | |||