diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 8b13ef702..211c0385c 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -230,6 +230,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_max.cpp pass_level2/torch_mean.cpp pass_level2/torch_min.cpp + pass_level2/torch_mm.cpp pass_level2/torch_ne.cpp pass_level2/torch_norm.cpp pass_level2/torch_normal.cpp @@ -511,6 +512,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_max.cpp pass_ncnn/torch_mean.cpp pass_ncnn/torch_min.cpp + pass_ncnn/torch_mm.cpp pass_ncnn/torch_norm.cpp pass_ncnn/torch_permute.cpp pass_ncnn/torch_prod.cpp diff --git a/tools/pnnx/src/pass_level2/torch_mm.cpp b/tools/pnnx/src/pass_level2/torch_mm.cpp new file mode 100644 index 000000000..59988fbd6 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_mm.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_mm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 mat2 +aten::mm op_0 2 1 input mat2 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.mm"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mm, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_attribute.cpp b/tools/pnnx/src/pass_ncnn/convert_attribute.cpp index 240134ba6..dc8694e42 100644 --- a/tools/pnnx/src/pass_ncnn/convert_attribute.cpp +++ b/tools/pnnx/src/pass_ncnn/convert_attribute.cpp @@ -89,8 +89,11 @@ void convert_attribute(Graph& graph) op->params["2"] = new_shape[0]; } - op->attrs["0"] = data; - op->attrs.erase(key); + if (key != "0") + { + op->attrs["0"] = data; + op->attrs.erase(key); + } } } diff --git a/tools/pnnx/src/pass_ncnn/torch_mm.cpp b/tools/pnnx/src/pass_ncnn/torch_mm.cpp new file mode 100644 index 000000000..d24112907 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_mm.cpp @@ -0,0 +1,50 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_mm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 mat2 +torch.mm op_0 2 1 input mat2 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Gemm"; + } + + const char* name_str() const + { + return "mm"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mm, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index ab7358041..45ad247cd 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -205,6 +205,7 @@ pnnx_add_test(torch_matmul) pnnx_add_test(torch_max) pnnx_add_test(torch_mean) pnnx_add_test(torch_min) +pnnx_add_test(torch_mm) pnnx_add_test(torch_ne) pnnx_add_test(torch_norm) pnnx_add_test(torch_ones) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index d456c0da6..03ca51692 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -148,6 +148,7 @@ pnnx_ncnn_add_test(torch_matmul) pnnx_ncnn_add_test(torch_max) pnnx_ncnn_add_test(torch_mean) pnnx_ncnn_add_test(torch_min) +pnnx_ncnn_add_test(torch_mm) pnnx_ncnn_add_test(torch_norm) pnnx_ncnn_add_test(torch_permute) pnnx_ncnn_add_test(torch_prod) diff --git a/tools/pnnx/tests/ncnn/test_torch_mm.py b/tools/pnnx/tests/ncnn/test_torch_mm.py new file mode 100644 index 000000000..778ca983d --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_mm.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, a0, a1): + a = torch.mm(a0, a1) + return a + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + a0 = torch.rand(23, 14) + a1 = torch.rand(14, 35) + + a = net(a0, a1) + + # export torchscript + mod = torch.jit.trace(net, (a0, a1)) + mod.save("test_torch_mm.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_mm.pt inputshape=[23,14],[14,35]") + + # ncnn inference + import test_torch_mm_ncnn + b = test_torch_mm_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_mm.py b/tools/pnnx/tests/test_torch_mm.py new file mode 100644 index 000000000..a59643b53 --- /dev/null +++ b/tools/pnnx/tests/test_torch_mm.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, a0, a1): + a = torch.mm(a0, a1) + return a + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + a0 = torch.rand(23, 14) + a1 = torch.rand(14, 35) + + a = net(a0, a1) + + # export torchscript + mod = torch.jit.trace(net, (a0, a1)) + mod.save("test_torch_mm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_mm.pt inputshape=[23,14],[14,35]") + + # pnnx inference + import test_torch_mm_pnnx + b = test_torch_mm_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)