From d42e048b56a4a99fb8447ad2a18678a9a1ae2def Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 18 Mar 2022 12:24:05 +0800 Subject: [PATCH] pnnx convert torch.addmm (#3634) --- src/layer/gemm.cpp | 5 + tests/test_gemm.cpp | 16 +- tools/pnnx/src/CMakeLists.txt | 2 + tools/pnnx/src/pass_level2/torch_addmm.cpp | 44 +++++ tools/pnnx/src/pass_ncnn/torch_addmm.cpp | 196 +++++++++++++++++++++ tools/pnnx/tests/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/test_torch_addmm.py | 72 ++++++++ tools/pnnx/tests/test_torch_addmm.py | 68 +++++++ 9 files changed, 404 insertions(+), 1 deletion(-) create mode 100644 tools/pnnx/src/pass_level2/torch_addmm.cpp create mode 100644 tools/pnnx/src/pass_ncnn/torch_addmm.cpp create mode 100644 tools/pnnx/tests/ncnn/test_torch_addmm.py create mode 100644 tools/pnnx/tests/test_torch_addmm.py diff --git a/src/layer/gemm.cpp b/src/layer/gemm.cpp index e44075f6f..fc312e1bb 100644 --- a/src/layer/gemm.cpp +++ b/src/layer/gemm.cpp @@ -104,6 +104,11 @@ int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_bl // auto broadcast from h to w is the ncnn-style convention broadcast_type_C = 1; } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } if (C.dims == 2 && C.w == 1 && C.h == M) { // Mx1 diff --git a/tests/test_gemm.cpp b/tests/test_gemm.cpp index 455e5f142..ab01102d7 100644 --- a/tests/test_gemm.cpp +++ b/tests/test_gemm.cpp @@ -146,6 +146,19 @@ static int test_gemm_5() || test_gemm_bias(16, 24, 15, RandomMat(1, 14), 1.7f, 1.3f, 1, 1); } +static int test_gemm_6() +{ + return 0 + || test_gemm_bias(13, 14, 15, RandomMat(14), 0.1f, 0.4f, 0, 0) + || test_gemm_bias(13, 14, 15, RandomMat(14), 0.4f, -1.f, 1, 0) + || test_gemm_bias(13, 14, 15, RandomMat(14), -0.3f, -0.21f, 0, 1) + || test_gemm_bias(13, 14, 15, RandomMat(14), 1.7f, 1.3f, 1, 1) + || test_gemm_bias(16, 24, 15, RandomMat(14), 0.1f, 0.4f, 0, 0) + || test_gemm_bias(16, 24, 15, RandomMat(14), 0.4f, -1.f, 1, 0) + || test_gemm_bias(16, 24, 15, RandomMat(14), -0.3f, -0.21f, 0, 1) + || test_gemm_bias(16, 24, 15, RandomMat(14), 1.7f, 1.3f, 1, 1); +} + int main() { SRAND(7767517); @@ -156,5 +169,6 @@ int main() || test_gemm_2() || test_gemm_3() || test_gemm_4() - || test_gemm_5(); + || test_gemm_5() + || test_gemm_6(); } diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 4786ed9d4..75cfd281e 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -175,6 +175,7 @@ set(pnnx_pass_level2_SRCS pass_level2/Tensor_select.cpp pass_level2/Tensor_slice.cpp pass_level2/Tensor_view.cpp + pass_level2/torch_addmm.cpp pass_level2/torch_amax.cpp pass_level2/torch_amin.cpp pass_level2/torch_arange.cpp @@ -402,6 +403,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/Tensor_repeat.cpp pass_ncnn/Tensor_slice.cpp pass_ncnn/Tensor_view.cpp + pass_ncnn/torch_addmm.cpp pass_ncnn/torch_amax.cpp pass_ncnn/torch_amin.cpp pass_ncnn/torch_clamp.cpp diff --git a/tools/pnnx/src/pass_level2/torch_addmm.cpp b/tools/pnnx/src/pass_level2/torch_addmm.cpp new file mode 100644 index 000000000..c8e14a713 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_addmm.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_addmm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 mat1 +pnnx.Input input_2 0 1 mat2 +pnnx.Input input_3 0 1 beta +pnnx.Input input_4 0 1 alpha +aten::addmm op_0 5 1 input mat1 mat2 beta alpha out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.addmm"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_addmm, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_addmm.cpp b/tools/pnnx/src/pass_ncnn/torch_addmm.cpp new file mode 100644 index 000000000..9a039092e --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_addmm.cpp @@ -0,0 +1,196 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_addmm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 mat1 +pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @qwq +torch.addmm op_0 3 1 bias mat1 weight out alpha=%alpha beta=%beta +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InnerProduct"; + } + + const char* name_str() const + { + return "addmm"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + float alpha = 1.f; + float beta = 1.f; + + if (captured_params.at("alpha").type == 2) + { + alpha = captured_params.at("alpha").i; + } + if (captured_params.at("alpha").type == 3) + { + alpha = captured_params.at("alpha").f; + } + + if (captured_params.at("beta").type == 2) + { + beta = captured_params.at("beta").i; + } + if (captured_params.at("beta").type == 3) + { + beta = captured_params.at("beta").f; + } + + if (alpha != 1.f || beta != 1.f) + return false; + + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + if (weight.shape.size() != 2 || bias.shape.size() != 1) + return false; + + if (weight.shape[1] != bias.shape[0]) + return false; + + return true; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + // transpose weight inch-outch to outch-inch + const int inch = weight.shape[0]; + const int outch = weight.shape[1]; + std::vector new_weight; + { + const float* w = (const float*)weight.data.data(); + + new_weight.resize(outch * inch); + float* w2 = (float*)new_weight.data(); + + // reorder weight from inch-outch to outch-inch + for (int i = 0; i < outch; i++) + { + for (int j = 0; j < inch; j++) + { + w2[i * inch + j] = w[j * outch + i]; + } + } + } + + op->params["0"] = outch; + op->params["1"] = 1; + op->params["2"] = (int)(weight.data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = Attribute({outch, inch}, new_weight); + op->attrs["2"] = bias; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_addmm, 20) + +class torch_addmm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 mat1 +pnnx.Input input_2 0 1 mat2 +torch.addmm op_0 3 1 input mat1 mat2 out alpha=%alpha beta=%beta +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Gemm"; + } + + const char* name_str() const + { + return "addmm"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::swap(op->inputs[0], op->inputs[1]); + std::swap(op->inputs[1], op->inputs[2]); + + float alpha = 1.f; + float beta = 1.f; + + if (captured_params.at("alpha").type == 2) + { + alpha = captured_params.at("alpha").i; + } + if (captured_params.at("alpha").type == 3) + { + alpha = captured_params.at("alpha").f; + } + + if (captured_params.at("beta").type == 2) + { + beta = captured_params.at("beta").i; + } + if (captured_params.at("beta").type == 3) + { + beta = captured_params.at("beta").f; + } + + op->params["0"] = alpha; + op->params["1"] = beta / alpha; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_addmm_1, 22) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 4eea37e1d..400ce8388 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -162,6 +162,7 @@ pnnx_add_test(Tensor_select) pnnx_add_test(Tensor_slice) pnnx_add_test(Tensor_view) +pnnx_add_test(torch_addmm) pnnx_add_test(torch_amax) pnnx_add_test(torch_amin) pnnx_add_test(torch_argmax) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index bcbfd4fa6..b551647f9 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -125,6 +125,7 @@ pnnx_ncnn_add_test(Tensor_reshape) pnnx_ncnn_add_test(Tensor_slice) pnnx_ncnn_add_test(Tensor_view) +pnnx_ncnn_add_test(torch_addmm) pnnx_ncnn_add_test(torch_amax) pnnx_ncnn_add_test(torch_amin) pnnx_ncnn_add_test(torch_cat) diff --git a/tools/pnnx/tests/ncnn/test_torch_addmm.py b/tools/pnnx/tests/ncnn/test_torch_addmm.py new file mode 100644 index 000000000..61c402a8c --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_addmm.py @@ -0,0 +1,72 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.c0 = nn.Parameter(torch.rand(12)) + self.c2 = nn.Parameter(torch.rand(48, 12)) + + def forward(self, a0, a1, a2, b0, b1, b2, c1): + a = torch.addmm(a0, a1, a2) + b = torch.addmm(b0, b1, b2, beta=1.4, alpha=0.7) + c = torch.addmm(self.c0, c1, self.c2, beta=1, alpha=1) + return a, b, c + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + a0 = torch.rand(23) + a1 = torch.rand(13, 16) + a2 = torch.rand(16, 23) + b0 = torch.rand(7, 33) + b1 = torch.rand(7, 26) + b2 = torch.rand(26, 33) + c1 = torch.rand(16, 48) + + a = net(a0, a1, a2, b0, b1, b2, c1) + + # export torchscript + mod = torch.jit.trace(net, (a0, a1, a2, b0, b1, b2, c1)) + mod.save("test_torch_addmm.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_addmm.pt inputshape=[23],[13,16],[16,23],[7,33],[7,26],[26,33],[16,48]") + + # ncnn inference + import test_torch_addmm_ncnn + b = test_torch_addmm_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + print(a0.shape) + print(b0.shape) + print(a0) + print(b0) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_addmm.py b/tools/pnnx/tests/test_torch_addmm.py new file mode 100644 index 000000000..0a38cabc2 --- /dev/null +++ b/tools/pnnx/tests/test_torch_addmm.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.c0 = nn.Parameter(torch.rand(12)) + self.c2 = nn.Parameter(torch.rand(48, 12)) + + def forward(self, a0, a1, a2, b0, b1, b2, c1): + a = torch.addmm(a0, a1, a2) + b = torch.addmm(b0, b1, b2, beta=1.4, alpha=0.7) + c = torch.addmm(self.c0, c1, self.c2, beta=1, alpha=1) + return a, b, c + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + a0 = torch.rand(13, 1) + a1 = torch.rand(13, 16) + a2 = torch.rand(16, 23) + b0 = torch.rand(7, 33) + b1 = torch.rand(7, 26) + b2 = torch.rand(26, 33) + c1 = torch.rand(16, 48) + + a = net(a0, a1, a2, b0, b1, b2, c1) + + # export torchscript + mod = torch.jit.trace(net, (a0, a1, a2, b0, b1, b2, c1)) + mod.save("test_torch_addmm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_addmm.pt inputshape=[13,1],[13,16],[16,23],[7,33],[7,26],[26,33],[16,48]") + + # pnnx inference + import test_torch_addmm_pnnx + b = test_torch_addmm_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)