diff --git a/tools/pnnx/README.md b/tools/pnnx/README.md index dbf8d052d..93eca067d 100644 --- a/tools/pnnx/README.md +++ b/tools/pnnx/README.md @@ -520,8 +520,8 @@ TORCH_LIBRARY(upfirdn2d_op, m) { |nn.LeakyReLU | :heavy_check_mark: | :heavy_check_mark: | |nn.Linear | :heavy_check_mark: | :heavy_check_mark: | |nn.LocalResponseNorm | :heavy_check_mark: | :heavy_check_mark: | -|nn.LogSigmoid | :heavy_check_mark: | -|nn.LogSoftmax | :heavy_check_mark: | +|nn.LogSigmoid | :heavy_check_mark: | :heavy_check_mark: | +|nn.LogSoftmax | :heavy_check_mark: | :heavy_check_mark: | |nn.LPPool1d | :heavy_check_mark: | |nn.LPPool2d | :heavy_check_mark: | |nn.LSTM | :heavy_check_mark: | :heavy_check_mark: | @@ -626,8 +626,8 @@ TORCH_LIBRARY(upfirdn2d_op, m) { |F.leaky_relu_ | :heavy_check_mark: | :heavy_check_mark: | |F.linear | :heavy_check_mark: | :heavy_check_mark:* | |F.local_response_norm | :heavy_check_mark: | :heavy_check_mark: | -|F.logsigmoid | :heavy_check_mark: | -|F.log_softmax | :heavy_check_mark: | +|F.logsigmoid | :heavy_check_mark: | :heavy_check_mark: | +|F.log_softmax | :heavy_check_mark: | :heavy_check_mark: | |F.lp_pool1d | :heavy_check_mark: | |F.lp_pool2d | :heavy_check_mark: | |F.max_pool1d | :heavy_check_mark: | :heavy_check_mark: | diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index bfbc2fa66..24845603c 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -428,6 +428,8 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/F_leaky_relu.cpp pass_ncnn/F_linear.cpp pass_ncnn/F_local_response_norm.cpp + pass_ncnn/F_log_softmax.cpp + pass_ncnn/F_logsigmoid.cpp pass_ncnn/F_max_pool1d.cpp pass_ncnn/F_max_pool2d.cpp pass_ncnn/F_max_pool3d.cpp @@ -485,6 +487,8 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/nn_LeakyReLU.cpp pass_ncnn/nn_Linear.cpp pass_ncnn/nn_LocalResponseNorm.cpp + pass_ncnn/nn_LogSigmoid.cpp + pass_ncnn/nn_LogSoftmax.cpp pass_ncnn/nn_LSTM.cpp pass_ncnn/nn_MaxPool1d.cpp pass_ncnn/nn_MaxPool2d.cpp diff --git a/tools/pnnx/src/pass_ncnn/F_log_softmax.cpp b/tools/pnnx/src/pass_ncnn/F_log_softmax.cpp new file mode 100644 index 000000000..046546f51 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_log_softmax.cpp @@ -0,0 +1,67 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_log_softmax : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.log_softmax op 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +F.softmax softmax 1 1 input softmax +UnaryOp log 1 1 softmax out 0=8 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F_log_softmax"; + } + + const char* name_str() const + { + return "f_logsoftmax"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + ops.at("softmax")->params["dim"] = captured_params.at("dim"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_log_softmax, 19) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_logsigmoid.cpp b/tools/pnnx/src/pass_ncnn/F_logsigmoid.cpp new file mode 100644 index 000000000..a894689ad --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_logsigmoid.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_logsigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.logsigmoid op 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +F.sigmoid sigmoid 1 1 input sigmoid +UnaryOp log 1 1 sigmoid out 0=8 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F_logsigmoid"; + } + + const char* name_str() const + { + return "f_logsigmoid"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_logsigmoid, 19) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_LogSigmoid.cpp b/tools/pnnx/src/pass_ncnn/nn_LogSigmoid.cpp new file mode 100644 index 000000000..6e1b10ff5 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_LogSigmoid.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_LogSigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.LogSigmoid op 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.Sigmoid sigmoid 1 1 input sigmoid +UnaryOp log 1 1 sigmoid out 0=8 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "LogSigmoid"; + } + + const char* name_str() const + { + return "logsigmoid"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LogSigmoid, 19) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_LogSoftmax.cpp b/tools/pnnx/src/pass_ncnn/nn_LogSoftmax.cpp new file mode 100644 index 000000000..c942728c9 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_LogSoftmax.cpp @@ -0,0 +1,67 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_LogSoftmax : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.LogSoftmax op 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.Softmax softmax 1 1 input softmax +UnaryOp log 1 1 softmax out 0=8 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "LogSoftmax"; + } + + const char* name_str() const + { + return "logsoftmax"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + ops.at("softmax")->params["dim"] = captured_params.at("dim"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LogSoftmax, 19) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 208c6f18b..677e37714 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -40,6 +40,8 @@ pnnx_ncnn_add_test(F_interpolate) pnnx_ncnn_add_test(F_layer_norm) pnnx_ncnn_add_test(F_leaky_relu) pnnx_ncnn_add_test(F_local_response_norm) +pnnx_ncnn_add_test(F_logsigmoid) +pnnx_ncnn_add_test(F_log_softmax) pnnx_ncnn_add_test(F_max_pool1d) pnnx_ncnn_add_test(F_max_pool2d) pnnx_ncnn_add_test(F_max_pool3d) @@ -100,6 +102,8 @@ pnnx_ncnn_add_test(nn_LayerNorm) pnnx_ncnn_add_test(nn_LeakyReLU) pnnx_ncnn_add_test(nn_Linear) pnnx_ncnn_add_test(nn_LocalResponseNorm) +pnnx_ncnn_add_test(nn_LogSigmoid) +pnnx_ncnn_add_test(nn_LogSoftmax) pnnx_ncnn_add_test(nn_LSTM) pnnx_ncnn_add_test(nn_MaxPool1d) pnnx_ncnn_add_test(nn_MaxPool2d) diff --git a/tools/pnnx/tests/ncnn/test_F_log_softmax.py b/tools/pnnx/tests/ncnn/test_F_log_softmax.py new file mode 100644 index 000000000..3f53229b2 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_log_softmax.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.log_softmax(x, 0) + y = F.log_softmax(y, 1) + z = F.log_softmax(z, 2) + z2 = F.log_softmax(z, -1) + return x, y, z, z2 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(16) + y = torch.rand(2, 16) + z = torch.rand(3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_log_softmax.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_log_softmax.pt inputshape=[16],[2,16],[3,12,16]") + + # ncnn inference + import test_F_log_softmax_ncnn + b = test_F_log_softmax_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_logsigmoid.py b/tools/pnnx/tests/ncnn/test_F_logsigmoid.py new file mode 100644 index 000000000..023652f7a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_logsigmoid.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.logsigmoid(x) + y = F.logsigmoid(y) + z = F.logsigmoid(z) + w = F.logsigmoid(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(16) + y = torch.rand(2, 16) + z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_logsigmoid.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_logsigmoid.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") + + # ncnn inference + import test_F_logsigmoid_ncnn + b = test_F_logsigmoid_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py b/tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py new file mode 100644 index 000000000..26f676c94 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LogSigmoid() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(12) + y = torch.rand(12, 64) + z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_LogSigmoid.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_LogSigmoid.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") + + # ncnn inference + import test_nn_LogSigmoid_ncnn + b = test_nn_LogSigmoid_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py b/tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py new file mode 100644 index 000000000..a02c050e6 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py @@ -0,0 +1,67 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LogSoftmax(dim=0) + self.act_1 = nn.LogSoftmax(dim=1) + self.act_2 = nn.LogSoftmax(dim=2) + self.act_3 = nn.LogSoftmax(dim=-1) + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_1(y) + z = self.act_2(z) + z2 = self.act_3(z) + return x, y, z, z2 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(12) + y = torch.rand(12, 64) + z = torch.rand(12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_LogSoftmax.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_LogSoftmax.pt inputshape=[12],[12,64],[12,24,64]") + + # ncnn inference + import test_nn_LogSoftmax_ncnn + b = test_nn_LogSoftmax_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)