| @@ -6,6 +6,7 @@ | |||
| * [BinaryOp](#binaryop) | |||
| * [BNLL](#bnll) | |||
| * [Cast](#cast) | |||
| * [CELU](#celu) | |||
| * [Clip](#clip) | |||
| * [Concat](#concat) | |||
| * [Convolution](#convolution) | |||
| @@ -197,6 +198,19 @@ Element type: | |||
| - 3 = int8 | |||
| - 4 = bfloat16 | |||
| # CELU | |||
| ``` | |||
| if x < 0 y = (exp(x / alpha) - 1.f) * alpha | |||
| else y = x | |||
| ``` | |||
| * one_blob_only | |||
| * support_inplace | |||
| | param id | name | type | default | description | | |||
| | --------- | ------------- | ----- | --------- | ----------------- | | |||
| | 0 | alpha | float | 1.f | | | |||
| # Clip | |||
| ``` | |||
| y = clamp(x, min, max) | |||
| @@ -164,6 +164,7 @@ ncnn_add_layer(CumulativeSum) | |||
| ncnn_add_layer(CopyTo) | |||
| ncnn_add_layer(Erf) | |||
| ncnn_add_layer(Diag) | |||
| ncnn_add_layer(CELU) | |||
| if(NCNN_VULKAN) | |||
| ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) | |||
| @@ -0,0 +1,56 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "celu.h" | |||
| #include <math.h> | |||
| namespace ncnn { | |||
| CELU::CELU() | |||
| { | |||
| one_blob_only = true; | |||
| support_inplace = true; | |||
| } | |||
| int CELU::load_param(const ParamDict& pd) | |||
| { | |||
| alpha = pd.get(0, 1.f); | |||
| return 0; | |||
| } | |||
| int CELU::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| { | |||
| int w = bottom_top_blob.w; | |||
| int h = bottom_top_blob.h; | |||
| int channels = bottom_top_blob.c; | |||
| int size = w * h; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < channels; q++) | |||
| { | |||
| float* ptr = bottom_top_blob.channel(q); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| if (ptr[i] < 0.f) | |||
| ptr[i] = (expf(ptr[i] / alpha) - 1.f) * alpha; | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace ncnn | |||
| @@ -0,0 +1,37 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #ifndef LAYER_CELU_H | |||
| #define LAYER_CELU_H | |||
| #include "layer.h" | |||
| namespace ncnn { | |||
| class CELU : public Layer | |||
| { | |||
| public: | |||
| CELU(); | |||
| virtual int load_param(const ParamDict& pd); | |||
| virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; | |||
| public: | |||
| float alpha; | |||
| }; | |||
| } // namespace ncnn | |||
| #endif // LAYER_CELU_H | |||
| @@ -72,6 +72,7 @@ ncnn_add_layer_test(Bias) | |||
| ncnn_add_layer_test(BinaryOp) | |||
| ncnn_add_layer_test(BNLL) | |||
| ncnn_add_layer_test(Cast) | |||
| ncnn_add_layer_test(CELU) | |||
| ncnn_add_layer_test(Clip) | |||
| ncnn_add_layer_test(Concat) | |||
| ncnn_add_layer_test(Convolution) | |||
| @@ -0,0 +1,66 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "layer/celu.h" | |||
| #include "testutil.h" | |||
| static int test_celu(const ncnn::Mat& a, float alpha) | |||
| { | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, alpha); | |||
| std::vector<ncnn::Mat> weights(0); | |||
| int ret = test_layer<ncnn::CELU>("CELU", pd, weights, a); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_celu failed a.dims=%d a=(%d %d %d) alpha=%f\n", a.dims, a.w, a.h, a.c, alpha); | |||
| } | |||
| return ret; | |||
| } | |||
| static int test_celu_0() | |||
| { | |||
| return 0 | |||
| || test_celu(RandomMat(5, 7, 24), 1.f) | |||
| || test_celu(RandomMat(7, 9, 12), 0.5f) | |||
| || test_celu(RandomMat(3, 5, 13), 0.2f); | |||
| } | |||
| static int test_celu_1() | |||
| { | |||
| return 0 | |||
| || test_celu(RandomMat(15, 24), 1.f) | |||
| || test_celu(RandomMat(17, 12), 0.5f) | |||
| || test_celu(RandomMat(19, 15), 0.2f); | |||
| } | |||
| static int test_celu_2() | |||
| { | |||
| return 0 | |||
| || test_celu(RandomMat(128), 1.f) | |||
| || test_celu(RandomMat(124), 0.5f) | |||
| || test_celu(RandomMat(127), 0.2f); | |||
| } | |||
| int main() | |||
| { | |||
| SRAND(7767517); | |||
| return 0 | |||
| || test_celu_0() | |||
| || test_celu_1() | |||
| || test_celu_2(); | |||
| } | |||
| @@ -471,7 +471,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) { | |||
| |nn.BatchNorm2d | :heavy_check_mark: | :heavy_check_mark: | | |||
| |nn.BatchNorm3d | :heavy_check_mark: | :heavy_check_mark: | | |||
| |nn.Bilinear | | | |||
| |nn.CELU | :heavy_check_mark: | | |||
| |nn.CELU | :heavy_check_mark: | :heavy_check_mark: | | |||
| |nn.ChannelShuffle | :heavy_check_mark: | :heavy_check_mark: | | |||
| |nn.ConstantPad1d | :heavy_check_mark: | :heavy_check_mark: | | |||
| |nn.ConstantPad2d | :heavy_check_mark: | :heavy_check_mark: | | |||
| @@ -412,6 +412,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/F_avg_pool2d.cpp | |||
| pass_ncnn/F_avg_pool3d.cpp | |||
| pass_ncnn/F_batch_norm.cpp | |||
| pass_ncnn/F_celu.cpp | |||
| pass_ncnn/F_conv_transpose1d.cpp | |||
| pass_ncnn/F_conv_transpose2d.cpp | |||
| pass_ncnn/F_conv_transpose3d.cpp | |||
| @@ -468,6 +469,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/nn_BatchNorm1d.cpp | |||
| pass_ncnn/nn_BatchNorm2d.cpp | |||
| pass_ncnn/nn_BatchNorm3d.cpp | |||
| pass_ncnn/nn_CELU.cpp | |||
| pass_ncnn/nn_ChannelShuffle.cpp | |||
| pass_ncnn/nn_ConstantPad1d.cpp | |||
| pass_ncnn/nn_ConstantPad2d.cpp | |||
| @@ -0,0 +1,49 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class F_celu : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| F.celu op_0 1 1 input out alpha=%alpha | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "CELU"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "celu"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_celu, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,49 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class nn_CELU : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.CELU op_0 1 1 input out alpha=%alpha | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "CELU"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "celu"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_CELU, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,63 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w): | |||
| x = F.celu(x) | |||
| y = F.celu(y, 0.8) | |||
| z = F.celu(z, 0.5) | |||
| w = F.celu(w, 2) | |||
| return x, y, z, w | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(16) | |||
| y = torch.rand(2, 16) | |||
| z = torch.rand(3, 12, 16) | |||
| w = torch.rand(5, 7, 9, 11) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_F_celu.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_F_celu.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") | |||
| # ncnn inference | |||
| import test_F_celu_ncnn | |||
| b = test_F_celu_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,66 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.act_0 = nn.CELU() | |||
| self.act_1 = nn.CELU(alpha=2.0) | |||
| def forward(self, x, y, z, w): | |||
| x = self.act_0(x) | |||
| y = self.act_0(y) | |||
| z = self.act_1(z) | |||
| w = self.act_1(w) | |||
| return x, y, z, w | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(12) | |||
| y = torch.rand(12, 64) | |||
| z = torch.rand(12, 24, 64) | |||
| w = torch.rand(12, 24, 32, 64) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_nn_CELU.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_nn_CELU.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") | |||
| # ncnn inference | |||
| import test_nn_CELU_ncnn | |||
| b = test_nn_CELU_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -23,7 +23,7 @@ class Model(nn.Module): | |||
| def forward(self, x, y, z, w): | |||
| x = F.celu(x) | |||
| y = F.celu(y, 0.8) | |||
| z = F.celu(z, -0.5) | |||
| z = F.celu(z, 0.5) | |||
| w = F.celu(w, 2) | |||
| return x, y, z, w | |||