From f7af84f0010293be416585bc9b3b621786fe1a4e Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 20 Apr 2023 16:47:25 +0800 Subject: [PATCH] pnnx reset maxpool return_indices if only output found, drop convtransposed output_size arg for static output padding (#4654) --- tools/pnnx/src/CMakeLists.txt | 1 + .../src/pass_level1/nn_ConvTranspose1d.cpp | 11 +++ .../src/pass_level1/nn_ConvTranspose2d.cpp | 11 +++ .../src/pass_level1/nn_ConvTranspose3d.cpp | 11 +++ tools/pnnx/src/pass_level3.cpp | 3 + .../src/pass_level3/fuse_maxpool_unpack.cpp | 87 +++++++++++++++++++ .../src/pass_level3/fuse_maxpool_unpack.h | 21 +++++ .../tests/ncnn/test_nn_ConvTranspose1d.py | 6 ++ .../tests/ncnn/test_nn_ConvTranspose2d.py | 6 ++ .../tests/ncnn/test_nn_ConvTranspose3d.py | 6 ++ tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py | 9 +- tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py | 6 +- tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py | 6 +- tools/pnnx/tests/test_nn_ConvTranspose1d.py | 6 ++ tools/pnnx/tests/test_nn_ConvTranspose2d.py | 6 ++ tools/pnnx/tests/test_nn_ConvTranspose3d.py | 6 ++ tools/pnnx/tests/test_nn_MaxPool1d.py | 9 +- tools/pnnx/tests/test_nn_MaxPool2d.py | 6 +- tools/pnnx/tests/test_nn_MaxPool3d.py | 6 +- 19 files changed, 205 insertions(+), 18 deletions(-) create mode 100644 tools/pnnx/src/pass_level3/fuse_maxpool_unpack.cpp create mode 100644 tools/pnnx/src/pass_level3/fuse_maxpool_unpack.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index da2c4a4fc..a16fca3cb 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -289,6 +289,7 @@ set(pnnx_pass_level3_SRCS pass_level3/fuse_einsum_operands.cpp pass_level3/fuse_expression.cpp pass_level3/fuse_index_expression.cpp + pass_level3/fuse_maxpool_unpack.cpp pass_level3/fuse_multiheadattention_unpack.cpp pass_level3/fuse_rnn_unpack.cpp pass_level3/rename_F_conv_transposend.cpp diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp index c6f2ce9b4..6677b832e 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp @@ -52,6 +52,17 @@ public: { op->attrs["bias"] = mod.attr("bias").toTensor(); } + + if (op->inputs.size() > 1) + { + fprintf(stderr, "ConvTranspose1d arg output_size detected and dropped !\n"); + + for (size_t i = 1; i < op->inputs.size(); i++) + { + op->inputs[i]->remove_consumer(op); + } + op->inputs.resize(1); + } } }; diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp index 32b55f5d3..a481cb154 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp @@ -52,6 +52,17 @@ public: { op->attrs["bias"] = mod.attr("bias").toTensor(); } + + if (op->inputs.size() > 1) + { + fprintf(stderr, "ConvTranspose2d arg output_size detected and dropped !\n"); + + for (size_t i = 1; i < op->inputs.size(); i++) + { + op->inputs[i]->remove_consumer(op); + } + op->inputs.resize(1); + } } }; diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp index 1f414efad..b8fd141de 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp @@ -52,6 +52,17 @@ public: { op->attrs["bias"] = mod.attr("bias").toTensor(); } + + if (op->inputs.size() > 1) + { + fprintf(stderr, "ConvTranspose3d arg output_size detected and dropped !\n"); + + for (size_t i = 1; i < op->inputs.size(); i++) + { + op->inputs[i]->remove_consumer(op); + } + op->inputs.resize(1); + } } }; diff --git a/tools/pnnx/src/pass_level3.cpp b/tools/pnnx/src/pass_level3.cpp index 085e5ff5e..5f05ace8f 100644 --- a/tools/pnnx/src/pass_level3.cpp +++ b/tools/pnnx/src/pass_level3.cpp @@ -23,6 +23,7 @@ #include "pass_level3/fuse_einsum_operands.h" #include "pass_level3/fuse_expression.h" #include "pass_level3/fuse_index_expression.h" +#include "pass_level3/fuse_maxpool_unpack.h" #include "pass_level3/fuse_multiheadattention_unpack.h" #include "pass_level3/fuse_rnn_unpack.h" #include "pass_level3/rename_F_conv_transposend.h" @@ -45,6 +46,8 @@ void pass_level3(Graph& g, const std::set& foldable_constants, cons fuse_einsum_operands(g); + fuse_maxpool_unpack(g); + fuse_multiheadattention_unpack(g); fuse_rnn_unpack(g); diff --git a/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.cpp new file mode 100644 index 000000000..5cba43c00 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.cpp @@ -0,0 +1,87 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_maxpool_unpack.h" +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_maxpool_unpack(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "nn.MaxPool1d" && op->type != "nn.MaxPool2d" && op->type != "nn.MaxPool3d") + continue; + + Operator* op2 = op->outputs[0]->consumers[0]; + + if (op->outputs.size() == 1 && op2->type != "prim::TupleUnpack") + { + if (op->params.find("return_indices") == op->params.end()) + continue; + + if (op->params.at("return_indices").b == false) + continue; + + matched = true; + + // no indices returned actually + op->params["return_indices"] = false; + break; + } + + if (op->outputs.size() != 1) + continue; + + if (op->outputs[0]->consumers.size() != 1) + continue; + + if (op2->type != "prim::TupleUnpack") + continue; + + matched = true; + + op->outputs[0]->producer = 0; + op->outputs[0]->remove_consumer(op2); + + for (auto& x : op2->outputs) + { + x->producer = op; + } + + op->outputs = op2->outputs; + + op2->inputs.clear(); + op2->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + + delete op2; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.h b/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.h new file mode 100644 index 000000000..19bbf1442 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_maxpool_unpack(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/test_nn_ConvTranspose1d.py b/tools/pnnx/tests/ncnn/test_nn_ConvTranspose1d.py index 0f201e747..d222ba9ca 100644 --- a/tools/pnnx/tests/ncnn/test_nn_ConvTranspose1d.py +++ b/tools/pnnx/tests/ncnn/test_nn_ConvTranspose1d.py @@ -29,6 +29,9 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose1d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose1d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(6), output_padding=(1), dilation=2, groups=1, bias=True) + self.downsample = nn.Conv1d(24, 16, 3, stride=2, padding=1) + self.upsample = nn.ConvTranspose1d(16, 24, 3, stride=2, padding=1) + def forward(self, x): x = self.deconv_0(x) x = self.deconv_1(x) @@ -39,6 +42,9 @@ class Model(nn.Module): x = self.deconv_6(x) x = self.deconv_7(x) + y = self.downsample(x) + x = self.upsample(y, output_size=x.size()) + return x def test(): diff --git a/tools/pnnx/tests/ncnn/test_nn_ConvTranspose2d.py b/tools/pnnx/tests/ncnn/test_nn_ConvTranspose2d.py index f13c7c866..9b0f35dc7 100644 --- a/tools/pnnx/tests/ncnn/test_nn_ConvTranspose2d.py +++ b/tools/pnnx/tests/ncnn/test_nn_ConvTranspose2d.py @@ -29,6 +29,9 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose2d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), output_padding=(1,0), dilation=2, groups=1, bias=True) + self.downsample = nn.Conv2d(24, 16, 3, stride=2, padding=1) + self.upsample = nn.ConvTranspose2d(16, 24, 3, stride=2, padding=1) + def forward(self, x): x = self.deconv_0(x) x = self.deconv_1(x) @@ -39,6 +42,9 @@ class Model(nn.Module): x = self.deconv_6(x) x = self.deconv_7(x) + y = self.downsample(x) + x = self.upsample(y, output_size=x.size()) + return x def test(): diff --git a/tools/pnnx/tests/ncnn/test_nn_ConvTranspose3d.py b/tools/pnnx/tests/ncnn/test_nn_ConvTranspose3d.py index 683c44755..4284d86ee 100644 --- a/tools/pnnx/tests/ncnn/test_nn_ConvTranspose3d.py +++ b/tools/pnnx/tests/ncnn/test_nn_ConvTranspose3d.py @@ -29,6 +29,9 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6,7), output_padding=(1,0,1), dilation=2, groups=1, bias=True) + self.downsample = nn.Conv3d(24, 16, 3, stride=2, padding=1) + self.upsample = nn.ConvTranspose3d(16, 24, 3, stride=2, padding=1) + def forward(self, x): x = self.deconv_0(x) x = self.deconv_1(x) @@ -39,6 +42,9 @@ class Model(nn.Module): x = self.deconv_6(x) x = self.deconv_7(x) + y = self.downsample(x) + x = self.upsample(y, output_size=x.size()) + return x def test(): diff --git a/tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py b/tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py index f41552d97..348dbc661 100644 --- a/tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py +++ b/tools/pnnx/tests/ncnn/test_nn_MaxPool1d.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -25,7 +26,7 @@ class Model(nn.Module): self.pool_2 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) self.pool_3 = nn.MaxPool1d(kernel_size=5, stride=2, padding=2, dilation=1, return_indices=False, ceil_mode=True) self.pool_4 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) - self.pool_5 = nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + self.pool_5 = nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=True, ceil_mode=True) self.pool_6 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2, dilation=1, return_indices=False, ceil_mode=False) def forward(self, x): @@ -36,7 +37,7 @@ class Model(nn.Module): x = self.pool_2(x) x = self.pool_3(x) x = self.pool_4(x) - x = self.pool_5(x) + x, tx = self.pool_5(x) x = self.pool_6(x) y = self.pool_0(y) @@ -44,7 +45,9 @@ class Model(nn.Module): y = self.pool_2(y) y = self.pool_3(y) y = self.pool_4(y) - y = self.pool_5(y) + if version.parse(torch.__version__) < version.parse('1.10'): + y = y.unsqueeze(0) + y, ty = self.pool_5(y) y = self.pool_6(y) return x, y diff --git a/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py b/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py index b1c5df059..f4feb3008 100644 --- a/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py +++ b/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py @@ -25,7 +25,7 @@ class Model(nn.Module): self.pool_2 = nn.MaxPool2d(kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, ceil_mode=False) self.pool_3 = nn.MaxPool2d(kernel_size=(4,5), stride=(1,2), padding=(1,2), dilation=1, ceil_mode=True) self.pool_4 = nn.MaxPool2d(kernel_size=(2,3), stride=1, padding=1, dilation=1, ceil_mode=False) - self.pool_5 = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=True) + self.pool_5 = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=True, ceil_mode=True) self.pool_6 = nn.MaxPool2d(kernel_size=(5,4), stride=1, padding=2, dilation=1, ceil_mode=False) def forward(self, x): @@ -36,7 +36,7 @@ class Model(nn.Module): x = self.pool_2(x) x = self.pool_3(x) x = self.pool_4(x) - x = self.pool_5(x) + x, tx = self.pool_5(x) x = self.pool_6(x) y = self.pool_0(y) @@ -44,7 +44,7 @@ class Model(nn.Module): y = self.pool_2(y) y = self.pool_3(y) y = self.pool_4(y) - y = self.pool_5(y) + y, ty = self.pool_5(y) y = self.pool_6(y) return x, y diff --git a/tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py b/tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py index deb9cdb64..0de797854 100644 --- a/tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py +++ b/tools/pnnx/tests/ncnn/test_nn_MaxPool3d.py @@ -25,7 +25,7 @@ class Model(nn.Module): self.pool_2 = nn.MaxPool3d(kernel_size=(1,2,3), stride=1, padding=(0,0,1), dilation=1, return_indices=False, ceil_mode=False) self.pool_3 = nn.MaxPool3d(kernel_size=(3,4,5), stride=(1,2,2), padding=(1,2,2), dilation=1, return_indices=False, ceil_mode=True) self.pool_4 = nn.MaxPool3d(kernel_size=(2,3,3), stride=1, padding=1, dilation=(1,1,1), return_indices=False, ceil_mode=False) - self.pool_5 = nn.MaxPool3d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + self.pool_5 = nn.MaxPool3d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=True, ceil_mode=True) self.pool_6 = nn.MaxPool3d(kernel_size=(5,4,4), stride=1, padding=2, dilation=1, return_indices=False, ceil_mode=False) def forward(self, x): @@ -36,7 +36,7 @@ class Model(nn.Module): x = self.pool_2(x) x = self.pool_3(x) x = self.pool_4(x) - x = self.pool_5(x) + x, tx = self.pool_5(x) x = self.pool_6(x) y = self.pool_0(y) @@ -44,7 +44,7 @@ class Model(nn.Module): y = self.pool_2(y) y = self.pool_3(y) y = self.pool_4(y) - y = self.pool_5(y) + y, ty = self.pool_5(y) y = self.pool_6(y) return x, y diff --git a/tools/pnnx/tests/test_nn_ConvTranspose1d.py b/tools/pnnx/tests/test_nn_ConvTranspose1d.py index 2219accbc..afa161135 100644 --- a/tools/pnnx/tests/test_nn_ConvTranspose1d.py +++ b/tools/pnnx/tests/test_nn_ConvTranspose1d.py @@ -29,6 +29,9 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose1d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose1d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(6), output_padding=(1), dilation=2, groups=1, bias=True) + self.downsample = nn.Conv1d(24, 16, 3, stride=2, padding=1) + self.upsample = nn.ConvTranspose1d(16, 24, 3, stride=2, padding=1) + def forward(self, x): x = self.deconv_0(x) x = self.deconv_1(x) @@ -39,6 +42,9 @@ class Model(nn.Module): x = self.deconv_6(x) x = self.deconv_7(x) + y = self.downsample(x) + x = self.upsample(y, output_size=x.size()) + return x def test(): diff --git a/tools/pnnx/tests/test_nn_ConvTranspose2d.py b/tools/pnnx/tests/test_nn_ConvTranspose2d.py index 5b3acd808..d8d264b3f 100644 --- a/tools/pnnx/tests/test_nn_ConvTranspose2d.py +++ b/tools/pnnx/tests/test_nn_ConvTranspose2d.py @@ -29,6 +29,9 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose2d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), output_padding=(1,0), dilation=2, groups=1, bias=True) + self.downsample = nn.Conv2d(24, 16, 3, stride=2, padding=1) + self.upsample = nn.ConvTranspose2d(16, 24, 3, stride=2, padding=1) + def forward(self, x): x = self.deconv_0(x) x = self.deconv_1(x) @@ -39,6 +42,9 @@ class Model(nn.Module): x = self.deconv_6(x) x = self.deconv_7(x) + y = self.downsample(x) + x = self.upsample(y, output_size=x.size()) + return x def test(): diff --git a/tools/pnnx/tests/test_nn_ConvTranspose3d.py b/tools/pnnx/tests/test_nn_ConvTranspose3d.py index de727228e..3afbb1b03 100644 --- a/tools/pnnx/tests/test_nn_ConvTranspose3d.py +++ b/tools/pnnx/tests/test_nn_ConvTranspose3d.py @@ -29,6 +29,9 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6,7), output_padding=(1,0,1), dilation=2, groups=1, bias=True) + self.downsample = nn.Conv3d(24, 16, 3, stride=2, padding=1) + self.upsample = nn.ConvTranspose3d(16, 24, 3, stride=2, padding=1) + def forward(self, x): x = self.deconv_0(x) x = self.deconv_1(x) @@ -39,6 +42,9 @@ class Model(nn.Module): x = self.deconv_6(x) x = self.deconv_7(x) + y = self.downsample(x) + x = self.upsample(y, output_size=x.size()) + return x def test(): diff --git a/tools/pnnx/tests/test_nn_MaxPool1d.py b/tools/pnnx/tests/test_nn_MaxPool1d.py index 6e2c05974..c9124d61a 100644 --- a/tools/pnnx/tests/test_nn_MaxPool1d.py +++ b/tools/pnnx/tests/test_nn_MaxPool1d.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -25,7 +26,7 @@ class Model(nn.Module): self.pool_2 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) self.pool_3 = nn.MaxPool1d(kernel_size=5, stride=2, padding=2, dilation=1, return_indices=False, ceil_mode=True) self.pool_4 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=2, return_indices=False, ceil_mode=False) - self.pool_5 = nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + self.pool_5 = nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=True, ceil_mode=True) self.pool_6 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) def forward(self, x, y): @@ -34,7 +35,7 @@ class Model(nn.Module): x = self.pool_2(x) x = self.pool_3(x) x = self.pool_4(x) - x = self.pool_5(x) + x, tx = self.pool_5(x) x, indices = self.pool_6(x) y = self.pool_0(y) @@ -42,7 +43,9 @@ class Model(nn.Module): y = self.pool_2(y) y = self.pool_3(y) y = self.pool_4(y) - y = self.pool_5(y) + if version.parse(torch.__version__) < version.parse('1.10'): + y = y.unsqueeze(0) + y, ty = self.pool_5(y) return x, indices, y diff --git a/tools/pnnx/tests/test_nn_MaxPool2d.py b/tools/pnnx/tests/test_nn_MaxPool2d.py index f171d4967..6fbc708a6 100644 --- a/tools/pnnx/tests/test_nn_MaxPool2d.py +++ b/tools/pnnx/tests/test_nn_MaxPool2d.py @@ -25,7 +25,7 @@ class Model(nn.Module): self.pool_2 = nn.MaxPool2d(kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, return_indices=False, ceil_mode=False) self.pool_3 = nn.MaxPool2d(kernel_size=(4,5), stride=(1,2), padding=(1,2), dilation=1, return_indices=False, ceil_mode=True) self.pool_4 = nn.MaxPool2d(kernel_size=(2,3), stride=1, padding=1, dilation=(1,2), return_indices=False, ceil_mode=False) - self.pool_5 = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + self.pool_5 = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=True, ceil_mode=True) self.pool_6 = nn.MaxPool2d(kernel_size=(5,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) def forward(self, x, y): @@ -34,7 +34,7 @@ class Model(nn.Module): x = self.pool_2(x) x = self.pool_3(x) x = self.pool_4(x) - x = self.pool_5(x) + x, tx = self.pool_5(x) x, indices = self.pool_6(x) y = self.pool_0(y) @@ -42,7 +42,7 @@ class Model(nn.Module): y = self.pool_2(y) y = self.pool_3(y) y = self.pool_4(y) - y = self.pool_5(y) + y, ty = self.pool_5(y) return x, indices, y def test(): diff --git a/tools/pnnx/tests/test_nn_MaxPool3d.py b/tools/pnnx/tests/test_nn_MaxPool3d.py index a0b7063a8..02c058e64 100644 --- a/tools/pnnx/tests/test_nn_MaxPool3d.py +++ b/tools/pnnx/tests/test_nn_MaxPool3d.py @@ -25,7 +25,7 @@ class Model(nn.Module): self.pool_2 = nn.MaxPool3d(kernel_size=(1,2,3), stride=1, padding=(0,0,1), dilation=1, return_indices=False, ceil_mode=False) self.pool_3 = nn.MaxPool3d(kernel_size=(3,4,5), stride=(1,2,2), padding=(1,2,2), dilation=1, return_indices=False, ceil_mode=True) self.pool_4 = nn.MaxPool3d(kernel_size=(2,3,3), stride=1, padding=1, dilation=(1,2,2), return_indices=False, ceil_mode=False) - self.pool_5 = nn.MaxPool3d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + self.pool_5 = nn.MaxPool3d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=True, ceil_mode=True) self.pool_6 = nn.MaxPool3d(kernel_size=(5,4,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) def forward(self, x, y): @@ -34,7 +34,7 @@ class Model(nn.Module): x = self.pool_2(x) x = self.pool_3(x) x = self.pool_4(x) - x = self.pool_5(x) + x, tx = self.pool_5(x) x, indices = self.pool_6(x) y = self.pool_0(y) @@ -42,7 +42,7 @@ class Model(nn.Module): y = self.pool_2(y) y = self.pool_3(y) y = self.pool_4(y) - y = self.pool_5(y) + y, ty = self.pool_5(y) return x, indices, y def test():