From 76f48c8fcbf2a65e6e00fc0cd42e953ac5789579 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 23 Apr 2025 19:27:08 +0800 Subject: [PATCH] pnnx pass level1 wrapper enabling faster build (#6014) --- tools/pnnx/src/CMakeLists.txt | 30 +-- tools/pnnx/src/ir.h | 4 + tools/pnnx/src/load_torchscript.cpp | 19 +- tools/pnnx/src/pass_level0/inline_block.cpp | 2 +- tools/pnnx/src/pass_level1.cpp | 43 +--- tools/pnnx/src/pass_level1.h | 29 --- .../pnnx/src/pass_level1/fuse_module_pass.cpp | 223 ++++++++++++++++++ tools/pnnx/src/pass_level1/fuse_module_pass.h | 151 ++++++++++++ .../src/pass_level1/nn_AdaptiveAvgPool1d.cpp | 8 +- .../src/pass_level1/nn_AdaptiveAvgPool2d.cpp | 8 +- .../src/pass_level1/nn_AdaptiveAvgPool3d.cpp | 8 +- .../src/pass_level1/nn_AdaptiveMaxPool1d.cpp | 12 +- .../src/pass_level1/nn_AdaptiveMaxPool2d.cpp | 12 +- .../src/pass_level1/nn_AdaptiveMaxPool3d.cpp | 12 +- .../pnnx/src/pass_level1/nn_AlphaDropout.cpp | 4 +- tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp | 8 +- tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp | 8 +- tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp | 8 +- tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp | 16 +- tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp | 16 +- tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp | 16 +- tools/pnnx/src/pass_level1/nn_CELU.cpp | 8 +- .../src/pass_level1/nn_ChannelShuffle.cpp | 8 +- .../pnnx/src/pass_level1/nn_ConstantPad1d.cpp | 10 +- .../pnnx/src/pass_level1/nn_ConstantPad2d.cpp | 10 +- .../pnnx/src/pass_level1/nn_ConstantPad3d.cpp | 10 +- tools/pnnx/src/pass_level1/nn_Conv1d.cpp | 22 +- tools/pnnx/src/pass_level1/nn_Conv2d.cpp | 26 +- tools/pnnx/src/pass_level1/nn_Conv3d.cpp | 22 +- .../src/pass_level1/nn_ConvTranspose1d.cpp | 12 +- .../src/pass_level1/nn_ConvTranspose2d.cpp | 12 +- .../src/pass_level1/nn_ConvTranspose3d.cpp | 12 +- tools/pnnx/src/pass_level1/nn_Dropout.cpp | 4 +- tools/pnnx/src/pass_level1/nn_Dropout2d.cpp | 4 +- tools/pnnx/src/pass_level1/nn_Dropout3d.cpp | 4 +- tools/pnnx/src/pass_level1/nn_ELU.cpp | 8 +- tools/pnnx/src/pass_level1/nn_Embedding.cpp | 10 +- tools/pnnx/src/pass_level1/nn_Fold.cpp | 8 +- tools/pnnx/src/pass_level1/nn_GELU.cpp | 8 +- tools/pnnx/src/pass_level1/nn_GLU.cpp | 8 +- tools/pnnx/src/pass_level1/nn_GRU.cpp | 32 ++- tools/pnnx/src/pass_level1/nn_GroupNorm.cpp | 12 +- tools/pnnx/src/pass_level1/nn_Hardshrink.cpp | 8 +- tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp | 2 +- tools/pnnx/src/pass_level1/nn_Hardswish.cpp | 2 +- tools/pnnx/src/pass_level1/nn_Hardtanh.cpp | 8 +- .../src/pass_level1/nn_InstanceNorm1d.cpp | 16 +- .../src/pass_level1/nn_InstanceNorm2d.cpp | 16 +- .../src/pass_level1/nn_InstanceNorm3d.cpp | 16 +- tools/pnnx/src/pass_level1/nn_LPPool1d.cpp | 21 +- tools/pnnx/src/pass_level1/nn_LPPool2d.cpp | 16 +- tools/pnnx/src/pass_level1/nn_LSTM.cpp | 38 ++- tools/pnnx/src/pass_level1/nn_LayerNorm.cpp | 12 +- tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp | 10 +- tools/pnnx/src/pass_level1/nn_Linear.cpp | 16 +- .../src/pass_level1/nn_LocalResponseNorm.cpp | 25 +- tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp | 4 +- tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp | 8 +- tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp | 10 +- tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp | 10 +- tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp | 10 +- tools/pnnx/src/pass_level1/nn_Mish.cpp | 2 +- .../src/pass_level1/nn_MultiheadAttention.cpp | 75 +++--- tools/pnnx/src/pass_level1/nn_PReLU.cpp | 8 +- .../pnnx/src/pass_level1/nn_PixelShuffle.cpp | 8 +- .../src/pass_level1/nn_PixelUnshuffle.cpp | 8 +- tools/pnnx/src/pass_level1/nn_RMSNorm.cpp | 10 +- tools/pnnx/src/pass_level1/nn_RNN.cpp | 34 ++- tools/pnnx/src/pass_level1/nn_RReLU.cpp | 8 +- tools/pnnx/src/pass_level1/nn_ReLU.cpp | 2 +- tools/pnnx/src/pass_level1/nn_ReLU6.cpp | 2 +- .../src/pass_level1/nn_ReflectionPad1d.cpp | 10 +- .../src/pass_level1/nn_ReflectionPad2d.cpp | 10 +- .../src/pass_level1/nn_ReplicationPad1d.cpp | 10 +- .../src/pass_level1/nn_ReplicationPad2d.cpp | 10 +- .../src/pass_level1/nn_ReplicationPad3d.cpp | 10 +- tools/pnnx/src/pass_level1/nn_SELU.cpp | 4 +- tools/pnnx/src/pass_level1/nn_SiLU.cpp | 2 +- tools/pnnx/src/pass_level1/nn_Sigmoid.cpp | 2 +- tools/pnnx/src/pass_level1/nn_Softmax.cpp | 8 +- tools/pnnx/src/pass_level1/nn_Softmax2d.cpp | 4 +- tools/pnnx/src/pass_level1/nn_Softmin.cpp | 8 +- tools/pnnx/src/pass_level1/nn_Softplus.cpp | 8 +- tools/pnnx/src/pass_level1/nn_Softshrink.cpp | 8 +- tools/pnnx/src/pass_level1/nn_Softsign.cpp | 2 +- tools/pnnx/src/pass_level1/nn_Tanh.cpp | 2 +- tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp | 2 +- tools/pnnx/src/pass_level1/nn_Threshold.cpp | 8 +- tools/pnnx/src/pass_level1/nn_Unfold.cpp | 8 +- tools/pnnx/src/pass_level1/nn_Upsample.cpp | 39 +-- .../pass_level1/nn_UpsamplingBilinear2d.cpp | 8 +- .../pass_level1/nn_UpsamplingNearest2d.cpp | 8 +- tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp | 10 +- .../src/pass_level1/nn_quantized_Conv2d.cpp | 18 +- .../pass_level1/nn_quantized_DeQuantize.cpp | 30 ++- .../src/pass_level1/nn_quantized_Linear.cpp | 20 +- .../src/pass_level1/nn_quantized_Quantize.cpp | 8 +- .../pass_level1/torchvision_DeformConv2d.cpp | 12 +- .../src/pass_level1/torchvision_RoIAlign.cpp | 10 +- tools/pnnx/src/utils.h | 14 -- 100 files changed, 882 insertions(+), 703 deletions(-) create mode 100644 tools/pnnx/src/pass_level1/fuse_module_pass.cpp create mode 100644 tools/pnnx/src/pass_level1/fuse_module_pass.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index fa579294e..e2c2e26fc 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -11,6 +11,8 @@ set(pnnx_pass_level0_SRCS ) set(pnnx_pass_level1_SRCS + pass_level1/fuse_module_pass.cpp + pass_level1/nn_AdaptiveAvgPool1d.cpp pass_level1/nn_AdaptiveAvgPool2d.cpp pass_level1/nn_AdaptiveAvgPool3d.cpp @@ -688,20 +690,20 @@ if(onnxruntime_FOUND) pass_onnx/shape_inference.cpp pass_onnx/fuse_constant_as_attribute.cpp - pass_onnx/nn_AdaptiveAvgPool2d.cpp - pass_onnx/nn_AdaptiveAvgPool3d.cpp - pass_onnx/nn_AvgPool2d.cpp - pass_onnx/nn_AvgPool3d.cpp - pass_onnx/nn_BatchNorm2d.cpp - pass_onnx/nn_BatchNorm3d.cpp - pass_onnx/nn_Conv2d.cpp - pass_onnx/nn_Conv3d.cpp - pass_onnx/nn_GELU.cpp - pass_onnx/nn_LayerNorm.cpp - pass_onnx/nn_Linear.cpp - pass_onnx/nn_MaxPool2d.cpp - pass_onnx/nn_MaxPool3d.cpp - pass_onnx/nn_MultiheadAttention.cpp + # pass_onnx/nn_AdaptiveAvgPool2d.cpp + # pass_onnx/nn_AdaptiveAvgPool3d.cpp + # pass_onnx/nn_AvgPool2d.cpp + # pass_onnx/nn_AvgPool3d.cpp + # pass_onnx/nn_BatchNorm2d.cpp + # pass_onnx/nn_BatchNorm3d.cpp + # pass_onnx/nn_Conv2d.cpp + # pass_onnx/nn_Conv3d.cpp + # pass_onnx/nn_GELU.cpp + # pass_onnx/nn_LayerNorm.cpp + # pass_onnx/nn_Linear.cpp + # pass_onnx/nn_MaxPool2d.cpp + # pass_onnx/nn_MaxPool3d.cpp + # pass_onnx/nn_MultiheadAttention.cpp ) set(onnx2pnnx_SRCS diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 6ab52eb0a..d1d200dd6 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -34,6 +34,9 @@ struct Node; namespace at { class Tensor; } +namespace pnnx { +class TorchTensorProxy; +} // namespace pnnx #endif // BUILD_TORCH2PNNX #if BUILD_ONNX2PNNX @@ -230,6 +233,7 @@ public: #if BUILD_TORCH2PNNX Attribute(const at::Tensor& t); + Attribute(const TorchTensorProxy& t); #endif #if BUILD_ONNX2PNNX Attribute(const onnx::TensorProto& t); diff --git a/tools/pnnx/src/load_torchscript.cpp b/tools/pnnx/src/load_torchscript.cpp index 647a9f4ba..cb78b3b8a 100644 --- a/tools/pnnx/src/load_torchscript.cpp +++ b/tools/pnnx/src/load_torchscript.cpp @@ -31,6 +31,7 @@ int64_t cuda_version(); #include "pass_level0.h" #include "pass_level1.h" +#include "pass_level1/fuse_module_pass.h" namespace pnnx { @@ -372,6 +373,11 @@ Attribute::Attribute(const at::Tensor& t) } } +Attribute::Attribute(const TorchTensorProxy& t) + : Attribute(t.t()) +{ +} + Operand* Graph::new_operand(const torch::jit::Value* v) { // Operand* r = new Operand; @@ -442,17 +448,6 @@ static const char* get_at_tensor_type_str(const at::ScalarType& st) return ""; } -const torch::jit::Node* find_node_by_kind(const std::shared_ptr& graph, const std::string& kind) -{ - for (const auto& n : graph->nodes()) - { - if (n->kind().toDisplayString() == kind) - return n; - } - - return 0; -} - static void print_shape_list(const std::vector >& shapes, const std::vector& types) { for (size_t i = 0; i < shapes.size(); i++) @@ -508,7 +503,7 @@ static void get_traced_input_shape(const std::string& ptpath, std::vector diff --git a/tools/pnnx/src/pass_level1.cpp b/tools/pnnx/src/pass_level1.cpp index aa61fac00..0696149c5 100644 --- a/tools/pnnx/src/pass_level1.cpp +++ b/tools/pnnx/src/pass_level1.cpp @@ -12,43 +12,16 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include +#include +#include #include +#include #include "pass_level1.h" -namespace pnnx { - -FuseModulePass::~FuseModulePass() -{ -} - -void FuseModulePass::write(Operator* /*op*/, const std::shared_ptr& /*graph*/) const -{ -} - -void FuseModulePass::write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& /*mod*/) const -{ - write(op, graph); -} - -static std::vector g_global_pnnx_fuse_module_passes; - -const std::vector& get_global_pnnx_fuse_module_passes() -{ - return g_global_pnnx_fuse_module_passes; -} +#include "pass_level1/fuse_module_pass.h" -FuseModulePassRegister::FuseModulePassRegister(const FuseModulePass* _pass) - : pass(_pass) -{ - g_global_pnnx_fuse_module_passes.push_back(pass); -} - -FuseModulePassRegister::~FuseModulePassRegister() -{ - delete pass; -} +namespace pnnx { static void fuse_moduleop_unpack(Graph& graph, const std::vector& module_operators) { @@ -399,10 +372,12 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrname = wrapped_name; #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 11) - ow->write(op, toGraphFunction(function).graph(), sub_mod); + TorchGraphProxy graph_proxy(toGraphFunction(function).graph()); #else - ow->write(op, function.graph(), sub_mod); + TorchGraphProxy graph_proxy(function.graph()); #endif + TorchModuleProxy sub_mod_proxy(sub_mod); + ow->write(op, graph_proxy, sub_mod_proxy); break; } diff --git a/tools/pnnx/src/pass_level1.h b/tools/pnnx/src/pass_level1.h index 1eb5a7ab9..13daca961 100644 --- a/tools/pnnx/src/pass_level1.h +++ b/tools/pnnx/src/pass_level1.h @@ -15,39 +15,10 @@ #ifndef PNNX_PASS_LEVEL1_H #define PNNX_PASS_LEVEL1_H -#include -#include #include "ir.h" namespace pnnx { -class FuseModulePass -{ -public: - virtual ~FuseModulePass(); - - virtual const char* match_type_str() const = 0; - - virtual const char* type_str() const = 0; - - virtual void write(Operator* op, const std::shared_ptr& graph) const; - - virtual void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const; -}; - -class FuseModulePassRegister -{ -public: - FuseModulePassRegister(const FuseModulePass* pass); - ~FuseModulePassRegister(); - const FuseModulePass* pass; -}; - -const std::vector& get_global_pnnx_fuse_module_passes(); - -#define REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CLASS) \ - static FuseModulePassRegister g_global_pnnx_fusemodulepass_##CLASS##_register(new CLASS); - void pass_level1(const torch::jit::Module& mod, const std::shared_ptr& g, const std::vector& module_operators, Graph& pg); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/fuse_module_pass.cpp b/tools/pnnx/src/pass_level1/fuse_module_pass.cpp new file mode 100644 index 000000000..007a900aa --- /dev/null +++ b/tools/pnnx/src/pass_level1/fuse_module_pass.cpp @@ -0,0 +1,223 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_module_pass.h" + +#include +#include +#include + +namespace pnnx { + +std::string TorchNodeProxy::kind() const +{ + return node->kind().toDisplayString(); +} + +bool TorchNodeProxy::hasNamedInput(const std::string& name) const +{ + return node->hasNamedInput(name); +} + +const torch::jit::Value* TorchNodeProxy::namedInput(const std::string& name) const +{ + return node->namedInput(name); +} + +int TorchNodeProxy::input_count() const +{ + return node->inputs().size(); +} + +const torch::jit::Value* TorchNodeProxy::input(int i) const +{ + return node->input(i); +} + +int TorchNodeProxy::output_count() const +{ + return node->outputs().size(); +} + +const torch::jit::Value* TorchNodeProxy::output(int i) const +{ + return node->output(i); +} + +bool TorchNodeProxy::is_input_none(int i) const +{ + return node->input(i)->type()->kind() == c10::TypeKind::NoneType; +} + +TorchGraphProxy::TorchGraphProxy(const std::shared_ptr _graph) + : graph(_graph) +{ + for (const auto& n : graph->nodes()) + { + nodes.push_back(n); + } +} + +const TorchNodeProxy* TorchGraphProxy::find_node_by_kind(const std::string& kind) const +{ + for (const auto& n : nodes) + { + if (n.node->kind().toDisplayString() == kind) + return &n; + } + + return 0; +} + +const TorchNodeProxy* TorchGraphProxy::find_producer_node_by_value(const torch::jit::Value* value) const +{ + for (const auto& n : nodes) + { + if (n.node == value->node()) + return &n; + } + + fprintf(stderr, "TorchGraphProxy find_producer_node_by_value failed\n"); + return 0; +} + +int TorchGraphProxy::input_count() const +{ + return std::as_const(*graph).inputs().size(); +} + +const torch::jit::Value* TorchGraphProxy::input(int i) const +{ + return std::as_const(*graph).inputs()[i]; +} + +int TorchGraphProxy::output_count() const +{ + return std::as_const(*graph).outputs().size(); +} + +const torch::jit::Value* TorchGraphProxy::output(int i) const +{ + return std::as_const(*graph).outputs()[i]; +} + +void TorchGraphProxy::dump() const +{ + graph->dump(); +} + +class TorchTensorProxyPrivate +{ +public: + at::Tensor t; +}; + +TorchTensorProxy::TorchTensorProxy(const at::Tensor& _t) + : d(new TorchTensorProxyPrivate) +{ + d->t = _t; +} + +TorchTensorProxy::~TorchTensorProxy() +{ + delete d; +} + +const at::Tensor& TorchTensorProxy::t() const +{ + return d->t; +} + +int TorchTensorProxy::size(size_t i) const +{ + return d->t.size(i); +} + +TorchModuleProxy::TorchModuleProxy(const torch::jit::Module& _mod) + : mod(_mod) +{ + const std::vector& attributes = mod._ivalue()->type()->getAttributes(); + for (size_t i = 0; i < attributes.size(); i++) + { + const std::string& name = attributes[i].getName(); + const c10::IValue& ivalue = mod._ivalue()->getSlot(i); + + if (name.empty()) + continue; + + if (ivalue.isTensor()) + attrs.emplace(name, ivalue.toTensor()); + + if (ivalue.isModule()) + { + const torch::jit::Module submod = ivalue.toModule(); + + const std::vector& sub_attributes = submod._ivalue()->type()->getAttributes(); + for (size_t j = 0; j < sub_attributes.size(); j++) + { + const std::string& sub_name = sub_attributes[j].getName(); + const c10::IValue& sub_ivalue = submod._ivalue()->getSlot(j); + + if (sub_name.empty()) + continue; + + if (sub_ivalue.isTensor()) + attrs.emplace(name + "." + sub_name, sub_ivalue.toTensor()); + } + } + } +} + +bool TorchModuleProxy::hasattr(const std::string& name) const +{ + return attrs.find(name) != attrs.end(); +} + +const TorchTensorProxy& TorchModuleProxy::attr(const std::string& name) const +{ + return attrs.at(name); +} + +FuseModulePass::~FuseModulePass() +{ +} + +void FuseModulePass::write(Operator* /*op*/, const TorchGraphProxy& /*graph*/) const +{ +} + +void FuseModulePass::write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& /*mod*/) const +{ + write(op, graph); +} + +static std::vector g_global_pnnx_fuse_module_passes; + +const std::vector& get_global_pnnx_fuse_module_passes() +{ + return g_global_pnnx_fuse_module_passes; +} + +FuseModulePassRegister::FuseModulePassRegister(const FuseModulePass* _pass) + : pass(_pass) +{ + g_global_pnnx_fuse_module_passes.push_back(pass); +} + +FuseModulePassRegister::~FuseModulePassRegister() +{ + delete pass; +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/fuse_module_pass.h b/tools/pnnx/src/pass_level1/fuse_module_pass.h new file mode 100644 index 000000000..68ae8179b --- /dev/null +++ b/tools/pnnx/src/pass_level1/fuse_module_pass.h @@ -0,0 +1,151 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_FUSE_MODULE_PASS_H +#define PNNX_FUSE_MODULE_PASS_H + +#include +#include +#include +#include + +#include "ir.h" + +namespace torch { +namespace jit { +struct Graph; +struct Module; +struct Node; +struct Value; +} // namespace jit +} // namespace torch +namespace at { +struct Tensor; +} // namespace at + +namespace pnnx { + +class TorchNodeProxy +{ +public: + TorchNodeProxy(const torch::jit::Node* _node) + : node(_node) + { + } + + std::string kind() const; + + bool hasNamedInput(const std::string& name) const; + const torch::jit::Value* namedInput(const std::string& name) const; + + int input_count() const; + const torch::jit::Value* input(int i) const; + + int output_count() const; + const torch::jit::Value* output(int i) const; + + bool is_input_none(int i) const; + +public: + const torch::jit::Node* node; +}; + +class TorchGraphProxy +{ +public: + TorchGraphProxy(const std::shared_ptr _graph); + + // bool has_node(const std::string& name) const; + const TorchNodeProxy* find_node_by_kind(const std::string& kind) const; + + const TorchNodeProxy* find_producer_node_by_value(const torch::jit::Value* value) const; + + int input_count() const; + const torch::jit::Value* input(int i) const; + + int output_count() const; + const torch::jit::Value* output(int i) const; + + void dump() const; + +public: + const std::shared_ptr graph; + +public: + std::vector nodes; +}; + +class TorchTensorProxyPrivate; +class TorchTensorProxy +{ +public: + TorchTensorProxy(const at::Tensor& _t); + ~TorchTensorProxy(); + + TorchTensorProxy(const TorchTensorProxy&) = delete; + TorchTensorProxy& operator=(const TorchTensorProxy&) = delete; + + const at::Tensor& t() const; + + int size(size_t i) const; + +private: + TorchTensorProxyPrivate* const d; +}; + +class TorchModuleProxy +{ +public: + TorchModuleProxy(const torch::jit::Module& _mod); + + bool hasattr(const std::string& name) const; + const TorchTensorProxy& attr(const std::string& name) const; + +public: + const torch::jit::Module& mod; + +private: + std::unordered_map attrs; +}; + +class FuseModulePass +{ +public: + virtual ~FuseModulePass(); + + virtual const char* match_type_str() const = 0; + + virtual const char* type_str() const = 0; + + virtual void write(Operator* op, const TorchGraphProxy& graph) const; + + virtual void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const; +}; + +class FuseModulePassRegister +{ +public: + FuseModulePassRegister(const FuseModulePass* pass); + ~FuseModulePassRegister(); + const FuseModulePass* pass; +}; + +const std::vector& get_global_pnnx_fuse_module_passes(); + +#define REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CLASS) \ + static FuseModulePassRegister g_global_pnnx_fusemodulepass_##CLASS##_register(new CLASS); + +} // namespace pnnx + +#endif // PNNX_FUSE_MODULE_PASS_H diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp index b7ee5241d..b13150f51 100644 --- a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.AdaptiveAvgPool1d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* adaptive_avg_pool1d = find_node_by_kind(graph, "aten::adaptive_avg_pool1d"); + const TorchNodeProxy* adaptive_avg_pool1d = graph.find_node_by_kind("aten::adaptive_avg_pool1d"); op->params["output_size"] = adaptive_avg_pool1d->namedInput("output_size"); } diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp index 7987285a8..57d239ca0 100644 --- a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.AdaptiveAvgPool2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* adaptive_avg_pool2d = find_node_by_kind(graph, "aten::adaptive_avg_pool2d"); + const TorchNodeProxy* adaptive_avg_pool2d = graph.find_node_by_kind("aten::adaptive_avg_pool2d"); op->params["output_size"] = adaptive_avg_pool2d->namedInput("output_size"); } diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp index 9f34be877..153ea3f17 100644 --- a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.AdaptiveAvgPool3d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* adaptive_avg_pool3d = find_node_by_kind(graph, "aten::adaptive_avg_pool3d"); + const TorchNodeProxy* adaptive_avg_pool3d = graph.find_node_by_kind("aten::adaptive_avg_pool3d"); op->params["output_size"] = adaptive_avg_pool3d->namedInput("output_size"); } diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp index aaf7a224b..2019555ae 100644 --- a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,12 +29,14 @@ public: return "nn.AdaptiveMaxPool1d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* adaptive_max_pool1d = find_node_by_kind(graph, "aten::adaptive_max_pool1d"); + const TorchNodeProxy* adaptive_max_pool1d = graph.find_node_by_kind("aten::adaptive_max_pool1d"); + + const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0)); op->params["output_size"] = adaptive_max_pool1d->namedInput("output_size"); - op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; + op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false; } }; diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp index 72286d58e..98d00d8e0 100644 --- a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,12 +29,14 @@ public: return "nn.AdaptiveMaxPool2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* adaptive_max_pool2d = find_node_by_kind(graph, "aten::adaptive_max_pool2d"); + const TorchNodeProxy* adaptive_max_pool2d = graph.find_node_by_kind("aten::adaptive_max_pool2d"); + + const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0)); op->params["output_size"] = adaptive_max_pool2d->namedInput("output_size"); - op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; + op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false; } }; diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp index faff21131..0acb1287a 100644 --- a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,12 +29,14 @@ public: return "nn.AdaptiveMaxPool3d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* adaptive_max_pool3d = find_node_by_kind(graph, "aten::adaptive_max_pool3d"); + const TorchNodeProxy* adaptive_max_pool3d = graph.find_node_by_kind("aten::adaptive_max_pool3d"); + + const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0)); op->params["output_size"] = adaptive_max_pool3d->namedInput("output_size"); - op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; + op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false; } }; diff --git a/tools/pnnx/src/pass_level1/nn_AlphaDropout.cpp b/tools/pnnx/src/pass_level1/nn_AlphaDropout.cpp index 7bc7d6eab..191842abf 100644 --- a/tools/pnnx/src/pass_level1/nn_AlphaDropout.cpp +++ b/tools/pnnx/src/pass_level1/nn_AlphaDropout.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp b/tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp index 99eb692af..0efb54d63 100644 --- a/tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.AvgPool1d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* avg_pool1d = find_node_by_kind(graph, "aten::avg_pool1d"); + const TorchNodeProxy* avg_pool1d = graph.find_node_by_kind("aten::avg_pool1d"); op->params["kernel_size"] = avg_pool1d->namedInput("kernel_size"); op->params["stride"] = avg_pool1d->namedInput("stride"); diff --git a/tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp b/tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp index bc75cee2f..594627045 100644 --- a/tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.AvgPool2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* avg_pool2d = find_node_by_kind(graph, "aten::avg_pool2d"); + const TorchNodeProxy* avg_pool2d = graph.find_node_by_kind("aten::avg_pool2d"); op->params["kernel_size"] = avg_pool2d->namedInput("kernel_size"); op->params["stride"] = avg_pool2d->namedInput("stride"); diff --git a/tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp b/tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp index 0b6c10fd6..49958ee49 100644 --- a/tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.AvgPool3d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* avg_pool3d = find_node_by_kind(graph, "aten::avg_pool3d"); + const TorchNodeProxy* avg_pool3d = graph.find_node_by_kind("aten::avg_pool3d"); op->params["kernel_size"] = avg_pool3d->namedInput("kernel_size"); op->params["stride"] = avg_pool3d->namedInput("stride"); diff --git a/tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp b/tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp index afe10649e..1d70abd1a 100644 --- a/tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,12 +29,12 @@ public: return "nn.BatchNorm1d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); + const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm"); - const auto& running_mean = mod.attr("running_mean").toTensor(); - const auto& running_var = mod.attr("running_var").toTensor(); + const TorchTensorProxy& running_mean = mod.attr("running_mean"); + const TorchTensorProxy& running_var = mod.attr("running_var"); op->params["num_features"] = running_mean.size(0); op->params["eps"] = bn->namedInput("eps"); @@ -46,8 +44,8 @@ public: op->attrs["running_var"] = running_var; if (mod.hasattr("weight") && mod.hasattr("bias")) { - op->attrs["weight"] = mod.attr("weight").toTensor(); - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["weight"] = mod.attr("weight"); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp b/tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp index 7642c0331..b8b85a433 100644 --- a/tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,12 +29,12 @@ public: return "nn.BatchNorm2d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); + const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm"); - const auto& running_mean = mod.attr("running_mean").toTensor(); - const auto& running_var = mod.attr("running_var").toTensor(); + const TorchTensorProxy& running_mean = mod.attr("running_mean"); + const TorchTensorProxy& running_var = mod.attr("running_var"); op->params["num_features"] = running_mean.size(0); op->params["eps"] = bn->namedInput("eps"); @@ -46,8 +44,8 @@ public: op->attrs["running_var"] = running_var; if (mod.hasattr("weight") && mod.hasattr("bias")) { - op->attrs["weight"] = mod.attr("weight").toTensor(); - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["weight"] = mod.attr("weight"); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp b/tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp index 20a832fc3..93e47326a 100644 --- a/tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,12 +29,12 @@ public: return "nn.BatchNorm3d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); + const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm"); - const auto& running_mean = mod.attr("running_mean").toTensor(); - const auto& running_var = mod.attr("running_var").toTensor(); + const TorchTensorProxy& running_mean = mod.attr("running_mean"); + const TorchTensorProxy& running_var = mod.attr("running_var"); op->params["num_features"] = running_mean.size(0); op->params["eps"] = bn->namedInput("eps"); @@ -46,8 +44,8 @@ public: op->attrs["running_var"] = running_var; if (mod.hasattr("weight") && mod.hasattr("bias")) { - op->attrs["weight"] = mod.attr("weight").toTensor(); - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["weight"] = mod.attr("weight"); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_CELU.cpp b/tools/pnnx/src/pass_level1/nn_CELU.cpp index dc50b92f5..1c9292ba2 100644 --- a/tools/pnnx/src/pass_level1/nn_CELU.cpp +++ b/tools/pnnx/src/pass_level1/nn_CELU.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.CELU"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* celu = find_node_by_kind(graph, "aten::celu"); + const TorchNodeProxy* celu = graph.find_node_by_kind("aten::celu"); op->params["alpha"] = celu->namedInput("alpha"); } diff --git a/tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp b/tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp index 84ecf8410..f46548397 100644 --- a/tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp +++ b/tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.ChannelShuffle"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* channel_shuffle = find_node_by_kind(graph, "aten::channel_shuffle"); + const TorchNodeProxy* channel_shuffle = graph.find_node_by_kind("aten::channel_shuffle"); op->params["groups"] = channel_shuffle->namedInput("groups"); } diff --git a/tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp b/tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp index 4698fb21a..c0130b944 100644 --- a/tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ConstantPad1d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd"); if (!pad) { diff --git a/tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp b/tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp index 21f9a34a4..2fd748692 100644 --- a/tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ConstantPad2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd"); if (!pad) { diff --git a/tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp b/tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp index 5b0acbee0..573839c38 100644 --- a/tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ConstantPad3d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd"); if (!pad) { diff --git a/tools/pnnx/src/pass_level1/nn_Conv1d.cpp b/tools/pnnx/src/pass_level1/nn_Conv1d.cpp index a74db848b..df3727f17 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv1d.cpp @@ -12,11 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -// #include "../pass_level3/fuse_expression.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -33,7 +29,7 @@ public: return "nn.Conv1d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // { // pnnx::Graph pnnx_graph; @@ -45,18 +41,18 @@ public: // pnnx_graph.save("tmp.param", "tmp.bin"); // } - const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); - const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* reflection_pad1d = find_node_by_kind(graph, "aten::reflection_pad1d"); - const torch::jit::Node* replication_pad1d = find_node_by_kind(graph, "aten::replication_pad1d"); + const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); + const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* reflection_pad1d = graph.find_node_by_kind("aten::reflection_pad1d"); + const TorchNodeProxy* replication_pad1d = graph.find_node_by_kind("aten::replication_pad1d"); if (convolution_mode) { convolution = convolution_mode; } - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(1) * op->params["groups"].i; @@ -131,7 +127,7 @@ public: op->attrs["weight"] = weight; if (mod.hasattr("bias")) { - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_Conv2d.cpp b/tools/pnnx/src/pass_level1/nn_Conv2d.cpp index 26b5f86f1..4a9f968e9 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv2d.cpp @@ -12,11 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -// #include "../pass_level3/fuse_expression.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -33,7 +29,7 @@ public: return "nn.Conv2d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // { // pnnx::Graph pnnx_graph; @@ -45,18 +41,18 @@ public: // pnnx_graph.save("tmp.param", "tmp.bin"); // } - const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); - const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d"); - const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d"); + const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); + const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* reflection_pad2d = graph.find_node_by_kind("aten::reflection_pad2d"); + const TorchNodeProxy* replication_pad2d = graph.find_node_by_kind("aten::replication_pad2d"); if (convolution_mode) { convolution = convolution_mode; } - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(1) * op->params["groups"].i; @@ -126,12 +122,12 @@ public: op->params["padding"] = convolution->namedInput("padding"); } op->params["dilation"] = convolution->namedInput("dilation"); - op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor(); + op->params["bias"] = mod.hasattr("bias"); op->attrs["weight"] = weight; - if (mod.hasattr("bias") && mod.attr("bias").isTensor()) + if (mod.hasattr("bias")) { - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp index 271c024d9..8e4472433 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp @@ -12,11 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -// #include "../pass_level3/fuse_expression.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -33,7 +29,7 @@ public: return "nn.Conv3d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // { // pnnx::Graph pnnx_graph; @@ -45,18 +41,18 @@ public: // pnnx_graph.save("tmp.param", "tmp.bin"); // } - const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); - const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d"); - const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); + const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); + const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* reflection_pad3d = graph.find_node_by_kind("aten::reflection_pad3d"); + const TorchNodeProxy* replication_pad3d = graph.find_node_by_kind("aten::replication_pad3d"); if (convolution_mode) { convolution = convolution_mode; } - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(1) * op->params["groups"].i; @@ -131,7 +127,7 @@ public: op->attrs["weight"] = weight; if (mod.hasattr("bias")) { - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp index 6677b832e..35a06bb86 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "nn.ConvTranspose1d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(0); @@ -50,7 +48,7 @@ public: op->attrs["weight"] = weight; if (mod.hasattr("bias")) { - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } if (op->inputs.size() > 1) diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp index a481cb154..c47d19710 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "nn.ConvTranspose2d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(0); @@ -50,7 +48,7 @@ public: op->attrs["weight"] = weight; if (mod.hasattr("bias")) { - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } if (op->inputs.size() > 1) diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp index b8fd141de..d762bbd89 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "nn.ConvTranspose3d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(0); @@ -50,7 +48,7 @@ public: op->attrs["weight"] = weight; if (mod.hasattr("bias")) { - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } if (op->inputs.size() > 1) diff --git a/tools/pnnx/src/pass_level1/nn_Dropout.cpp b/tools/pnnx/src/pass_level1/nn_Dropout.cpp index 1c18dd453..4302b8677 100644 --- a/tools/pnnx/src/pass_level1/nn_Dropout.cpp +++ b/tools/pnnx/src/pass_level1/nn_Dropout.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Dropout2d.cpp b/tools/pnnx/src/pass_level1/nn_Dropout2d.cpp index 4bbaa4c05..c3c470871 100644 --- a/tools/pnnx/src/pass_level1/nn_Dropout2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Dropout2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Dropout3d.cpp b/tools/pnnx/src/pass_level1/nn_Dropout3d.cpp index 06b4d7d91..44793afa3 100644 --- a/tools/pnnx/src/pass_level1/nn_Dropout3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Dropout3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_ELU.cpp b/tools/pnnx/src/pass_level1/nn_ELU.cpp index a5b309ee3..1a77339e2 100644 --- a/tools/pnnx/src/pass_level1/nn_ELU.cpp +++ b/tools/pnnx/src/pass_level1/nn_ELU.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.ELU"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* elu = find_node_by_kind(graph, "aten::elu"); + const TorchNodeProxy* elu = graph.find_node_by_kind("aten::elu"); op->params["alpha"] = elu->namedInput("alpha"); } diff --git a/tools/pnnx/src/pass_level1/nn_Embedding.cpp b/tools/pnnx/src/pass_level1/nn_Embedding.cpp index 8e1c76bef..2c6cd24ee 100644 --- a/tools/pnnx/src/pass_level1/nn_Embedding.cpp +++ b/tools/pnnx/src/pass_level1/nn_Embedding.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "nn.Embedding"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* embedding = find_node_by_kind(graph, "aten::embedding"); + const TorchNodeProxy* embedding = graph.find_node_by_kind("aten::embedding"); - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["num_embeddings"] = weight.size(0); op->params["embedding_dim"] = weight.size(1); diff --git a/tools/pnnx/src/pass_level1/nn_Fold.cpp b/tools/pnnx/src/pass_level1/nn_Fold.cpp index 045c1f6f1..ab1cf85ce 100644 --- a/tools/pnnx/src/pass_level1/nn_Fold.cpp +++ b/tools/pnnx/src/pass_level1/nn_Fold.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Fold"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* col2im = find_node_by_kind(graph, "aten::col2im"); + const TorchNodeProxy* col2im = graph.find_node_by_kind("aten::col2im"); op->params["output_size"] = col2im->namedInput("output_size"); op->params["kernel_size"] = col2im->namedInput("kernel_size"); diff --git a/tools/pnnx/src/pass_level1/nn_GELU.cpp b/tools/pnnx/src/pass_level1/nn_GELU.cpp index 8aae7d061..789863d50 100644 --- a/tools/pnnx/src/pass_level1/nn_GELU.cpp +++ b/tools/pnnx/src/pass_level1/nn_GELU.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.GELU"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* gelu = find_node_by_kind(graph, "aten::gelu"); + const TorchNodeProxy* gelu = graph.find_node_by_kind("aten::gelu"); if (gelu->hasNamedInput("approximate")) { diff --git a/tools/pnnx/src/pass_level1/nn_GLU.cpp b/tools/pnnx/src/pass_level1/nn_GLU.cpp index 72af2f3f0..3739f1621 100644 --- a/tools/pnnx/src/pass_level1/nn_GLU.cpp +++ b/tools/pnnx/src/pass_level1/nn_GLU.cpp @@ -13,9 +13,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -32,9 +30,9 @@ public: return "nn.GLU"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* glu = find_node_by_kind(graph, "aten::glu"); + const TorchNodeProxy* glu = graph.find_node_by_kind("aten::glu"); op->params["dim"] = glu->namedInput("dim"); } diff --git a/tools/pnnx/src/pass_level1/nn_GRU.cpp b/tools/pnnx/src/pass_level1/nn_GRU.cpp index 8ed972c84..6ac4f213a 100644 --- a/tools/pnnx/src/pass_level1/nn_GRU.cpp +++ b/tools/pnnx/src/pass_level1/nn_GRU.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,17 +29,17 @@ public: return "nn.GRU"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // mod.dump(true, true, true); // graph->dump(); - const torch::jit::Node* gru = find_node_by_kind(graph, "aten::gru"); + const TorchNodeProxy* gru = graph.find_node_by_kind("aten::gru"); - const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); - if (return_tuple && return_tuple->inputs().size() == 2 && gru->outputs().size() == 2 - && return_tuple->inputs()[0] == gru->outputs()[1] && return_tuple->inputs()[1] == gru->outputs()[0]) + const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct"); + if (return_tuple && return_tuple->input_count() == 2 && gru->output_count() == 2 + && return_tuple->input(0) == gru->output(1) && return_tuple->input(1) == gru->output(0)) { // mark the swapped output tuple // we would restore the fine order in pass_level3/fuse_rnn_unpack @@ -54,7 +52,7 @@ public: // fprintf(stderr, "arg %s\n", aa.name().c_str()); // } - const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); + const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0"); op->params["input_size"] = weight_ih_l0.size(1); op->params["hidden_size"] = weight_ih_l0.size(0) / 3; @@ -72,16 +70,16 @@ public: std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); - op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); - op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); + op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key); + op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key); if (bias) { std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); - op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); - op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); + op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key); + op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key); } if (bidirectional) @@ -89,16 +87,16 @@ public: std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; - op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); - op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); + op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key); + op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key); if (bias) { std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; - op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); - op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); + op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key); + op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key); } } } diff --git a/tools/pnnx/src/pass_level1/nn_GroupNorm.cpp b/tools/pnnx/src/pass_level1/nn_GroupNorm.cpp index 3947aead6..063f729f4 100644 --- a/tools/pnnx/src/pass_level1/nn_GroupNorm.cpp +++ b/tools/pnnx/src/pass_level1/nn_GroupNorm.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "nn.GroupNorm"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // graph->dump(); - const torch::jit::Node* gn = find_node_by_kind(graph, "aten::group_norm"); + const TorchNodeProxy* gn = graph.find_node_by_kind("aten::group_norm"); // for (auto aa : gn->schema().arguments()) // { @@ -48,12 +46,12 @@ public: if (mod.hasattr("weight") && mod.hasattr("bias")) { - const auto& weight = mod.attr("weight").toTensor(); + const auto& weight = mod.attr("weight"); op->params["num_channels"] = weight.size(0); op->attrs["weight"] = weight; - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } else { diff --git a/tools/pnnx/src/pass_level1/nn_Hardshrink.cpp b/tools/pnnx/src/pass_level1/nn_Hardshrink.cpp index ee230a9fa..0ef45be02 100644 --- a/tools/pnnx/src/pass_level1/nn_Hardshrink.cpp +++ b/tools/pnnx/src/pass_level1/nn_Hardshrink.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Hardshrink"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* hardshrink = find_node_by_kind(graph, "aten::hardshrink"); + const TorchNodeProxy* hardshrink = graph.find_node_by_kind("aten::hardshrink"); op->params["lambd"] = hardshrink->namedInput("lambd"); } diff --git a/tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp b/tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp index ba6fe78e9..c4827495f 100644 --- a/tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp +++ b/tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Hardswish.cpp b/tools/pnnx/src/pass_level1/nn_Hardswish.cpp index 1061832a1..940ff2cfb 100644 --- a/tools/pnnx/src/pass_level1/nn_Hardswish.cpp +++ b/tools/pnnx/src/pass_level1/nn_Hardswish.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Hardtanh.cpp b/tools/pnnx/src/pass_level1/nn_Hardtanh.cpp index 1c4ee37ab..fac05a575 100644 --- a/tools/pnnx/src/pass_level1/nn_Hardtanh.cpp +++ b/tools/pnnx/src/pass_level1/nn_Hardtanh.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Hardtanh"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* hardtanh = find_node_by_kind(graph, "aten::hardtanh"); + const TorchNodeProxy* hardtanh = graph.find_node_by_kind("aten::hardtanh"); op->params["min_val"] = hardtanh->namedInput("min_val"); op->params["max_val"] = hardtanh->namedInput("max_val"); diff --git a/tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp b/tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp index a6c371bad..9dfabf5e5 100644 --- a/tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "nn.InstanceNorm1d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // graph->dump(); - const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); + const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm"); // for (auto aa : in->schema().arguments()) // { @@ -48,22 +46,22 @@ public: if (mod.hasattr("weight") && mod.hasattr("bias")) { - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["num_features"] = weight.size(0); op->attrs["weight"] = weight; - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } if (mod.hasattr("running_mean") && mod.hasattr("running_var")) { - const auto& running_mean = mod.attr("running_mean").toTensor(); + const TorchTensorProxy& running_mean = mod.attr("running_mean"); op->params["num_features"] = running_mean.size(0); op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = mod.attr("running_var").toTensor(); + op->attrs["running_var"] = mod.attr("running_var"); } // take num_features from input shape diff --git a/tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp b/tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp index 3eb86b9df..163825031 100644 --- a/tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "nn.InstanceNorm2d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // graph->dump(); - const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); + const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm"); // for (auto aa : in->schema().arguments()) // { @@ -48,22 +46,22 @@ public: if (mod.hasattr("weight") && mod.hasattr("bias")) { - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["num_features"] = weight.size(0); op->attrs["weight"] = weight; - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } if (mod.hasattr("running_mean") && mod.hasattr("running_var")) { - const auto& running_mean = mod.attr("running_mean").toTensor(); + const TorchTensorProxy& running_mean = mod.attr("running_mean"); op->params["num_features"] = running_mean.size(0); op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = mod.attr("running_var").toTensor(); + op->attrs["running_var"] = mod.attr("running_var"); } // take num_features from input shape diff --git a/tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp b/tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp index 8881db311..ae871d36b 100644 --- a/tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "nn.InstanceNorm3d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // graph->dump(); - const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); + const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm"); // for (auto aa : in->schema().arguments()) // { @@ -48,22 +46,22 @@ public: if (mod.hasattr("weight") && mod.hasattr("bias")) { - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["num_features"] = weight.size(0); op->attrs["weight"] = weight; - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } if (mod.hasattr("running_mean") && mod.hasattr("running_var")) { - const auto& running_mean = mod.attr("running_mean").toTensor(); + const TorchTensorProxy& running_mean = mod.attr("running_mean"); op->params["num_features"] = running_mean.size(0); op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = mod.attr("running_var").toTensor(); + op->attrs["running_var"] = mod.attr("running_var"); } // take num_features from input shape diff --git a/tools/pnnx/src/pass_level1/nn_LPPool1d.cpp b/tools/pnnx/src/pass_level1/nn_LPPool1d.cpp index f7b737576..0993bedb6 100644 --- a/tools/pnnx/src/pass_level1/nn_LPPool1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_LPPool1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,21 +29,24 @@ public: return "nn.LPPool1d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); - op->params["norm_type"] = pow->inputs()[1]; + const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow"); + op->params["norm_type"] = pow->input(1); + + const TorchNodeProxy* avg_pool1d = graph.find_node_by_kind("aten::avg_pool1d"); - const torch::jit::Node* avg_pool1d = find_node_by_kind(graph, "aten::avg_pool1d"); + const TorchNodeProxy* kernel_size = graph.find_producer_node_by_value(avg_pool1d->namedInput("kernel_size")); + const TorchNodeProxy* stride = graph.find_producer_node_by_value(avg_pool1d->namedInput("stride")); - op->params["kernel_size"] = avg_pool1d->namedInput("kernel_size")->node()->inputs()[0]; - if (avg_pool1d->namedInput("stride")->node()->inputs().size() == 0) + op->params["kernel_size"] = kernel_size->input(0); + if (stride->input_count() == 0) { op->params["stride"] = op->params["kernel_size"]; } else { - op->params["stride"] = avg_pool1d->namedInput("stride")->node()->inputs()[0]; + op->params["stride"] = stride->input(0); } op->params["ceil_mode"] = avg_pool1d->namedInput("ceil_mode"); } diff --git a/tools/pnnx/src/pass_level1/nn_LPPool2d.cpp b/tools/pnnx/src/pass_level1/nn_LPPool2d.cpp index d843704f7..702e4c875 100644 --- a/tools/pnnx/src/pass_level1/nn_LPPool2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_LPPool2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,15 +29,17 @@ public: return "nn.LPPool2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); - op->params["norm_type"] = pow->inputs()[1]; + const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow"); + op->params["norm_type"] = pow->input(1); + + const TorchNodeProxy* avg_pool2d = graph.find_node_by_kind("aten::avg_pool2d"); - const torch::jit::Node* avg_pool2d = find_node_by_kind(graph, "aten::avg_pool2d"); + const TorchNodeProxy* stride = graph.find_producer_node_by_value(avg_pool2d->namedInput("stride")); op->params["kernel_size"] = avg_pool2d->namedInput("kernel_size"); - if (avg_pool2d->namedInput("stride")->node()->inputs().size() == 0) + if (stride->input_count() == 0) { op->params["stride"] = op->params["kernel_size"]; } diff --git a/tools/pnnx/src/pass_level1/nn_LSTM.cpp b/tools/pnnx/src/pass_level1/nn_LSTM.cpp index a2354dfad..44b24b145 100644 --- a/tools/pnnx/src/pass_level1/nn_LSTM.cpp +++ b/tools/pnnx/src/pass_level1/nn_LSTM.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,17 +29,17 @@ public: return "nn.LSTM"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // mod.dump(true, true, true); // // graph->dump(); - const torch::jit::Node* lstm = find_node_by_kind(graph, "aten::lstm"); + const TorchNodeProxy* lstm = graph.find_node_by_kind("aten::lstm"); - const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); - if (return_tuple && return_tuple->inputs().size() == 3 && lstm->outputs().size() == 3 - && return_tuple->inputs()[0] == lstm->outputs()[1] && return_tuple->inputs()[1] == lstm->outputs()[2] && return_tuple->inputs()[2] == lstm->outputs()[0]) + const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct"); + if (return_tuple && return_tuple->input_count() == 3 && lstm->output_count() == 3 + && return_tuple->input(0) == lstm->output(1) && return_tuple->input(1) == lstm->output(2) && return_tuple->input(2) == lstm->output(0)) { // mark the swapped output tuple // we would restore the fine order in pass_level3/fuse_rnn_unpack @@ -54,8 +52,8 @@ public: // fprintf(stderr, "arg %s\n", aa.name().c_str()); // } - const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); - const auto& weight_hh_l0 = mod.attr("weight_hh_l0").toTensor(); + const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0"); + const TorchTensorProxy& weight_hh_l0 = mod.attr("weight_hh_l0"); op->params["input_size"] = weight_ih_l0.size(1); op->params["hidden_size"] = weight_ih_l0.size(0) / 4; @@ -75,23 +73,23 @@ public: std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); - op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); - op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); + op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key); + op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key); if (bias) { std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); - op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); - op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); + op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key); + op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key); } if (proj_size > 0) { std::string weight_hr_lk_key = std::string("weight_hr_l") + std::to_string(k); - op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key).toTensor(); + op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key); } if (bidirectional) @@ -99,23 +97,23 @@ public: std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; - op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); - op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); + op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key); + op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key); if (bias) { std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; - op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); - op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); + op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key); + op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key); } if (proj_size > 0) { std::string weight_hr_lk_reverse_key = std::string("weight_hr_l") + std::to_string(k) + "_reverse"; - op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key).toTensor(); + op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key); } } } diff --git a/tools/pnnx/src/pass_level1/nn_LayerNorm.cpp b/tools/pnnx/src/pass_level1/nn_LayerNorm.cpp index 5faa8d8c0..51ce1c73c 100644 --- a/tools/pnnx/src/pass_level1/nn_LayerNorm.cpp +++ b/tools/pnnx/src/pass_level1/nn_LayerNorm.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.LayerNorm"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* ln = find_node_by_kind(graph, "aten::layer_norm"); + const TorchNodeProxy* ln = graph.find_node_by_kind("aten::layer_norm"); op->params["normalized_shape"] = ln->namedInput("normalized_shape"); op->params["eps"] = ln->namedInput("eps"); @@ -41,8 +39,8 @@ public: if (mod.hasattr("weight") && mod.hasattr("bias")) { - op->attrs["weight"] = mod.attr("weight").toTensor(); - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["weight"] = mod.attr("weight"); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp b/tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp index 689f1c665..161299a09 100644 --- a/tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp +++ b/tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.LeakyReLU"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* leaky_relu = find_node_by_kind(graph, "aten::leaky_relu"); - const torch::jit::Node* leaky_relu_ = find_node_by_kind(graph, "aten::leaky_relu_"); + const TorchNodeProxy* leaky_relu = graph.find_node_by_kind("aten::leaky_relu"); + const TorchNodeProxy* leaky_relu_ = graph.find_node_by_kind("aten::leaky_relu_"); if (leaky_relu_) { diff --git a/tools/pnnx/src/pass_level1/nn_Linear.cpp b/tools/pnnx/src/pass_level1/nn_Linear.cpp index edcb9dc56..ce4c9f5cf 100644 --- a/tools/pnnx/src/pass_level1/nn_Linear.cpp +++ b/tools/pnnx/src/pass_level1/nn_Linear.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,20 +29,20 @@ public: return "nn.Linear"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& /*graph*/, const TorchModuleProxy& mod) const { - const torch::jit::Node* addmm = find_node_by_kind(graph, "aten::addmm"); + // const TorchNodeProxy* addmm = graph.find_node_by_kind("aten::addmm"); - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["in_features"] = weight.size(1); op->params["out_features"] = weight.size(0); - op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor(); + op->params["bias"] = mod.hasattr("bias"); op->attrs["weight"] = weight; - if (mod.hasattr("bias") && mod.attr("bias").isTensor()) + if (mod.hasattr("bias")) { - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp b/tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp index 1f88f78f0..e119cce84 100644 --- a/tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp +++ b/tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,26 +29,27 @@ public: return "nn.LocalResponseNorm"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* avg_pool = find_node_by_kind(graph, "aten::avg_pool2d"); - const torch::jit::Node* avg_pool3d = find_node_by_kind(graph, "aten::avg_pool3d"); + const TorchNodeProxy* avg_pool = graph.find_node_by_kind("aten::avg_pool2d"); + const TorchNodeProxy* avg_pool3d = graph.find_node_by_kind("aten::avg_pool3d"); if (avg_pool3d) { avg_pool = avg_pool3d; } - op->params["size"] = avg_pool->namedInput("kernel_size")->node()->inputs()[0]; + const TorchNodeProxy* kernel_size = graph.find_producer_node_by_value(avg_pool->namedInput("kernel_size")); + op->params["size"] = kernel_size->input(0); - const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); - op->params["beta"] = pow->inputs()[1]; + const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow"); + op->params["beta"] = pow->input(1); - const torch::jit::Node* add = pow->inputs()[0]->node(); - op->params["k"] = add->inputs()[1]; + const TorchNodeProxy* add = graph.find_producer_node_by_value(pow->input(0)); + op->params["k"] = add->input(1); - const torch::jit::Node* mul = add->inputs()[0]->node(); - op->params["alpha"] = mul->inputs()[1]; + const TorchNodeProxy* mul = graph.find_producer_node_by_value(add->input(0)); + op->params["alpha"] = mul->input(1); } }; diff --git a/tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp b/tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp index d0a164604..949bd3579 100644 --- a/tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp +++ b/tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp b/tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp index e5ecd673c..de696d1d5 100644 --- a/tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp +++ b/tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.LogSoftmax"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* log_softmax = find_node_by_kind(graph, "aten::log_softmax"); + const TorchNodeProxy* log_softmax = graph.find_node_by_kind("aten::log_softmax"); op->params["dim"] = log_softmax->namedInput("dim"); } diff --git a/tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp b/tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp index 5f42ca5b7..18b619d52 100644 --- a/tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.MaxPool1d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* max_pool1d = find_node_by_kind(graph, "aten::max_pool1d"); - const torch::jit::Node* max_pool1d_with_indices = find_node_by_kind(graph, "aten::max_pool1d_with_indices"); + const TorchNodeProxy* max_pool1d = graph.find_node_by_kind("aten::max_pool1d"); + const TorchNodeProxy* max_pool1d_with_indices = graph.find_node_by_kind("aten::max_pool1d_with_indices"); if (max_pool1d_with_indices) { diff --git a/tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp b/tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp index 8a8064415..5cdc45fc9 100644 --- a/tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.MaxPool2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* max_pool2d = find_node_by_kind(graph, "aten::max_pool2d"); - const torch::jit::Node* max_pool2d_with_indices = find_node_by_kind(graph, "aten::max_pool2d_with_indices"); + const TorchNodeProxy* max_pool2d = graph.find_node_by_kind("aten::max_pool2d"); + const TorchNodeProxy* max_pool2d_with_indices = graph.find_node_by_kind("aten::max_pool2d_with_indices"); if (max_pool2d_with_indices) { diff --git a/tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp b/tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp index e53e4a7fd..46e845c76 100644 --- a/tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.MaxPool3d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* max_pool3d = find_node_by_kind(graph, "aten::max_pool3d"); - const torch::jit::Node* max_pool3d_with_indices = find_node_by_kind(graph, "aten::max_pool3d_with_indices"); + const TorchNodeProxy* max_pool3d = graph.find_node_by_kind("aten::max_pool3d"); + const TorchNodeProxy* max_pool3d_with_indices = graph.find_node_by_kind("aten::max_pool3d_with_indices"); if (max_pool3d_with_indices) { diff --git a/tools/pnnx/src/pass_level1/nn_Mish.cpp b/tools/pnnx/src/pass_level1/nn_Mish.cpp index c65fb6a89..837bbc258 100644 --- a/tools/pnnx/src/pass_level1/nn_Mish.cpp +++ b/tools/pnnx/src/pass_level1/nn_Mish.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp b/tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp index f894638d3..6a0ca41e0 100644 --- a/tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp +++ b/tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp @@ -12,11 +12,13 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include +// #include "pass_level1.h" +// +// #include +// +// #include "../utils.h" -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -33,19 +35,19 @@ public: return "nn.MultiheadAttention"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // mod.dump(false, false, false); // graph->dump(); - const torch::jit::Node* multi_head_attention = find_node_by_kind(graph, "aten::_native_multi_head_attention"); + const TorchNodeProxy* multi_head_attention = graph.find_node_by_kind("aten::_native_multi_head_attention"); if (multi_head_attention) { op->params["num_heads"] = multi_head_attention->namedInput("num_head"); op->params["batch_first"] = true; op->params["add_zero_attn"] = false; - if (multi_head_attention->hasNamedInput("mask") && multi_head_attention->namedInput("mask") == graph->inputs()[graph->inputs().size() - 1]) + if (multi_head_attention->hasNamedInput("mask") && multi_head_attention->namedInput("mask") == graph.input(graph.input_count() - 1)) { size_t input_count = op->inputs.size(); op->inputnames.resize(input_count); @@ -54,25 +56,28 @@ public: } else { - const torch::jit::Node* div_num_heads = find_node_by_kind(graph, "aten::div"); - const torch::jit::Node* div_num_heads_18 = find_node_by_kind(graph, "aten::floor_divide"); + const TorchNodeProxy* div_num_heads = graph.find_node_by_kind("aten::div"); + const TorchNodeProxy* div_num_heads_18 = graph.find_node_by_kind("aten::floor_divide"); if (div_num_heads_18) { div_num_heads = div_num_heads_18; } - op->params["num_heads"] = (int)div_num_heads->input(1)->node()->t(torch::jit::attr::value).item(); + // const TorchNodeProxy* div_num_heads_input_1 = graph.find_producer_node_by_value(div_num_heads->input(1)); + + // op->params["num_heads"] = (int)div_num_heads_input_1->t(torch::jit::attr::value).item(); + op->params["num_heads"] = div_num_heads->input(1); - const torch::jit::Node* transpose_batch_seq = find_node_by_kind(graph, "aten::transpose"); + const TorchNodeProxy* transpose_batch_seq = graph.find_node_by_kind("aten::transpose"); - int transpose_dim0 = transpose_batch_seq->input(1)->node()->i(torch::jit::attr::value); - int transpose_dim1 = transpose_batch_seq->input(2)->node()->i(torch::jit::attr::value); - if (transpose_dim0 == 1 && transpose_dim1 == 0) + Parameter transpose_dim0 = transpose_batch_seq->input(1); + Parameter transpose_dim1 = transpose_batch_seq->input(2); + if (transpose_dim0.i == 1 && transpose_dim1.i == 0) { op->params["batch_first"] = true; } - const torch::jit::Node* add_zero_attn = find_node_by_kind(graph, "aten::zeros"); + const TorchNodeProxy* add_zero_attn = graph.find_node_by_kind("aten::zeros"); if (add_zero_attn) { op->params["add_zero_attn"] = true; @@ -82,10 +87,10 @@ public: op->params["add_zero_attn"] = false; } - const torch::jit::Node* scaled_dot_product_attention = find_node_by_kind(graph, "aten::scaled_dot_product_attention"); + const TorchNodeProxy* scaled_dot_product_attention = graph.find_node_by_kind("aten::scaled_dot_product_attention"); if (scaled_dot_product_attention) { - if (scaled_dot_product_attention->input(3)->type()->kind() != c10::TypeKind::NoneType) + if (!scaled_dot_product_attention->is_input_none(3)) { size_t input_count = op->inputs.size(); op->inputnames.resize(input_count); @@ -94,7 +99,7 @@ public: } // find attention mask addition pattern pre torch-2.1 - const torch::jit::Node* has_attn_mask = find_node_by_kind(graph, "aten::baddbmm"); + const TorchNodeProxy* has_attn_mask = graph.find_node_by_kind("aten::baddbmm"); if (has_attn_mask) { size_t input_count = op->inputs.size(); @@ -106,14 +111,14 @@ public: // attn = torch.bmm(Q, K) // input0 = torch.add_(attn, attn_mask) // attn0 = torch.softmax(input0, -1) - const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax"); + const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax"); if (softmax) { - const torch::jit::Node* add_ = softmax->input(0)->node(); - if (add_ && add_->kind().toDisplayString() == std::string("aten::add_")) + const TorchNodeProxy* add_ = graph.find_producer_node_by_value(softmax->input(0)); + if (add_ && add_->kind() == "aten::add_") { - const torch::jit::Node* bmm = add_->input(0)->node(); - if (bmm && bmm->kind().toDisplayString() == std::string("aten::bmm")) + const TorchNodeProxy* bmm = graph.find_producer_node_by_value(add_->input(0)); + if (bmm && bmm->kind() == "aten::bmm") { size_t input_count = op->inputs.size(); op->inputnames.resize(input_count); @@ -125,7 +130,7 @@ public: if (mod.hasattr("in_proj_weight")) { - const auto& in_proj_weight = mod.attr("in_proj_weight").toTensor(); + const TorchTensorProxy& in_proj_weight = mod.attr("in_proj_weight"); op->params["embed_dim"] = in_proj_weight.size(1); op->params["kdim"] = in_proj_weight.size(1); @@ -134,9 +139,9 @@ public: } else { - const auto& q_proj_weight = mod.attr("q_proj_weight").toTensor(); - const auto& k_proj_weight = mod.attr("k_proj_weight").toTensor(); - const auto& v_proj_weight = mod.attr("v_proj_weight").toTensor(); + const TorchTensorProxy& q_proj_weight = mod.attr("q_proj_weight"); + const TorchTensorProxy& k_proj_weight = mod.attr("k_proj_weight"); + const TorchTensorProxy& v_proj_weight = mod.attr("v_proj_weight"); op->params["embed_dim"] = q_proj_weight.size(1); op->params["kdim"] = k_proj_weight.size(1); @@ -146,15 +151,15 @@ public: op->attrs["v_proj_weight"] = v_proj_weight; } - const auto& out_proj_weight = mod.attr("out_proj").toModule().attr("weight").toTensor(); + const TorchTensorProxy& out_proj_weight = mod.attr("out_proj.weight"); op->attrs["out_proj.weight"] = out_proj_weight; - if (mod.hasattr("in_proj_bias") && mod.attr("out_proj").toModule().hasattr("bias")) + if (mod.hasattr("in_proj_bias") && mod.hasattr("out_proj.bias")) { // bias=True - const auto& in_proj_bias = mod.attr("in_proj_bias").toTensor(); - const auto& out_proj_bias = mod.attr("out_proj").toModule().attr("bias").toTensor(); + const TorchTensorProxy& in_proj_bias = mod.attr("in_proj_bias"); + const TorchTensorProxy& out_proj_bias = mod.attr("out_proj.bias"); op->params["bias"] = true; op->attrs["in_proj_bias"] = in_proj_bias; @@ -166,9 +171,9 @@ public: // the output projection bias always there no matter bias is False in pytorch 1.8 // this behavior changes since https://github.com/pytorch/pytorch/commit/58d1b3639bc07f9519de18e5a18e575f260c7eeb - if (mod.attr("out_proj").toModule().hasattr("bias")) + if (mod.hasattr("out_proj.bias")) { - const auto& out_proj_bias = mod.attr("out_proj").toModule().attr("bias").toTensor(); + const TorchTensorProxy& out_proj_bias = mod.attr("out_proj.bias"); op->attrs["out_proj.bias"] = out_proj_bias; } } @@ -176,8 +181,8 @@ public: if (mod.hasattr("bias_k") && mod.hasattr("bias_v")) { // add_bias_kv=True - const auto& bias_k = mod.attr("bias_k").toTensor(); - const auto& bias_v = mod.attr("bias_v").toTensor(); + const TorchTensorProxy& bias_k = mod.attr("bias_k"); + const TorchTensorProxy& bias_v = mod.attr("bias_v"); op->params["add_bias_kv"] = true; op->attrs["bias_k"] = bias_k; diff --git a/tools/pnnx/src/pass_level1/nn_PReLU.cpp b/tools/pnnx/src/pass_level1/nn_PReLU.cpp index 52b3f2497..7ae2f7b8c 100644 --- a/tools/pnnx/src/pass_level1/nn_PReLU.cpp +++ b/tools/pnnx/src/pass_level1/nn_PReLU.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.PReLU"; } - void write(Operator* op, const std::shared_ptr& /*graph*/, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& /*graph*/, const TorchModuleProxy& mod) const { - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); op->params["num_parameters"] = weight.size(0); diff --git a/tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp b/tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp index f9c1bbc6b..2f58e85bf 100644 --- a/tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp +++ b/tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.PixelShuffle"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pixel_shuffle = find_node_by_kind(graph, "aten::pixel_shuffle"); + const TorchNodeProxy* pixel_shuffle = graph.find_node_by_kind("aten::pixel_shuffle"); op->params["upscale_factor"] = pixel_shuffle->namedInput("upscale_factor"); } diff --git a/tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp b/tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp index 73154db9a..fedcefa8b 100644 --- a/tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp +++ b/tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.PixelUnshuffle"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pixel_unshuffle = find_node_by_kind(graph, "aten::pixel_unshuffle"); + const TorchNodeProxy* pixel_unshuffle = graph.find_node_by_kind("aten::pixel_unshuffle"); op->params["downscale_factor"] = pixel_unshuffle->namedInput("downscale_factor"); } diff --git a/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp b/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp index 498f0453c..e564ec4b7 100644 --- a/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp +++ b/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.RMSNorm"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* rmsn = find_node_by_kind(graph, "aten::rms_norm"); + const TorchNodeProxy* rmsn = graph.find_node_by_kind("aten::rms_norm"); op->params["normalized_shape"] = rmsn->namedInput("normalized_shape"); op->params["eps"] = rmsn->namedInput("eps"); @@ -41,7 +39,7 @@ public: if (mod.hasattr("weight")) { - op->attrs["weight"] = mod.attr("weight").toTensor(); + op->attrs["weight"] = mod.attr("weight"); } } }; diff --git a/tools/pnnx/src/pass_level1/nn_RNN.cpp b/tools/pnnx/src/pass_level1/nn_RNN.cpp index ab6483929..a13abd496 100644 --- a/tools/pnnx/src/pass_level1/nn_RNN.cpp +++ b/tools/pnnx/src/pass_level1/nn_RNN.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,23 +29,23 @@ public: return "nn.RNN"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // mod.dump(true, true, true); // graph->dump(); - const torch::jit::Node* rnn = find_node_by_kind(graph, "aten::rnn_tanh"); - const torch::jit::Node* rnn_relu = find_node_by_kind(graph, "aten::rnn_relu"); + const TorchNodeProxy* rnn = graph.find_node_by_kind("aten::rnn_tanh"); + const TorchNodeProxy* rnn_relu = graph.find_node_by_kind("aten::rnn_relu"); if (rnn_relu) { rnn = rnn_relu; } - const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); - if (return_tuple && return_tuple->inputs().size() == 2 && rnn->outputs().size() == 2 - && return_tuple->inputs()[0] == rnn->outputs()[1] && return_tuple->inputs()[1] == rnn->outputs()[0]) + const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct"); + if (return_tuple && return_tuple->input_count() == 2 && rnn->output_count() == 2 + && return_tuple->input(0) == rnn->output(1) && return_tuple->input(1) == rnn->output(0)) { // mark the swapped output tuple // we would restore the fine order in pass_level3/fuse_rnn_unpack @@ -60,7 +58,7 @@ public: // fprintf(stderr, "arg %s\n", aa.name().c_str()); // } - const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); + const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0"); op->params["input_size"] = weight_ih_l0.size(1); op->params["hidden_size"] = weight_ih_l0.size(0); @@ -79,16 +77,16 @@ public: std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); - op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); - op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); + op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key); + op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key); if (bias) { std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); - op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); - op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); + op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key); + op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key); } if (bidirectional) @@ -96,16 +94,16 @@ public: std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; - op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); - op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); + op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key); + op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key); if (bias) { std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; - op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); - op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); + op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key); + op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key); } } } diff --git a/tools/pnnx/src/pass_level1/nn_RReLU.cpp b/tools/pnnx/src/pass_level1/nn_RReLU.cpp index 538552ad2..35f4c240d 100644 --- a/tools/pnnx/src/pass_level1/nn_RReLU.cpp +++ b/tools/pnnx/src/pass_level1/nn_RReLU.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.RReLU"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* rrelu = find_node_by_kind(graph, "aten::rrelu"); + const TorchNodeProxy* rrelu = graph.find_node_by_kind("aten::rrelu"); op->params["lower"] = rrelu->namedInput("lower"); op->params["upper"] = rrelu->namedInput("upper"); diff --git a/tools/pnnx/src/pass_level1/nn_ReLU.cpp b/tools/pnnx/src/pass_level1/nn_ReLU.cpp index bc213c64d..fbd45a8fc 100644 --- a/tools/pnnx/src/pass_level1/nn_ReLU.cpp +++ b/tools/pnnx/src/pass_level1/nn_ReLU.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_ReLU6.cpp b/tools/pnnx/src/pass_level1/nn_ReLU6.cpp index 69bc36a6f..5616ee413 100644 --- a/tools/pnnx/src/pass_level1/nn_ReLU6.cpp +++ b/tools/pnnx/src/pass_level1/nn_ReLU6.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp b/tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp index 42e3d4da6..dae365c23 100644 --- a/tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ReflectionPad1d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* reflection_pad1d = find_node_by_kind(graph, "aten::reflection_pad1d"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* reflection_pad1d = graph.find_node_by_kind("aten::reflection_pad1d"); if (pad) { diff --git a/tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp b/tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp index 57ff7da17..efa71a6bb 100644 --- a/tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ReflectionPad2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* reflection_pad2d = graph.find_node_by_kind("aten::reflection_pad2d"); if (pad) { diff --git a/tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp b/tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp index 1a5cf230d..99c6d3fce 100644 --- a/tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ReplicationPad1d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* replication_pad1d = find_node_by_kind(graph, "aten::replication_pad1d"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* replication_pad1d = graph.find_node_by_kind("aten::replication_pad1d"); if (pad) { diff --git a/tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp b/tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp index fc8b643bb..f977b5e90 100644 --- a/tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ReplicationPad2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* replication_pad2d = graph.find_node_by_kind("aten::replication_pad2d"); if (pad) { diff --git a/tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp b/tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp index e4759fd06..b3eee4555 100644 --- a/tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ReplicationPad3d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* replication_pad3d = graph.find_node_by_kind("aten::replication_pad3d"); if (pad) { diff --git a/tools/pnnx/src/pass_level1/nn_SELU.cpp b/tools/pnnx/src/pass_level1/nn_SELU.cpp index c6ca3d472..501a3f4a8 100644 --- a/tools/pnnx/src/pass_level1/nn_SELU.cpp +++ b/tools/pnnx/src/pass_level1/nn_SELU.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_SiLU.cpp b/tools/pnnx/src/pass_level1/nn_SiLU.cpp index 816e9aa77..b95c348db 100644 --- a/tools/pnnx/src/pass_level1/nn_SiLU.cpp +++ b/tools/pnnx/src/pass_level1/nn_SiLU.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Sigmoid.cpp b/tools/pnnx/src/pass_level1/nn_Sigmoid.cpp index f106c553c..26297c67b 100644 --- a/tools/pnnx/src/pass_level1/nn_Sigmoid.cpp +++ b/tools/pnnx/src/pass_level1/nn_Sigmoid.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Softmax.cpp b/tools/pnnx/src/pass_level1/nn_Softmax.cpp index f0baa0188..5461f9f6a 100644 --- a/tools/pnnx/src/pass_level1/nn_Softmax.cpp +++ b/tools/pnnx/src/pass_level1/nn_Softmax.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Softmax"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax"); + const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax"); op->params["dim"] = softmax->namedInput("dim"); } diff --git a/tools/pnnx/src/pass_level1/nn_Softmax2d.cpp b/tools/pnnx/src/pass_level1/nn_Softmax2d.cpp index c80404066..5536c7979 100644 --- a/tools/pnnx/src/pass_level1/nn_Softmax2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Softmax2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Softmin.cpp b/tools/pnnx/src/pass_level1/nn_Softmin.cpp index 1a1bec249..4dae08d0b 100644 --- a/tools/pnnx/src/pass_level1/nn_Softmin.cpp +++ b/tools/pnnx/src/pass_level1/nn_Softmin.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Softmin"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax"); + const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax"); op->params["dim"] = softmax->namedInput("dim"); } diff --git a/tools/pnnx/src/pass_level1/nn_Softplus.cpp b/tools/pnnx/src/pass_level1/nn_Softplus.cpp index cf470b0e7..b76065a3d 100644 --- a/tools/pnnx/src/pass_level1/nn_Softplus.cpp +++ b/tools/pnnx/src/pass_level1/nn_Softplus.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Softplus"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* softplus = find_node_by_kind(graph, "aten::softplus"); + const TorchNodeProxy* softplus = graph.find_node_by_kind("aten::softplus"); op->params["beta"] = softplus->namedInput("beta"); op->params["threshold"] = softplus->namedInput("threshold"); diff --git a/tools/pnnx/src/pass_level1/nn_Softshrink.cpp b/tools/pnnx/src/pass_level1/nn_Softshrink.cpp index 4c16ef814..f61f4b04f 100644 --- a/tools/pnnx/src/pass_level1/nn_Softshrink.cpp +++ b/tools/pnnx/src/pass_level1/nn_Softshrink.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Softshrink"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* softshrink = find_node_by_kind(graph, "aten::softshrink"); + const TorchNodeProxy* softshrink = graph.find_node_by_kind("aten::softshrink"); op->params["lambd"] = softshrink->namedInput("lambd"); } diff --git a/tools/pnnx/src/pass_level1/nn_Softsign.cpp b/tools/pnnx/src/pass_level1/nn_Softsign.cpp index 05be02954..e28f73ccd 100644 --- a/tools/pnnx/src/pass_level1/nn_Softsign.cpp +++ b/tools/pnnx/src/pass_level1/nn_Softsign.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Tanh.cpp b/tools/pnnx/src/pass_level1/nn_Tanh.cpp index 802b91c35..cdbfc25f3 100644 --- a/tools/pnnx/src/pass_level1/nn_Tanh.cpp +++ b/tools/pnnx/src/pass_level1/nn_Tanh.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp b/tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp index 78f4f4d13..650363c5d 100644 --- a/tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp +++ b/tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" namespace pnnx { diff --git a/tools/pnnx/src/pass_level1/nn_Threshold.cpp b/tools/pnnx/src/pass_level1/nn_Threshold.cpp index ee52f6cbd..1bbb1fe27 100644 --- a/tools/pnnx/src/pass_level1/nn_Threshold.cpp +++ b/tools/pnnx/src/pass_level1/nn_Threshold.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Threshold"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* threshold = find_node_by_kind(graph, "aten::threshold"); + const TorchNodeProxy* threshold = graph.find_node_by_kind("aten::threshold"); op->params["threshold"] = threshold->namedInput("threshold"); op->params["value"] = threshold->namedInput("value"); diff --git a/tools/pnnx/src/pass_level1/nn_Unfold.cpp b/tools/pnnx/src/pass_level1/nn_Unfold.cpp index 1abf6201a..8be9f6be7 100644 --- a/tools/pnnx/src/pass_level1/nn_Unfold.cpp +++ b/tools/pnnx/src/pass_level1/nn_Unfold.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.Unfold"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* im2col = find_node_by_kind(graph, "aten::im2col"); + const TorchNodeProxy* im2col = graph.find_node_by_kind("aten::im2col"); op->params["kernel_size"] = im2col->namedInput("kernel_size"); op->params["stride"] = im2col->namedInput("stride"); diff --git a/tools/pnnx/src/pass_level1/nn_Upsample.cpp b/tools/pnnx/src/pass_level1/nn_Upsample.cpp index f665eaa3f..f355919a2 100644 --- a/tools/pnnx/src/pass_level1/nn_Upsample.cpp +++ b/tools/pnnx/src/pass_level1/nn_Upsample.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,19 +29,19 @@ public: return "nn.Upsample"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* upsample_nearest1d = find_node_by_kind(graph, "aten::upsample_nearest1d"); - const torch::jit::Node* upsample_linear1d = find_node_by_kind(graph, "aten::upsample_linear1d"); + const TorchNodeProxy* upsample_nearest1d = graph.find_node_by_kind("aten::upsample_nearest1d"); + const TorchNodeProxy* upsample_linear1d = graph.find_node_by_kind("aten::upsample_linear1d"); - const torch::jit::Node* upsample_nearest2d = find_node_by_kind(graph, "aten::upsample_nearest2d"); - const torch::jit::Node* upsample_bilinear2d = find_node_by_kind(graph, "aten::upsample_bilinear2d"); - const torch::jit::Node* upsample_bicubic2d = find_node_by_kind(graph, "aten::upsample_bicubic2d"); + const TorchNodeProxy* upsample_nearest2d = graph.find_node_by_kind("aten::upsample_nearest2d"); + const TorchNodeProxy* upsample_bilinear2d = graph.find_node_by_kind("aten::upsample_bilinear2d"); + const TorchNodeProxy* upsample_bicubic2d = graph.find_node_by_kind("aten::upsample_bicubic2d"); - const torch::jit::Node* upsample_nearest3d = find_node_by_kind(graph, "aten::upsample_nearest3d"); - const torch::jit::Node* upsample_trilinear3d = find_node_by_kind(graph, "aten::upsample_trilinear3d"); + const TorchNodeProxy* upsample_nearest3d = graph.find_node_by_kind("aten::upsample_nearest3d"); + const TorchNodeProxy* upsample_trilinear3d = graph.find_node_by_kind("aten::upsample_trilinear3d"); - const torch::jit::Node* upsample = 0; + const TorchNodeProxy* upsample = 0; if (upsample_nearest1d) { upsample = upsample_nearest1d; @@ -136,12 +134,17 @@ public: std::vector scale_factor; try { - const torch::jit::Node* size_list = find_node_by_kind(graph, "prim::ListConstruct"); - for (auto x : size_list->inputs()) + const TorchNodeProxy* size_list = graph.find_node_by_kind("prim::ListConstruct"); + const int size_list_input_count = size_list->input_count(); + for (int i = 0; i < size_list_input_count; i++) { - auto scale_tensor = x->node()->inputs()[0]->node()->inputs()[0]->node()->inputs()[0]->node()->inputs()[1]->node()->inputs()[0]->node()->inputs()[0]->node(); - auto t = scale_tensor->t(torch::jit::attr::value); - float s = (float)t.item(); + const TorchNodeProxy* scale_tensor = graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(size_list->input(i))->input(0))->input(0))->input(0))->input(1))->input(0))->input(0)); + + // auto t = scale_tensor->t(torch::jit::attr::value); + // float s = (float)t.item(); + Parameter ps = scale_tensor->node; + float s = ps.f; + scale_factor.push_back(s); } @@ -150,7 +153,7 @@ public: catch (...) { fprintf(stderr, "unhandled upsample recompute_scale_factor graph"); - graph->dump(); + graph.dump(); } } } diff --git a/tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp b/tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp index 2b0441d44..07a8c22d3 100644 --- a/tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.UpsamplingBilinear2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* upsample = find_node_by_kind(graph, "aten::upsample_bilinear2d"); + const TorchNodeProxy* upsample = graph.find_node_by_kind("aten::upsample_bilinear2d"); if (upsample->hasNamedInput("output_size")) { diff --git a/tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp b/tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp index 20f8e32fd..2b1856134 100644 --- a/tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,9 +29,9 @@ public: return "nn.UpsamplingNearest2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* upsample = find_node_by_kind(graph, "aten::upsample_nearest2d"); + const TorchNodeProxy* upsample = graph.find_node_by_kind("aten::upsample_nearest2d"); if (upsample->hasNamedInput("output_size")) { diff --git a/tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp b/tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp index b7e7b1457..4886774df 100644 --- a/tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,10 +29,10 @@ public: return "nn.ZeroPad2d"; } - void write(Operator* op, const std::shared_ptr& graph) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); + const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); + const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd"); if (!pad) { diff --git a/tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp b/tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp index f57dfcd45..90e6daee8 100644 --- a/tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp @@ -12,9 +12,11 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" -#include "../utils.h" +#include +#include +#include namespace pnnx { @@ -31,11 +33,13 @@ public: return "nn.quantized.Conv2d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const { + const auto& mod = _mod.mod; + // graph->dump(); - const torch::jit::Node* quantized_convolution = find_node_by_kind(graph, "quantized::conv2d"); + const TorchNodeProxy* quantized_convolution = graph.find_node_by_kind("quantized::conv2d"); // for (auto aa : quantized_convolution->schema().arguments()) // { @@ -113,11 +117,13 @@ public: return "nn.intrinsic.quantized.ConvReLU2d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const { + const auto& mod = _mod.mod; + // graph->dump(); - const torch::jit::Node* quantized_convolution = find_node_by_kind(graph, "quantized::conv2d_relu"); + const TorchNodeProxy* quantized_convolution = graph.find_node_by_kind("quantized::conv2d_relu"); // for (auto aa : quantized_convolution->schema().arguments()) // { diff --git a/tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp b/tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp index 15100839d..7cf389fd8 100644 --- a/tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp +++ b/tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,19 +29,19 @@ public: return "nn.quantized.DeQuantize"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const - { - // mod.dump(true, false, false); - - // graph->dump(); - - const torch::jit::Node* dequantize = find_node_by_kind(graph, "aten::dequantize"); - - // for (auto aa : dequantize->schema().arguments()) - // { - // fprintf(stderr, "arg %s\n", aa.name().c_str()); - // } - } + // void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const + // { + // // mod.dump(true, false, false); + // + // // graph->dump(); + // + // const torch::jit::Node* dequantize = find_node_by_kind(graph, "aten::dequantize"); + // + // // for (auto aa : dequantize->schema().arguments()) + // // { + // // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // // } + // } }; REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(DeQuantize) diff --git a/tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp b/tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp index 1c82c51ca..ddedcc35b 100644 --- a/tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp +++ b/tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp @@ -12,9 +12,11 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" +#include "fuse_module_pass.h" -#include "../utils.h" +#include +#include +#include namespace pnnx { @@ -31,13 +33,15 @@ public: return "nn.quantized.Linear"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const { + const auto& mod = _mod.mod; + // mod.dump(true, false, false); // graph->dump(); - const torch::jit::Node* quantized_linear = find_node_by_kind(graph, "quantized::linear"); + const TorchNodeProxy* quantized_linear = graph.find_node_by_kind("quantized::linear"); // for (auto aa : quantized_linear->schema().arguments()) // { @@ -99,13 +103,15 @@ public: return "nn.intrinsic.quantized.LinearReLU"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const { + const auto& mod = _mod.mod; + // mod.dump(true, false, false); - graph->dump(); + graph.dump(); - const torch::jit::Node* quantized_linear = find_node_by_kind(graph, "quantized::linear_relu"); + const TorchNodeProxy* quantized_linear = graph.find_node_by_kind("quantized::linear_relu"); // for (auto aa : quantized_linear->schema().arguments()) // { diff --git a/tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp b/tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp index f19a02c00..eb3ab819d 100644 --- a/tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp +++ b/tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,13 +29,13 @@ public: return "nn.quantized.Quantize"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph) const { // mod.dump(true, false, false); // graph->dump(); - const torch::jit::Node* quantize_per_tensor = find_node_by_kind(graph, "aten::quantize_per_tensor"); + const TorchNodeProxy* quantize_per_tensor = graph.find_node_by_kind("aten::quantize_per_tensor"); // for (auto aa : quantize_per_tensor->schema().arguments()) // { diff --git a/tools/pnnx/src/pass_level1/torchvision_DeformConv2d.cpp b/tools/pnnx/src/pass_level1/torchvision_DeformConv2d.cpp index 85e7e64ac..dbc3fbf31 100644 --- a/tools/pnnx/src/pass_level1/torchvision_DeformConv2d.cpp +++ b/tools/pnnx/src/pass_level1/torchvision_DeformConv2d.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "torchvision.ops.DeformConv2d"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { - const torch::jit::Node* deform_conv2d = find_node_by_kind(graph, "torchvision::deform_conv2d"); + const TorchNodeProxy* deform_conv2d = graph.find_node_by_kind("torchvision::deform_conv2d"); - const auto& weight = mod.attr("weight").toTensor(); + const TorchTensorProxy& weight = mod.attr("weight"); const Parameter stride_w = deform_conv2d->namedInput("stride_w"); const Parameter stride_h = deform_conv2d->namedInput("stride_h"); @@ -56,7 +54,7 @@ public: op->attrs["weight"] = weight; if (mod.hasattr("bias")) { - op->attrs["bias"] = mod.attr("bias").toTensor(); + op->attrs["bias"] = mod.attr("bias"); } } }; diff --git a/tools/pnnx/src/pass_level1/torchvision_RoIAlign.cpp b/tools/pnnx/src/pass_level1/torchvision_RoIAlign.cpp index 4a81ee45c..ed85d85a5 100644 --- a/tools/pnnx/src/pass_level1/torchvision_RoIAlign.cpp +++ b/tools/pnnx/src/pass_level1/torchvision_RoIAlign.cpp @@ -12,9 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "pass_level1.h" - -#include "../utils.h" +#include "fuse_module_pass.h" namespace pnnx { @@ -31,11 +29,11 @@ public: return "torchvision.ops.RoIAlign"; } - void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& /*mod*/) const + void write(Operator* op, const TorchGraphProxy& graph) const { - const torch::jit::Node* roi_align = find_node_by_kind(graph, "torchvision::roi_align"); + const TorchNodeProxy* roi_align = graph.find_node_by_kind("torchvision::roi_align"); - if (roi_align->inputs()[0] == graph->inputs()[2] && roi_align->inputs()[1] == graph->inputs()[1]) + if (roi_align->input(0) == graph.input(2) && roi_align->input(1) == graph.input(1)) { fprintf(stderr, "roi_align inputs swapped detected !\n"); std::swap(op->inputs[0], op->inputs[1]); diff --git a/tools/pnnx/src/utils.h b/tools/pnnx/src/utils.h index 323d5ab48..d3ae5a1ec 100644 --- a/tools/pnnx/src/utils.h +++ b/tools/pnnx/src/utils.h @@ -15,22 +15,8 @@ #ifndef PNNX_UTILS_H #define PNNX_UTILS_H -#if BUILD_TORCH2PNNX -#include -namespace torch { -namespace jit { -struct Graph; -struct Node; -} // namespace jit -} // namespace torch -#endif - namespace pnnx { -#if BUILD_TORCH2PNNX -const torch::jit::Node* find_node_by_kind(const std::shared_ptr& graph, const std::string& kind); -#endif - unsigned short float32_to_float16(float value); float float16_to_float32(unsigned short value);