From 42e71609508fde1bd54d9d9de6ca5522ee3bcf37 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 22 Dec 2021 16:48:54 +0800 Subject: [PATCH] eliminate pnnx dropout --- tools/pnnx/src/CMakeLists.txt | 19 ++--- tools/pnnx/src/pass_level5.cpp | 3 + .../src/pass_level5/eliminate_dropout.cpp | 76 +++++++++++++++++++ .../eliminate_dropout.h} | 32 +------- tools/pnnx/src/pass_ncnn/F_alpha_dropout.cpp | 49 ------------ tools/pnnx/src/pass_ncnn/F_dropout.cpp | 49 ------------ tools/pnnx/src/pass_ncnn/F_dropout2d.cpp | 49 ------------ tools/pnnx/src/pass_ncnn/F_dropout3d.cpp | 49 ------------ .../src/pass_ncnn/F_feature_alpha_dropout.cpp | 49 ------------ tools/pnnx/src/pass_ncnn/nn_AlphaDropout.cpp | 49 ------------ tools/pnnx/src/pass_ncnn/nn_Dropout2d.cpp | 49 ------------ tools/pnnx/src/pass_ncnn/nn_Dropout3d.cpp | 49 ------------ 12 files changed, 91 insertions(+), 431 deletions(-) create mode 100644 tools/pnnx/src/pass_level5/eliminate_dropout.cpp rename tools/pnnx/src/{pass_ncnn/nn_Dropout.cpp => pass_level5/eliminate_dropout.h} (57%) delete mode 100644 tools/pnnx/src/pass_ncnn/F_alpha_dropout.cpp delete mode 100644 tools/pnnx/src/pass_ncnn/F_dropout.cpp delete mode 100644 tools/pnnx/src/pass_ncnn/F_dropout2d.cpp delete mode 100644 tools/pnnx/src/pass_ncnn/F_dropout3d.cpp delete mode 100644 tools/pnnx/src/pass_ncnn/F_feature_alpha_dropout.cpp delete mode 100644 tools/pnnx/src/pass_ncnn/nn_AlphaDropout.cpp delete mode 100644 tools/pnnx/src/pass_ncnn/nn_Dropout2d.cpp delete mode 100644 tools/pnnx/src/pass_ncnn/nn_Dropout3d.cpp diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 5645b6384..996aa3bc7 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -208,6 +208,7 @@ set(pnnx_pass_level4_SRCS ) set(pnnx_pass_level5_SRCS + pass_level5/eliminate_dropout.cpp pass_level5/eliminate_maxpool_indices.cpp pass_level5/eliminate_slice.cpp pass_level5/eliminate_view_reshape.cpp @@ -251,7 +252,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/F_adaptive_max_pool1d.cpp pass_ncnn/F_adaptive_max_pool2d.cpp pass_ncnn/F_adaptive_max_pool3d.cpp - pass_ncnn/F_alpha_dropout.cpp + #pass_ncnn/F_alpha_dropout.cpp pass_ncnn/F_avg_pool1d.cpp pass_ncnn/F_avg_pool2d.cpp pass_ncnn/F_avg_pool3d.cpp @@ -260,12 +261,12 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/F_conv1d.cpp pass_ncnn/F_conv2d.cpp pass_ncnn/F_conv3d.cpp - pass_ncnn/F_dropout.cpp - pass_ncnn/F_dropout2d.cpp - pass_ncnn/F_dropout3d.cpp + #pass_ncnn/F_dropout.cpp + #pass_ncnn/F_dropout2d.cpp + #pass_ncnn/F_dropout3d.cpp pass_ncnn/F_elu.cpp pass_ncnn/F_embedding.cpp - pass_ncnn/F_feature_alpha_dropout.cpp + #pass_ncnn/F_feature_alpha_dropout.cpp pass_ncnn/F_gelu.cpp pass_ncnn/F_group_norm.cpp pass_ncnn/F_hardsigmoid.cpp @@ -302,7 +303,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/nn_AdaptiveMaxPool1d.cpp pass_ncnn/nn_AdaptiveMaxPool2d.cpp pass_ncnn/nn_AdaptiveMaxPool3d.cpp - pass_ncnn/nn_AlphaDropout.cpp + #pass_ncnn/nn_AlphaDropout.cpp pass_ncnn/nn_AvgPool1d.cpp pass_ncnn/nn_AvgPool2d.cpp pass_ncnn/nn_AvgPool3d.cpp @@ -317,9 +318,9 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/nn_Conv2d.cpp pass_ncnn/nn_Conv3d.cpp pass_ncnn/nn_ConvTranspose2d.cpp - pass_ncnn/nn_Dropout.cpp - pass_ncnn/nn_Dropout2d.cpp - pass_ncnn/nn_Dropout3d.cpp + #pass_ncnn/nn_Dropout.cpp + #pass_ncnn/nn_Dropout2d.cpp + #pass_ncnn/nn_Dropout3d.cpp pass_ncnn/nn_ELU.cpp pass_ncnn/nn_Embedding.cpp pass_ncnn/nn_GELU.cpp diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 54d338218..6c83d2d00 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -14,6 +14,7 @@ #include "pass_level5.h" +#include "pass_level5/eliminate_dropout.h" #include "pass_level5/eliminate_slice.h" #include "pass_level5/eliminate_view_reshape.h" #include "pass_level5/eval_expression.h" @@ -49,6 +50,8 @@ void pass_level5(Graph& g) eliminate_view_reshape(g); + eliminate_dropout(g); + fuse_channel_shuffle(g); dead_code_elimination(g); diff --git a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp new file mode 100644 index 000000000..a23e2b048 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp @@ -0,0 +1,76 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_dropout.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void eliminate_dropout(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "F.alpha_dropout" && op->type != "F.dropout" && op->type != "F.dropout2d" && op->type != "F.dropout3d" && op->type != "F.feature_alpha_dropout" && op->type != "nn.AlphaDropout" && op->type != "nn.Dropout" && op->type != "nn.Dropout2d" && op->type != "nn.Dropout3d") + continue; + + // delete noop-like dropout + matched = true; + + for (auto& x : op->inputs) + { + x->remove_consumer(op); + } + + Operand* slice_out = op->outputs[0]; + + for (auto& x : slice_out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == slice_out) + x->inputs[j] = op->inputs[0]; + } + + op->inputs[0]->consumers.push_back(x); + } + + slice_out->producer = 0; + slice_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), slice_out)); + delete slice_out; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Dropout.cpp b/tools/pnnx/src/pass_level5/eliminate_dropout.h similarity index 57% rename from tools/pnnx/src/pass_ncnn/nn_Dropout.cpp rename to tools/pnnx/src/pass_level5/eliminate_dropout.h index ca295699b..d3636611a 100644 --- a/tools/pnnx/src/pass_ncnn/nn_Dropout.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_dropout.h @@ -12,38 +12,10 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_ncnn.h" +#include "ir.h" namespace pnnx { -namespace ncnn { - -class nn_Dropout : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -nn.Dropout op_0 1 1 input out -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "dropout"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Dropout, 20) - -} // namespace ncnn +void eliminate_dropout(Graph& graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_alpha_dropout.cpp b/tools/pnnx/src/pass_ncnn/F_alpha_dropout.cpp deleted file mode 100644 index 49ca6be46..000000000 --- a/tools/pnnx/src/pass_ncnn/F_alpha_dropout.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "pass_ncnn.h" - -namespace pnnx { - -namespace ncnn { - -class F_alpha_dropout : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -F.alpha_dropout op_0 1 1 input out p=* training=* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "alpha_dropout"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_alpha_dropout, 20) - -} // namespace ncnn - -} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_dropout.cpp b/tools/pnnx/src/pass_ncnn/F_dropout.cpp deleted file mode 100644 index fdd5a1efa..000000000 --- a/tools/pnnx/src/pass_ncnn/F_dropout.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "pass_ncnn.h" - -namespace pnnx { - -namespace ncnn { - -class F_dropout : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -F.dropout op_0 1 1 input out p=* training=* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "dropout"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_dropout, 20) - -} // namespace ncnn - -} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_dropout2d.cpp b/tools/pnnx/src/pass_ncnn/F_dropout2d.cpp deleted file mode 100644 index 240be5be0..000000000 --- a/tools/pnnx/src/pass_ncnn/F_dropout2d.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "pass_ncnn.h" - -namespace pnnx { - -namespace ncnn { - -class F_dropout2d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -F.dropout2d op_0 1 1 input out p=* training=* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "dropout2d"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_dropout2d, 20) - -} // namespace ncnn - -} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_dropout3d.cpp b/tools/pnnx/src/pass_ncnn/F_dropout3d.cpp deleted file mode 100644 index 3d4651f4e..000000000 --- a/tools/pnnx/src/pass_ncnn/F_dropout3d.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "pass_ncnn.h" - -namespace pnnx { - -namespace ncnn { - -class F_dropout3d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -F.dropout3d op_0 1 1 input out p=* training=* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "dropout3d"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_dropout3d, 20) - -} // namespace ncnn - -} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_feature_alpha_dropout.cpp b/tools/pnnx/src/pass_ncnn/F_feature_alpha_dropout.cpp deleted file mode 100644 index 7f599cb42..000000000 --- a/tools/pnnx/src/pass_ncnn/F_feature_alpha_dropout.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "pass_ncnn.h" - -namespace pnnx { - -namespace ncnn { - -class F_feature_alpha_dropout : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -F.feature_alpha_dropout op_0 1 1 input out p=* training=* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "feature_alpha_dropout"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_feature_alpha_dropout, 20) - -} // namespace ncnn - -} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_AlphaDropout.cpp b/tools/pnnx/src/pass_ncnn/nn_AlphaDropout.cpp deleted file mode 100644 index 37a3df98a..000000000 --- a/tools/pnnx/src/pass_ncnn/nn_AlphaDropout.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "pass_ncnn.h" - -namespace pnnx { - -namespace ncnn { - -class nn_AlphaDropout : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -nn.AlphaDropout op_0 1 1 input out -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "dropout"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_AlphaDropout, 20) - -} // namespace ncnn - -} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Dropout2d.cpp b/tools/pnnx/src/pass_ncnn/nn_Dropout2d.cpp deleted file mode 100644 index d402065d7..000000000 --- a/tools/pnnx/src/pass_ncnn/nn_Dropout2d.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "pass_ncnn.h" - -namespace pnnx { - -namespace ncnn { - -class nn_Dropout2d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -nn.Dropout2d op_0 1 1 input out -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "dropout"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Dropout2d, 20) - -} // namespace ncnn - -} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Dropout3d.cpp b/tools/pnnx/src/pass_ncnn/nn_Dropout3d.cpp deleted file mode 100644 index 97cca570a..000000000 --- a/tools/pnnx/src/pass_ncnn/nn_Dropout3d.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "pass_ncnn.h" - -namespace pnnx { - -namespace ncnn { - -class nn_Dropout3d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -nn.Dropout3d op_0 1 1 input out -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Noop"; - } - - const char* name_str() const - { - return "dropout"; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Dropout3d, 20) - -} // namespace ncnn - -} // namespace pnnx