| @@ -230,6 +230,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_max.cpp | |||
| pass_level2/torch_mean.cpp | |||
| pass_level2/torch_min.cpp | |||
| pass_level2/torch_mm.cpp | |||
| pass_level2/torch_ne.cpp | |||
| pass_level2/torch_norm.cpp | |||
| pass_level2/torch_normal.cpp | |||
| @@ -511,6 +512,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/torch_max.cpp | |||
| pass_ncnn/torch_mean.cpp | |||
| pass_ncnn/torch_min.cpp | |||
| pass_ncnn/torch_mm.cpp | |||
| pass_ncnn/torch_norm.cpp | |||
| pass_ncnn/torch_permute.cpp | |||
| pass_ncnn/torch_prod.cpp | |||
| @@ -0,0 +1,41 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_mm : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 mat2 | |||
| aten::mm op_0 2 1 input mat2 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.mm"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mm, 20) | |||
| } // namespace pnnx | |||
| @@ -89,8 +89,11 @@ void convert_attribute(Graph& graph) | |||
| op->params["2"] = new_shape[0]; | |||
| } | |||
| op->attrs["0"] = data; | |||
| op->attrs.erase(key); | |||
| if (key != "0") | |||
| { | |||
| op->attrs["0"] = data; | |||
| op->attrs.erase(key); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_mm : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 mat2 | |||
| torch.mm op_0 2 1 input mat2 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Gemm"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "mm"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mm, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -205,6 +205,7 @@ pnnx_add_test(torch_matmul) | |||
| pnnx_add_test(torch_max) | |||
| pnnx_add_test(torch_mean) | |||
| pnnx_add_test(torch_min) | |||
| pnnx_add_test(torch_mm) | |||
| pnnx_add_test(torch_ne) | |||
| pnnx_add_test(torch_norm) | |||
| pnnx_add_test(torch_ones) | |||
| @@ -148,6 +148,7 @@ pnnx_ncnn_add_test(torch_matmul) | |||
| pnnx_ncnn_add_test(torch_max) | |||
| pnnx_ncnn_add_test(torch_mean) | |||
| pnnx_ncnn_add_test(torch_min) | |||
| pnnx_ncnn_add_test(torch_mm) | |||
| pnnx_ncnn_add_test(torch_norm) | |||
| pnnx_ncnn_add_test(torch_permute) | |||
| pnnx_ncnn_add_test(torch_prod) | |||
| @@ -0,0 +1,55 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, a0, a1): | |||
| a = torch.mm(a0, a1) | |||
| return a | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| a0 = torch.rand(23, 14) | |||
| a1 = torch.rand(14, 35) | |||
| a = net(a0, a1) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (a0, a1)) | |||
| mod.save("test_torch_mm.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_mm.pt inputshape=[23,14],[14,35]") | |||
| # ncnn inference | |||
| import test_torch_mm_ncnn | |||
| b = test_torch_mm_ncnn.test_inference() | |||
| return torch.allclose(a, b, 1e-4, 1e-4) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,55 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, a0, a1): | |||
| a = torch.mm(a0, a1) | |||
| return a | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| a0 = torch.rand(23, 14) | |||
| a1 = torch.rand(14, 35) | |||
| a = net(a0, a1) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (a0, a1)) | |||
| mod.save("test_torch_mm.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_mm.pt inputshape=[23,14],[14,35]") | |||
| # pnnx inference | |||
| import test_torch_mm_pnnx | |||
| b = test_torch_mm_pnnx.test_inference() | |||
| return torch.equal(a, b) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||