diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 794833086..56e7516a3 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -6,6 +6,7 @@ * [BinaryOp](#binaryop) * [BNLL](#bnll) * [Cast](#cast) +* [CELU](#celu) * [Clip](#clip) * [Concat](#concat) * [Convolution](#convolution) @@ -197,6 +198,19 @@ Element type: - 3 = int8 - 4 = bfloat16 +# CELU +``` +if x < 0 y = (exp(x / alpha) - 1.f) * alpha +else y = x +``` + +* one_blob_only +* support_inplace + +| param id | name | type | default | description | +| --------- | ------------- | ----- | --------- | ----------------- | +| 0 | alpha | float | 1.f | | + # Clip ``` y = clamp(x, min, max) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 255f20b72..4a4ea24e6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -164,6 +164,7 @@ ncnn_add_layer(CumulativeSum) ncnn_add_layer(CopyTo) ncnn_add_layer(Erf) ncnn_add_layer(Diag) +ncnn_add_layer(CELU) if(NCNN_VULKAN) ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) diff --git a/src/layer/celu.cpp b/src/layer/celu.cpp new file mode 100644 index 000000000..4bddfc368 --- /dev/null +++ b/src/layer/celu.cpp @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "celu.h" + +#include + +namespace ncnn { + +CELU::CELU() +{ + one_blob_only = true; + support_inplace = true; +} + +int CELU::load_param(const ParamDict& pd) +{ + alpha = pd.get(0, 1.f); + + return 0; +} + +int CELU::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int channels = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + if (ptr[i] < 0.f) + ptr[i] = (expf(ptr[i] / alpha) - 1.f) * alpha; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/celu.h b/src/layer/celu.h new file mode 100644 index 000000000..5e4b9c87f --- /dev/null +++ b/src/layer/celu.h @@ -0,0 +1,37 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_CELU_H +#define LAYER_CELU_H + +#include "layer.h" + +namespace ncnn { + +class CELU : public Layer +{ +public: + CELU(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +public: + float alpha; +}; + +} // namespace ncnn + +#endif // LAYER_CELU_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 897e78865..21de08c6f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -72,6 +72,7 @@ ncnn_add_layer_test(Bias) ncnn_add_layer_test(BinaryOp) ncnn_add_layer_test(BNLL) ncnn_add_layer_test(Cast) +ncnn_add_layer_test(CELU) ncnn_add_layer_test(Clip) ncnn_add_layer_test(Concat) ncnn_add_layer_test(Convolution) diff --git a/tests/test_celu.cpp b/tests/test_celu.cpp new file mode 100644 index 000000000..79b0bbe18 --- /dev/null +++ b/tests/test_celu.cpp @@ -0,0 +1,66 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "layer/celu.h" +#include "testutil.h" + +static int test_celu(const ncnn::Mat& a, float alpha) +{ + ncnn::ParamDict pd; + pd.set(0, alpha); + + std::vector weights(0); + + int ret = test_layer("CELU", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_celu failed a.dims=%d a=(%d %d %d) alpha=%f\n", a.dims, a.w, a.h, a.c, alpha); + } + + return ret; +} + +static int test_celu_0() +{ + return 0 + || test_celu(RandomMat(5, 7, 24), 1.f) + || test_celu(RandomMat(7, 9, 12), 0.5f) + || test_celu(RandomMat(3, 5, 13), 0.2f); +} + +static int test_celu_1() +{ + return 0 + || test_celu(RandomMat(15, 24), 1.f) + || test_celu(RandomMat(17, 12), 0.5f) + || test_celu(RandomMat(19, 15), 0.2f); +} + +static int test_celu_2() +{ + return 0 + || test_celu(RandomMat(128), 1.f) + || test_celu(RandomMat(124), 0.5f) + || test_celu(RandomMat(127), 0.2f); +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_celu_0() + || test_celu_1() + || test_celu_2(); +} diff --git a/tools/pnnx/README.md b/tools/pnnx/README.md index 93eca067d..a9789f166 100644 --- a/tools/pnnx/README.md +++ b/tools/pnnx/README.md @@ -471,7 +471,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) { |nn.BatchNorm2d | :heavy_check_mark: | :heavy_check_mark: | |nn.BatchNorm3d | :heavy_check_mark: | :heavy_check_mark: | |nn.Bilinear | | -|nn.CELU | :heavy_check_mark: | +|nn.CELU | :heavy_check_mark: | :heavy_check_mark: | |nn.ChannelShuffle | :heavy_check_mark: | :heavy_check_mark: | |nn.ConstantPad1d | :heavy_check_mark: | :heavy_check_mark: | |nn.ConstantPad2d | :heavy_check_mark: | :heavy_check_mark: | diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 7c1ac15c8..1ad17cf3c 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -412,6 +412,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/F_avg_pool2d.cpp pass_ncnn/F_avg_pool3d.cpp pass_ncnn/F_batch_norm.cpp + pass_ncnn/F_celu.cpp pass_ncnn/F_conv_transpose1d.cpp pass_ncnn/F_conv_transpose2d.cpp pass_ncnn/F_conv_transpose3d.cpp @@ -468,6 +469,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/nn_BatchNorm1d.cpp pass_ncnn/nn_BatchNorm2d.cpp pass_ncnn/nn_BatchNorm3d.cpp + pass_ncnn/nn_CELU.cpp pass_ncnn/nn_ChannelShuffle.cpp pass_ncnn/nn_ConstantPad1d.cpp pass_ncnn/nn_ConstantPad2d.cpp diff --git a/tools/pnnx/src/pass_ncnn/F_celu.cpp b/tools/pnnx/src/pass_ncnn/F_celu.cpp new file mode 100644 index 000000000..dd89d8897 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_celu.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_celu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.celu op_0 1 1 input out alpha=%alpha +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "CELU"; + } + + const char* name_str() const + { + return "celu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_celu, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_CELU.cpp b/tools/pnnx/src/pass_ncnn/nn_CELU.cpp new file mode 100644 index 000000000..1bdf28467 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_CELU.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_CELU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.CELU op_0 1 1 input out alpha=%alpha +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "CELU"; + } + + const char* name_str() const + { + return "celu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_CELU, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/test_F_celu.py b/tools/pnnx/tests/ncnn/test_F_celu.py new file mode 100644 index 000000000..04ecc37ba --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_celu.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.celu(x) + y = F.celu(y, 0.8) + z = F.celu(z, 0.5) + w = F.celu(w, 2) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(16) + y = torch.rand(2, 16) + z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_celu.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_celu.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") + + # ncnn inference + import test_F_celu_ncnn + b = test_F_celu_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_CELU.py b/tools/pnnx/tests/ncnn/test_nn_CELU.py new file mode 100644 index 000000000..097cc22f7 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_CELU.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.CELU() + self.act_1 = nn.CELU(alpha=2.0) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(12) + y = torch.rand(12, 64) + z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_CELU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_CELU.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") + + # ncnn inference + import test_nn_CELU_ncnn + b = test_nn_CELU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_celu.py b/tools/pnnx/tests/test_F_celu.py index 49d9e7a8d..43b25f854 100644 --- a/tools/pnnx/tests/test_F_celu.py +++ b/tools/pnnx/tests/test_F_celu.py @@ -23,7 +23,7 @@ class Model(nn.Module): def forward(self, x, y, z, w): x = F.celu(x) y = F.celu(y, 0.8) - z = F.celu(z, -0.5) + z = F.celu(z, 0.5) w = F.celu(w, 2) return x, y, z, w