Browse Source

Add logxxx to log comp xxx rewriter where xxx = sigmoid or softmax (#4925)

* Add logxxx to log comp xxx rewriter

* Use pattern matching for LogSigmoid and LogSoftmax

* Add conversion passes for functional counterparts

* Update documentation
tags/20230816
lrw04 GitHub 2 years ago
parent
commit
fed3b43c73
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 533 additions and 4 deletions
  1. +4
    -4
      tools/pnnx/README.md
  2. +4
    -0
      tools/pnnx/src/CMakeLists.txt
  3. +67
    -0
      tools/pnnx/src/pass_ncnn/F_log_softmax.cpp
  4. +65
    -0
      tools/pnnx/src/pass_ncnn/F_logsigmoid.cpp
  5. +65
    -0
      tools/pnnx/src/pass_ncnn/nn_LogSigmoid.cpp
  6. +67
    -0
      tools/pnnx/src/pass_ncnn/nn_LogSoftmax.cpp
  7. +4
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  8. +62
    -0
      tools/pnnx/tests/ncnn/test_F_log_softmax.py
  9. +63
    -0
      tools/pnnx/tests/ncnn/test_F_logsigmoid.py
  10. +65
    -0
      tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py
  11. +67
    -0
      tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py

+ 4
- 4
tools/pnnx/README.md View File

@@ -520,8 +520,8 @@ TORCH_LIBRARY(upfirdn2d_op, m) {
|nn.LeakyReLU | :heavy_check_mark: | :heavy_check_mark: |
|nn.Linear | :heavy_check_mark: | :heavy_check_mark: |
|nn.LocalResponseNorm | :heavy_check_mark: | :heavy_check_mark: |
|nn.LogSigmoid | :heavy_check_mark: |
|nn.LogSoftmax | :heavy_check_mark: |
|nn.LogSigmoid | :heavy_check_mark: | :heavy_check_mark: |
|nn.LogSoftmax | :heavy_check_mark: | :heavy_check_mark: |
|nn.LPPool1d | :heavy_check_mark: |
|nn.LPPool2d | :heavy_check_mark: |
|nn.LSTM | :heavy_check_mark: | :heavy_check_mark: |
@@ -626,8 +626,8 @@ TORCH_LIBRARY(upfirdn2d_op, m) {
|F.leaky_relu_ | :heavy_check_mark: | :heavy_check_mark: |
|F.linear | :heavy_check_mark: | :heavy_check_mark:* |
|F.local_response_norm | :heavy_check_mark: | :heavy_check_mark: |
|F.logsigmoid | :heavy_check_mark: |
|F.log_softmax | :heavy_check_mark: |
|F.logsigmoid | :heavy_check_mark: | :heavy_check_mark: |
|F.log_softmax | :heavy_check_mark: | :heavy_check_mark: |
|F.lp_pool1d | :heavy_check_mark: |
|F.lp_pool2d | :heavy_check_mark: |
|F.max_pool1d | :heavy_check_mark: | :heavy_check_mark: |


+ 4
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -428,6 +428,8 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/F_leaky_relu.cpp
pass_ncnn/F_linear.cpp
pass_ncnn/F_local_response_norm.cpp
pass_ncnn/F_log_softmax.cpp
pass_ncnn/F_logsigmoid.cpp
pass_ncnn/F_max_pool1d.cpp
pass_ncnn/F_max_pool2d.cpp
pass_ncnn/F_max_pool3d.cpp
@@ -485,6 +487,8 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/nn_LeakyReLU.cpp
pass_ncnn/nn_Linear.cpp
pass_ncnn/nn_LocalResponseNorm.cpp
pass_ncnn/nn_LogSigmoid.cpp
pass_ncnn/nn_LogSoftmax.cpp
pass_ncnn/nn_LSTM.cpp
pass_ncnn/nn_MaxPool1d.cpp
pass_ncnn/nn_MaxPool2d.cpp


+ 67
- 0
tools/pnnx/src/pass_ncnn/F_log_softmax.cpp View File

@@ -0,0 +1,67 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class F_log_softmax : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
F.log_softmax op 1 1 input out dim=%dim
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
F.softmax softmax 1 1 input softmax
UnaryOp log 1 1 softmax out 0=8
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F_log_softmax";
}

const char* name_str() const
{
return "f_logsoftmax";
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);

ops.at("softmax")->params["dim"] = captured_params.at("dim");
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_log_softmax, 19)

} // namespace ncnn

} // namespace pnnx

+ 65
- 0
tools/pnnx/src/pass_ncnn/F_logsigmoid.cpp View File

@@ -0,0 +1,65 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class F_logsigmoid : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
F.logsigmoid op 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
F.sigmoid sigmoid 1 1 input sigmoid
UnaryOp log 1 1 sigmoid out 0=8
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F_logsigmoid";
}

const char* name_str() const
{
return "f_logsigmoid";
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_logsigmoid, 19)

} // namespace ncnn

} // namespace pnnx

+ 65
- 0
tools/pnnx/src/pass_ncnn/nn_LogSigmoid.cpp View File

@@ -0,0 +1,65 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class nn_LogSigmoid : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.LogSigmoid op 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.Sigmoid sigmoid 1 1 input sigmoid
UnaryOp log 1 1 sigmoid out 0=8
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "LogSigmoid";
}

const char* name_str() const
{
return "logsigmoid";
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LogSigmoid, 19)

} // namespace ncnn

} // namespace pnnx

+ 67
- 0
tools/pnnx/src/pass_ncnn/nn_LogSoftmax.cpp View File

@@ -0,0 +1,67 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class nn_LogSoftmax : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.LogSoftmax op 1 1 input out dim=%dim
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.Softmax softmax 1 1 input softmax
UnaryOp log 1 1 softmax out 0=8
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "LogSoftmax";
}

const char* name_str() const
{
return "logsoftmax";
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);

ops.at("softmax")->params["dim"] = captured_params.at("dim");
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LogSoftmax, 19)

} // namespace ncnn

} // namespace pnnx

+ 4
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -40,6 +40,8 @@ pnnx_ncnn_add_test(F_interpolate)
pnnx_ncnn_add_test(F_layer_norm)
pnnx_ncnn_add_test(F_leaky_relu)
pnnx_ncnn_add_test(F_local_response_norm)
pnnx_ncnn_add_test(F_logsigmoid)
pnnx_ncnn_add_test(F_log_softmax)
pnnx_ncnn_add_test(F_max_pool1d)
pnnx_ncnn_add_test(F_max_pool2d)
pnnx_ncnn_add_test(F_max_pool3d)
@@ -100,6 +102,8 @@ pnnx_ncnn_add_test(nn_LayerNorm)
pnnx_ncnn_add_test(nn_LeakyReLU)
pnnx_ncnn_add_test(nn_Linear)
pnnx_ncnn_add_test(nn_LocalResponseNorm)
pnnx_ncnn_add_test(nn_LogSigmoid)
pnnx_ncnn_add_test(nn_LogSoftmax)
pnnx_ncnn_add_test(nn_LSTM)
pnnx_ncnn_add_test(nn_MaxPool1d)
pnnx_ncnn_add_test(nn_MaxPool2d)


+ 62
- 0
tools/pnnx/tests/ncnn/test_F_log_softmax.py View File

@@ -0,0 +1,62 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = F.log_softmax(x, 0)
y = F.log_softmax(y, 1)
z = F.log_softmax(z, 2)
z2 = F.log_softmax(z, -1)
return x, y, z, z2

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(16)
y = torch.rand(2, 16)
z = torch.rand(3, 12, 16)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_F_log_softmax.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_log_softmax.pt inputshape=[16],[2,16],[3,12,16]")

# ncnn inference
import test_F_log_softmax_ncnn
b = test_F_log_softmax_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 63
- 0
tools/pnnx/tests/ncnn/test_F_logsigmoid.py View File

@@ -0,0 +1,63 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
x = F.logsigmoid(x)
y = F.logsigmoid(y)
z = F.logsigmoid(z)
w = F.logsigmoid(w)
return x, y, z, w

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(16)
y = torch.rand(2, 16)
z = torch.rand(3, 12, 16)
w = torch.rand(5, 7, 9, 11)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_F_logsigmoid.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_logsigmoid.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]")

# ncnn inference
import test_F_logsigmoid_ncnn
b = test_F_logsigmoid_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 65
- 0
tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py View File

@@ -0,0 +1,65 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.act_0 = nn.LogSigmoid()

def forward(self, x, y, z, w):
x = self.act_0(x)
y = self.act_0(y)
z = self.act_0(z)
w = self.act_0(w)
return x, y, z, w

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(12)
y = torch.rand(12, 64)
z = torch.rand(12, 24, 64)
w = torch.rand(12, 24, 32, 64)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_nn_LogSigmoid.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_LogSigmoid.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]")

# ncnn inference
import test_nn_LogSigmoid_ncnn
b = test_nn_LogSigmoid_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 67
- 0
tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py View File

@@ -0,0 +1,67 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.act_0 = nn.LogSoftmax(dim=0)
self.act_1 = nn.LogSoftmax(dim=1)
self.act_2 = nn.LogSoftmax(dim=2)
self.act_3 = nn.LogSoftmax(dim=-1)

def forward(self, x, y, z):
x = self.act_0(x)
y = self.act_1(y)
z = self.act_2(z)
z2 = self.act_3(z)
return x, y, z, z2

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(12)
y = torch.rand(12, 64)
z = torch.rand(12, 24, 64)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_nn_LogSoftmax.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_LogSoftmax.pt inputshape=[12],[12,64],[12,24,64]")

# ncnn inference
import test_nn_LogSoftmax_ncnn
b = test_nn_LogSoftmax_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save