From b3fbbccd8be1026f0349af0a488e8a2d0b32dd58 Mon Sep 17 00:00:00 2001 From: zyt1024 <42999008+zyt1024@users.noreply.github.com> Date: Mon, 4 Sep 2023 11:09:30 +0800 Subject: [PATCH] pnnx convert torch narrow (#4918) --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/pass_level2/torch_narrow.cpp | 43 ++++++++++++++ tools/pnnx/tests/CMakeLists.txt | 1 + tools/pnnx/tests/test_torch_narrow.py | 63 +++++++++++++++++++++ 4 files changed, 108 insertions(+) create mode 100644 tools/pnnx/src/pass_level2/torch_narrow.cpp create mode 100644 tools/pnnx/tests/test_torch_narrow.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 43f1a0ce9..f4a170a4b 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -237,6 +237,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_mean.cpp pass_level2/torch_min.cpp pass_level2/torch_mm.cpp + pass_level2/torch_narrow.cpp pass_level2/torch_ne.cpp pass_level2/torch_norm.cpp pass_level2/torch_normal.cpp diff --git a/tools/pnnx/src/pass_level2/torch_narrow.cpp b/tools/pnnx/src/pass_level2/torch_narrow.cpp new file mode 100644 index 000000000..c827d4612 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_narrow.cpp @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_narrow : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +pnnx.Input input_2 0 1 start +pnnx.Input input_3 0 1 length +aten::narrow op_0 4 1 input dim start length out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.narrow"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_narrow, 20) + +} // namespace pnnx \ No newline at end of file diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 9c2c66864..ca209cb2d 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -212,6 +212,7 @@ pnnx_add_test(torch_max) pnnx_add_test(torch_mean) pnnx_add_test(torch_min) pnnx_add_test(torch_mm) +pnnx_add_test(torch_narrow) pnnx_add_test(torch_ne) pnnx_add_test(torch_norm) pnnx_add_test(torch_ones) diff --git a/tools/pnnx/tests/test_torch_narrow.py b/tools/pnnx/tests/test_torch_narrow.py new file mode 100644 index 000000000..8acce84a0 --- /dev/null +++ b/tools/pnnx/tests/test_torch_narrow.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + out0 = torch.narrow(x, 0, 0, 2) + out1 = torch.narrow(x, 1, 1, 2) + out2 = torch.narrow(y, 0, 0, 2) + out3 = torch.narrow(y, 1, 1, 2) + out4 = torch.narrow(z, 0, 0, 2) + out5 = torch.narrow(z, 1, 1, 2) + return out0, out1, out2, out3, out4, out5 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 3) + y = torch.rand(5, 3) + z = torch.rand(3, 5) + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_narrow.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_narrow.pt inputshape=[3,3],[5,3],[3,5]") + + # pnnx inference + import test_torch_narrow_pnnx + b = test_torch_narrow_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)