Browse Source

pnnx pass level1 wrapper enabling faster build (#6014)

tags/20250428
nihui GitHub 1 year ago
parent
commit
76f48c8fcb
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
100 changed files with 882 additions and 703 deletions
  1. +16
    -14
      tools/pnnx/src/CMakeLists.txt
  2. +4
    -0
      tools/pnnx/src/ir.h
  3. +7
    -12
      tools/pnnx/src/load_torchscript.cpp
  4. +1
    -1
      tools/pnnx/src/pass_level0/inline_block.cpp
  5. +9
    -34
      tools/pnnx/src/pass_level1.cpp
  6. +0
    -29
      tools/pnnx/src/pass_level1.h
  7. +223
    -0
      tools/pnnx/src/pass_level1/fuse_module_pass.cpp
  8. +151
    -0
      tools/pnnx/src/pass_level1/fuse_module_pass.h
  9. +3
    -5
      tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp
  10. +3
    -5
      tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp
  11. +3
    -5
      tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp
  12. +6
    -6
      tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp
  13. +6
    -6
      tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp
  14. +6
    -6
      tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp
  15. +1
    -3
      tools/pnnx/src/pass_level1/nn_AlphaDropout.cpp
  16. +3
    -5
      tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp
  17. +3
    -5
      tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp
  18. +3
    -5
      tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp
  19. +7
    -9
      tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp
  20. +7
    -9
      tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp
  21. +7
    -9
      tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp
  22. +3
    -5
      tools/pnnx/src/pass_level1/nn_CELU.cpp
  23. +3
    -5
      tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp
  24. +4
    -6
      tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp
  25. +4
    -6
      tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp
  26. +4
    -6
      tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp
  27. +9
    -13
      tools/pnnx/src/pass_level1/nn_Conv1d.cpp
  28. +11
    -15
      tools/pnnx/src/pass_level1/nn_Conv2d.cpp
  29. +9
    -13
      tools/pnnx/src/pass_level1/nn_Conv3d.cpp
  30. +5
    -7
      tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp
  31. +5
    -7
      tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp
  32. +5
    -7
      tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp
  33. +1
    -3
      tools/pnnx/src/pass_level1/nn_Dropout.cpp
  34. +1
    -3
      tools/pnnx/src/pass_level1/nn_Dropout2d.cpp
  35. +1
    -3
      tools/pnnx/src/pass_level1/nn_Dropout3d.cpp
  36. +3
    -5
      tools/pnnx/src/pass_level1/nn_ELU.cpp
  37. +4
    -6
      tools/pnnx/src/pass_level1/nn_Embedding.cpp
  38. +3
    -5
      tools/pnnx/src/pass_level1/nn_Fold.cpp
  39. +3
    -5
      tools/pnnx/src/pass_level1/nn_GELU.cpp
  40. +3
    -5
      tools/pnnx/src/pass_level1/nn_GLU.cpp
  41. +15
    -17
      tools/pnnx/src/pass_level1/nn_GRU.cpp
  42. +5
    -7
      tools/pnnx/src/pass_level1/nn_GroupNorm.cpp
  43. +3
    -5
      tools/pnnx/src/pass_level1/nn_Hardshrink.cpp
  44. +1
    -1
      tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp
  45. +1
    -1
      tools/pnnx/src/pass_level1/nn_Hardswish.cpp
  46. +3
    -5
      tools/pnnx/src/pass_level1/nn_Hardtanh.cpp
  47. +7
    -9
      tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp
  48. +7
    -9
      tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp
  49. +7
    -9
      tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp
  50. +11
    -10
      tools/pnnx/src/pass_level1/nn_LPPool1d.cpp
  51. +8
    -8
      tools/pnnx/src/pass_level1/nn_LPPool2d.cpp
  52. +18
    -20
      tools/pnnx/src/pass_level1/nn_LSTM.cpp
  53. +5
    -7
      tools/pnnx/src/pass_level1/nn_LayerNorm.cpp
  54. +4
    -6
      tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp
  55. +7
    -9
      tools/pnnx/src/pass_level1/nn_Linear.cpp
  56. +12
    -13
      tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp
  57. +1
    -3
      tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp
  58. +3
    -5
      tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp
  59. +4
    -6
      tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp
  60. +4
    -6
      tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp
  61. +4
    -6
      tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp
  62. +1
    -1
      tools/pnnx/src/pass_level1/nn_Mish.cpp
  63. +40
    -35
      tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp
  64. +3
    -5
      tools/pnnx/src/pass_level1/nn_PReLU.cpp
  65. +3
    -5
      tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp
  66. +3
    -5
      tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp
  67. +4
    -6
      tools/pnnx/src/pass_level1/nn_RMSNorm.cpp
  68. +16
    -18
      tools/pnnx/src/pass_level1/nn_RNN.cpp
  69. +3
    -5
      tools/pnnx/src/pass_level1/nn_RReLU.cpp
  70. +1
    -1
      tools/pnnx/src/pass_level1/nn_ReLU.cpp
  71. +1
    -1
      tools/pnnx/src/pass_level1/nn_ReLU6.cpp
  72. +4
    -6
      tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp
  73. +4
    -6
      tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp
  74. +4
    -6
      tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp
  75. +4
    -6
      tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp
  76. +4
    -6
      tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp
  77. +1
    -3
      tools/pnnx/src/pass_level1/nn_SELU.cpp
  78. +1
    -1
      tools/pnnx/src/pass_level1/nn_SiLU.cpp
  79. +1
    -1
      tools/pnnx/src/pass_level1/nn_Sigmoid.cpp
  80. +3
    -5
      tools/pnnx/src/pass_level1/nn_Softmax.cpp
  81. +1
    -3
      tools/pnnx/src/pass_level1/nn_Softmax2d.cpp
  82. +3
    -5
      tools/pnnx/src/pass_level1/nn_Softmin.cpp
  83. +3
    -5
      tools/pnnx/src/pass_level1/nn_Softplus.cpp
  84. +3
    -5
      tools/pnnx/src/pass_level1/nn_Softshrink.cpp
  85. +1
    -1
      tools/pnnx/src/pass_level1/nn_Softsign.cpp
  86. +1
    -1
      tools/pnnx/src/pass_level1/nn_Tanh.cpp
  87. +1
    -1
      tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp
  88. +3
    -5
      tools/pnnx/src/pass_level1/nn_Threshold.cpp
  89. +3
    -5
      tools/pnnx/src/pass_level1/nn_Unfold.cpp
  90. +21
    -18
      tools/pnnx/src/pass_level1/nn_Upsample.cpp
  91. +3
    -5
      tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp
  92. +3
    -5
      tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp
  93. +4
    -6
      tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp
  94. +12
    -6
      tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp
  95. +14
    -16
      tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp
  96. +13
    -7
      tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp
  97. +3
    -5
      tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp
  98. +5
    -7
      tools/pnnx/src/pass_level1/torchvision_DeformConv2d.cpp
  99. +4
    -6
      tools/pnnx/src/pass_level1/torchvision_RoIAlign.cpp
  100. +0
    -14
      tools/pnnx/src/utils.h

+ 16
- 14
tools/pnnx/src/CMakeLists.txt View File

@@ -11,6 +11,8 @@ set(pnnx_pass_level0_SRCS
)

set(pnnx_pass_level1_SRCS
pass_level1/fuse_module_pass.cpp

pass_level1/nn_AdaptiveAvgPool1d.cpp
pass_level1/nn_AdaptiveAvgPool2d.cpp
pass_level1/nn_AdaptiveAvgPool3d.cpp
@@ -688,20 +690,20 @@ if(onnxruntime_FOUND)
pass_onnx/shape_inference.cpp
pass_onnx/fuse_constant_as_attribute.cpp

pass_onnx/nn_AdaptiveAvgPool2d.cpp
pass_onnx/nn_AdaptiveAvgPool3d.cpp
pass_onnx/nn_AvgPool2d.cpp
pass_onnx/nn_AvgPool3d.cpp
pass_onnx/nn_BatchNorm2d.cpp
pass_onnx/nn_BatchNorm3d.cpp
pass_onnx/nn_Conv2d.cpp
pass_onnx/nn_Conv3d.cpp
pass_onnx/nn_GELU.cpp
pass_onnx/nn_LayerNorm.cpp
pass_onnx/nn_Linear.cpp
pass_onnx/nn_MaxPool2d.cpp
pass_onnx/nn_MaxPool3d.cpp
pass_onnx/nn_MultiheadAttention.cpp
# pass_onnx/nn_AdaptiveAvgPool2d.cpp
# pass_onnx/nn_AdaptiveAvgPool3d.cpp
# pass_onnx/nn_AvgPool2d.cpp
# pass_onnx/nn_AvgPool3d.cpp
# pass_onnx/nn_BatchNorm2d.cpp
# pass_onnx/nn_BatchNorm3d.cpp
# pass_onnx/nn_Conv2d.cpp
# pass_onnx/nn_Conv3d.cpp
# pass_onnx/nn_GELU.cpp
# pass_onnx/nn_LayerNorm.cpp
# pass_onnx/nn_Linear.cpp
# pass_onnx/nn_MaxPool2d.cpp
# pass_onnx/nn_MaxPool3d.cpp
# pass_onnx/nn_MultiheadAttention.cpp
)

set(onnx2pnnx_SRCS


+ 4
- 0
tools/pnnx/src/ir.h View File

@@ -34,6 +34,9 @@ struct Node;
namespace at {
class Tensor;
}
namespace pnnx {
class TorchTensorProxy;
} // namespace pnnx
#endif // BUILD_TORCH2PNNX

#if BUILD_ONNX2PNNX
@@ -230,6 +233,7 @@ public:

#if BUILD_TORCH2PNNX
Attribute(const at::Tensor& t);
Attribute(const TorchTensorProxy& t);
#endif
#if BUILD_ONNX2PNNX
Attribute(const onnx::TensorProto& t);


+ 7
- 12
tools/pnnx/src/load_torchscript.cpp View File

@@ -31,6 +31,7 @@ int64_t cuda_version();

#include "pass_level0.h"
#include "pass_level1.h"
#include "pass_level1/fuse_module_pass.h"

namespace pnnx {

@@ -372,6 +373,11 @@ Attribute::Attribute(const at::Tensor& t)
}
}

Attribute::Attribute(const TorchTensorProxy& t)
: Attribute(t.t())
{
}

Operand* Graph::new_operand(const torch::jit::Value* v)
{
// Operand* r = new Operand;
@@ -442,17 +448,6 @@ static const char* get_at_tensor_type_str(const at::ScalarType& st)
return "";
}

const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Graph>& graph, const std::string& kind)
{
for (const auto& n : graph->nodes())
{
if (n->kind().toDisplayString() == kind)
return n;
}

return 0;
}

static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, const std::vector<std::string>& types)
{
for (size_t i = 0; i < shapes.size(); i++)
@@ -508,7 +503,7 @@ static void get_traced_input_shape(const std::string& ptpath, std::vector<std::v
{
// read traced_inputs.pkl
caffe2::serialize::PyTorchStreamReader reader(ptpath);
auto v = torch::jit::readArchiveAndTensors("traced_inputs", "", "traced_inputs/", std::nullopt, std::nullopt, std::nullopt, reader);
auto v = torch::jit::readArchiveAndTensors("traced_inputs", "", "traced_inputs/", c10::nullopt, c10::nullopt, c10::nullopt, reader);

if (!v.isGenericDict())
return;


+ 1
- 1
tools/pnnx/src/pass_level0/inline_block.cpp View File

@@ -13,7 +13,7 @@
// specific language governing permissions and limitations under the License.

#include "inline_block.h"
#include "../pass_level1.h"
#include "../pass_level1/fuse_module_pass.h"

#include <set>



+ 9
- 34
tools/pnnx/src/pass_level1.cpp View File

@@ -12,43 +12,16 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/script.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/api/include/torch/version.h>
#include <torch/csrc/jit/passes/quantization/helper.h>

#include "pass_level1.h"

namespace pnnx {

FuseModulePass::~FuseModulePass()
{
}

void FuseModulePass::write(Operator* /*op*/, const std::shared_ptr<torch::jit::Graph>& /*graph*/) const
{
}

void FuseModulePass::write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& /*mod*/) const
{
write(op, graph);
}

static std::vector<const FuseModulePass*> g_global_pnnx_fuse_module_passes;

const std::vector<const FuseModulePass*>& get_global_pnnx_fuse_module_passes()
{
return g_global_pnnx_fuse_module_passes;
}
#include "pass_level1/fuse_module_pass.h"

FuseModulePassRegister::FuseModulePassRegister(const FuseModulePass* _pass)
: pass(_pass)
{
g_global_pnnx_fuse_module_passes.push_back(pass);
}

FuseModulePassRegister::~FuseModulePassRegister()
{
delete pass;
}
namespace pnnx {

static void fuse_moduleop_unpack(Graph& graph, const std::vector<std::string>& module_operators)
{
@@ -399,10 +372,12 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit
op->name = wrapped_name;

#if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 11)
ow->write(op, toGraphFunction(function).graph(), sub_mod);
TorchGraphProxy graph_proxy(toGraphFunction(function).graph());
#else
ow->write(op, function.graph(), sub_mod);
TorchGraphProxy graph_proxy(function.graph());
#endif
TorchModuleProxy sub_mod_proxy(sub_mod);
ow->write(op, graph_proxy, sub_mod_proxy);

break;
}


+ 0
- 29
tools/pnnx/src/pass_level1.h View File

@@ -15,39 +15,10 @@
#ifndef PNNX_PASS_LEVEL1_H
#define PNNX_PASS_LEVEL1_H

#include <torch/script.h>
#include <torch/csrc/jit/api/module.h>
#include "ir.h"

namespace pnnx {

class FuseModulePass
{
public:
virtual ~FuseModulePass();

virtual const char* match_type_str() const = 0;

virtual const char* type_str() const = 0;

virtual void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const;

virtual void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const;
};

class FuseModulePassRegister
{
public:
FuseModulePassRegister(const FuseModulePass* pass);
~FuseModulePassRegister();
const FuseModulePass* pass;
};

const std::vector<const FuseModulePass*>& get_global_pnnx_fuse_module_passes();

#define REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CLASS) \
static FuseModulePassRegister g_global_pnnx_fusemodulepass_##CLASS##_register(new CLASS);

void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit::Graph>& g, const std::vector<std::string>& module_operators, Graph& pg);

} // namespace pnnx


+ 223
- 0
tools/pnnx/src/pass_level1/fuse_module_pass.cpp View File

@@ -0,0 +1,223 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_module_pass.h"

#include <torch/script.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/passes/quantization/helper.h>

namespace pnnx {

std::string TorchNodeProxy::kind() const
{
return node->kind().toDisplayString();
}

bool TorchNodeProxy::hasNamedInput(const std::string& name) const
{
return node->hasNamedInput(name);
}

const torch::jit::Value* TorchNodeProxy::namedInput(const std::string& name) const
{
return node->namedInput(name);
}

int TorchNodeProxy::input_count() const
{
return node->inputs().size();
}

const torch::jit::Value* TorchNodeProxy::input(int i) const
{
return node->input(i);
}

int TorchNodeProxy::output_count() const
{
return node->outputs().size();
}

const torch::jit::Value* TorchNodeProxy::output(int i) const
{
return node->output(i);
}

bool TorchNodeProxy::is_input_none(int i) const
{
return node->input(i)->type()->kind() == c10::TypeKind::NoneType;
}

TorchGraphProxy::TorchGraphProxy(const std::shared_ptr<torch::jit::Graph> _graph)
: graph(_graph)
{
for (const auto& n : graph->nodes())
{
nodes.push_back(n);
}
}

const TorchNodeProxy* TorchGraphProxy::find_node_by_kind(const std::string& kind) const
{
for (const auto& n : nodes)
{
if (n.node->kind().toDisplayString() == kind)
return &n;
}

return 0;
}

const TorchNodeProxy* TorchGraphProxy::find_producer_node_by_value(const torch::jit::Value* value) const
{
for (const auto& n : nodes)
{
if (n.node == value->node())
return &n;
}

fprintf(stderr, "TorchGraphProxy find_producer_node_by_value failed\n");
return 0;
}

int TorchGraphProxy::input_count() const
{
return std::as_const(*graph).inputs().size();
}

const torch::jit::Value* TorchGraphProxy::input(int i) const
{
return std::as_const(*graph).inputs()[i];
}

int TorchGraphProxy::output_count() const
{
return std::as_const(*graph).outputs().size();
}

const torch::jit::Value* TorchGraphProxy::output(int i) const
{
return std::as_const(*graph).outputs()[i];
}

void TorchGraphProxy::dump() const
{
graph->dump();
}

class TorchTensorProxyPrivate
{
public:
at::Tensor t;
};

TorchTensorProxy::TorchTensorProxy(const at::Tensor& _t)
: d(new TorchTensorProxyPrivate)
{
d->t = _t;
}

TorchTensorProxy::~TorchTensorProxy()
{
delete d;
}

const at::Tensor& TorchTensorProxy::t() const
{
return d->t;
}

int TorchTensorProxy::size(size_t i) const
{
return d->t.size(i);
}

TorchModuleProxy::TorchModuleProxy(const torch::jit::Module& _mod)
: mod(_mod)
{
const std::vector<c10::ClassAttribute>& attributes = mod._ivalue()->type()->getAttributes();
for (size_t i = 0; i < attributes.size(); i++)
{
const std::string& name = attributes[i].getName();
const c10::IValue& ivalue = mod._ivalue()->getSlot(i);

if (name.empty())
continue;

if (ivalue.isTensor())
attrs.emplace(name, ivalue.toTensor());

if (ivalue.isModule())
{
const torch::jit::Module submod = ivalue.toModule();

const std::vector<c10::ClassAttribute>& sub_attributes = submod._ivalue()->type()->getAttributes();
for (size_t j = 0; j < sub_attributes.size(); j++)
{
const std::string& sub_name = sub_attributes[j].getName();
const c10::IValue& sub_ivalue = submod._ivalue()->getSlot(j);

if (sub_name.empty())
continue;

if (sub_ivalue.isTensor())
attrs.emplace(name + "." + sub_name, sub_ivalue.toTensor());
}
}
}
}

bool TorchModuleProxy::hasattr(const std::string& name) const
{
return attrs.find(name) != attrs.end();
}

const TorchTensorProxy& TorchModuleProxy::attr(const std::string& name) const
{
return attrs.at(name);
}

FuseModulePass::~FuseModulePass()
{
}

void FuseModulePass::write(Operator* /*op*/, const TorchGraphProxy& /*graph*/) const
{
}

void FuseModulePass::write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& /*mod*/) const
{
write(op, graph);
}

static std::vector<const FuseModulePass*> g_global_pnnx_fuse_module_passes;

const std::vector<const FuseModulePass*>& get_global_pnnx_fuse_module_passes()
{
return g_global_pnnx_fuse_module_passes;
}

FuseModulePassRegister::FuseModulePassRegister(const FuseModulePass* _pass)
: pass(_pass)
{
g_global_pnnx_fuse_module_passes.push_back(pass);
}

FuseModulePassRegister::~FuseModulePassRegister()
{
delete pass;
}

} // namespace pnnx

+ 151
- 0
tools/pnnx/src/pass_level1/fuse_module_pass.h View File

@@ -0,0 +1,151 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef PNNX_FUSE_MODULE_PASS_H
#define PNNX_FUSE_MODULE_PASS_H

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "ir.h"

namespace torch {
namespace jit {
struct Graph;
struct Module;
struct Node;
struct Value;
} // namespace jit
} // namespace torch
namespace at {
struct Tensor;
} // namespace at

namespace pnnx {

class TorchNodeProxy
{
public:
TorchNodeProxy(const torch::jit::Node* _node)
: node(_node)
{
}

std::string kind() const;

bool hasNamedInput(const std::string& name) const;
const torch::jit::Value* namedInput(const std::string& name) const;

int input_count() const;
const torch::jit::Value* input(int i) const;

int output_count() const;
const torch::jit::Value* output(int i) const;

bool is_input_none(int i) const;

public:
const torch::jit::Node* node;
};

class TorchGraphProxy
{
public:
TorchGraphProxy(const std::shared_ptr<torch::jit::Graph> _graph);

// bool has_node(const std::string& name) const;
const TorchNodeProxy* find_node_by_kind(const std::string& kind) const;

const TorchNodeProxy* find_producer_node_by_value(const torch::jit::Value* value) const;

int input_count() const;
const torch::jit::Value* input(int i) const;

int output_count() const;
const torch::jit::Value* output(int i) const;

void dump() const;

public:
const std::shared_ptr<torch::jit::Graph> graph;

public:
std::vector<TorchNodeProxy> nodes;
};

class TorchTensorProxyPrivate;
class TorchTensorProxy
{
public:
TorchTensorProxy(const at::Tensor& _t);
~TorchTensorProxy();

TorchTensorProxy(const TorchTensorProxy&) = delete;
TorchTensorProxy& operator=(const TorchTensorProxy&) = delete;

const at::Tensor& t() const;

int size(size_t i) const;

private:
TorchTensorProxyPrivate* const d;
};

class TorchModuleProxy
{
public:
TorchModuleProxy(const torch::jit::Module& _mod);

bool hasattr(const std::string& name) const;
const TorchTensorProxy& attr(const std::string& name) const;

public:
const torch::jit::Module& mod;

private:
std::unordered_map<std::string, TorchTensorProxy> attrs;
};

class FuseModulePass
{
public:
virtual ~FuseModulePass();

virtual const char* match_type_str() const = 0;

virtual const char* type_str() const = 0;

virtual void write(Operator* op, const TorchGraphProxy& graph) const;

virtual void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const;
};

class FuseModulePassRegister
{
public:
FuseModulePassRegister(const FuseModulePass* pass);
~FuseModulePassRegister();
const FuseModulePass* pass;
};

const std::vector<const FuseModulePass*>& get_global_pnnx_fuse_module_passes();

#define REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CLASS) \
static FuseModulePassRegister g_global_pnnx_fusemodulepass_##CLASS##_register(new CLASS);

} // namespace pnnx

#endif // PNNX_FUSE_MODULE_PASS_H

+ 3
- 5
tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.AdaptiveAvgPool1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* adaptive_avg_pool1d = find_node_by_kind(graph, "aten::adaptive_avg_pool1d");
const TorchNodeProxy* adaptive_avg_pool1d = graph.find_node_by_kind("aten::adaptive_avg_pool1d");

op->params["output_size"] = adaptive_avg_pool1d->namedInput("output_size");
}


+ 3
- 5
tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.AdaptiveAvgPool2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* adaptive_avg_pool2d = find_node_by_kind(graph, "aten::adaptive_avg_pool2d");
const TorchNodeProxy* adaptive_avg_pool2d = graph.find_node_by_kind("aten::adaptive_avg_pool2d");

op->params["output_size"] = adaptive_avg_pool2d->namedInput("output_size");
}


+ 3
- 5
tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.AdaptiveAvgPool3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* adaptive_avg_pool3d = find_node_by_kind(graph, "aten::adaptive_avg_pool3d");
const TorchNodeProxy* adaptive_avg_pool3d = graph.find_node_by_kind("aten::adaptive_avg_pool3d");

op->params["output_size"] = adaptive_avg_pool3d->namedInput("output_size");
}


+ 6
- 6
tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,12 +29,14 @@ public:
return "nn.AdaptiveMaxPool1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* adaptive_max_pool1d = find_node_by_kind(graph, "aten::adaptive_max_pool1d");
const TorchNodeProxy* adaptive_max_pool1d = graph.find_node_by_kind("aten::adaptive_max_pool1d");

const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0));

op->params["output_size"] = adaptive_max_pool1d->namedInput("output_size");
op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false;
op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false;
}
};



+ 6
- 6
tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,12 +29,14 @@ public:
return "nn.AdaptiveMaxPool2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* adaptive_max_pool2d = find_node_by_kind(graph, "aten::adaptive_max_pool2d");
const TorchNodeProxy* adaptive_max_pool2d = graph.find_node_by_kind("aten::adaptive_max_pool2d");

const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0));

op->params["output_size"] = adaptive_max_pool2d->namedInput("output_size");
op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false;
op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false;
}
};



+ 6
- 6
tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,12 +29,14 @@ public:
return "nn.AdaptiveMaxPool3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* adaptive_max_pool3d = find_node_by_kind(graph, "aten::adaptive_max_pool3d");
const TorchNodeProxy* adaptive_max_pool3d = graph.find_node_by_kind("aten::adaptive_max_pool3d");

const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0));

op->params["output_size"] = adaptive_max_pool3d->namedInput("output_size");
op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false;
op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false;
}
};



+ 1
- 3
tools/pnnx/src/pass_level1/nn_AlphaDropout.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 3
- 5
tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.AvgPool1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* avg_pool1d = find_node_by_kind(graph, "aten::avg_pool1d");
const TorchNodeProxy* avg_pool1d = graph.find_node_by_kind("aten::avg_pool1d");

op->params["kernel_size"] = avg_pool1d->namedInput("kernel_size");
op->params["stride"] = avg_pool1d->namedInput("stride");


+ 3
- 5
tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.AvgPool2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* avg_pool2d = find_node_by_kind(graph, "aten::avg_pool2d");
const TorchNodeProxy* avg_pool2d = graph.find_node_by_kind("aten::avg_pool2d");

op->params["kernel_size"] = avg_pool2d->namedInput("kernel_size");
op->params["stride"] = avg_pool2d->namedInput("stride");


+ 3
- 5
tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.AvgPool3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* avg_pool3d = find_node_by_kind(graph, "aten::avg_pool3d");
const TorchNodeProxy* avg_pool3d = graph.find_node_by_kind("aten::avg_pool3d");

op->params["kernel_size"] = avg_pool3d->namedInput("kernel_size");
op->params["stride"] = avg_pool3d->namedInput("stride");


+ 7
- 9
tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,12 +29,12 @@ public:
return "nn.BatchNorm1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm");
const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm");

const auto& running_mean = mod.attr("running_mean").toTensor();
const auto& running_var = mod.attr("running_var").toTensor();
const TorchTensorProxy& running_mean = mod.attr("running_mean");
const TorchTensorProxy& running_var = mod.attr("running_var");

op->params["num_features"] = running_mean.size(0);
op->params["eps"] = bn->namedInput("eps");
@@ -46,8 +44,8 @@ public:
op->attrs["running_var"] = running_var;
if (mod.hasattr("weight") && mod.hasattr("bias"))
{
op->attrs["weight"] = mod.attr("weight").toTensor();
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["weight"] = mod.attr("weight");
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 7
- 9
tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,12 +29,12 @@ public:
return "nn.BatchNorm2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm");
const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm");

const auto& running_mean = mod.attr("running_mean").toTensor();
const auto& running_var = mod.attr("running_var").toTensor();
const TorchTensorProxy& running_mean = mod.attr("running_mean");
const TorchTensorProxy& running_var = mod.attr("running_var");

op->params["num_features"] = running_mean.size(0);
op->params["eps"] = bn->namedInput("eps");
@@ -46,8 +44,8 @@ public:
op->attrs["running_var"] = running_var;
if (mod.hasattr("weight") && mod.hasattr("bias"))
{
op->attrs["weight"] = mod.attr("weight").toTensor();
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["weight"] = mod.attr("weight");
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 7
- 9
tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,12 +29,12 @@ public:
return "nn.BatchNorm3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm");
const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm");

const auto& running_mean = mod.attr("running_mean").toTensor();
const auto& running_var = mod.attr("running_var").toTensor();
const TorchTensorProxy& running_mean = mod.attr("running_mean");
const TorchTensorProxy& running_var = mod.attr("running_var");

op->params["num_features"] = running_mean.size(0);
op->params["eps"] = bn->namedInput("eps");
@@ -46,8 +44,8 @@ public:
op->attrs["running_var"] = running_var;
if (mod.hasattr("weight") && mod.hasattr("bias"))
{
op->attrs["weight"] = mod.attr("weight").toTensor();
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["weight"] = mod.attr("weight");
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 3
- 5
tools/pnnx/src/pass_level1/nn_CELU.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.CELU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* celu = find_node_by_kind(graph, "aten::celu");
const TorchNodeProxy* celu = graph.find_node_by_kind("aten::celu");

op->params["alpha"] = celu->namedInput("alpha");
}


+ 3
- 5
tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.ChannelShuffle";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* channel_shuffle = find_node_by_kind(graph, "aten::channel_shuffle");
const TorchNodeProxy* channel_shuffle = graph.find_node_by_kind("aten::channel_shuffle");

op->params["groups"] = channel_shuffle->namedInput("groups");
}


+ 4
- 6
tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ConstantPad1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd");

if (!pad)
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ConstantPad2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd");

if (!pad)
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ConstantPad3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd");

if (!pad)
{


+ 9
- 13
tools/pnnx/src/pass_level1/nn_Conv1d.cpp View File

@@ -12,11 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

// #include "../pass_level3/fuse_expression.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -33,7 +29,7 @@ public:
return "nn.Conv1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// {
// pnnx::Graph pnnx_graph;
@@ -45,18 +41,18 @@ public:
// pnnx_graph.save("tmp.param", "tmp.bin");
// }

const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution");
const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode");
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* reflection_pad1d = find_node_by_kind(graph, "aten::reflection_pad1d");
const torch::jit::Node* replication_pad1d = find_node_by_kind(graph, "aten::replication_pad1d");
const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution");
const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* reflection_pad1d = graph.find_node_by_kind("aten::reflection_pad1d");
const TorchNodeProxy* replication_pad1d = graph.find_node_by_kind("aten::replication_pad1d");

if (convolution_mode)
{
convolution = convolution_mode;
}

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["groups"] = convolution->namedInput("groups");
op->params["in_channels"] = weight.size(1) * op->params["groups"].i;
@@ -131,7 +127,7 @@ public:
op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
{
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 11
- 15
tools/pnnx/src/pass_level1/nn_Conv2d.cpp View File

@@ -12,11 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

// #include "../pass_level3/fuse_expression.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -33,7 +29,7 @@ public:
return "nn.Conv2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// {
// pnnx::Graph pnnx_graph;
@@ -45,18 +41,18 @@ public:
// pnnx_graph.save("tmp.param", "tmp.bin");
// }

const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution");
const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode");
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d");
const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d");
const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution");
const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* reflection_pad2d = graph.find_node_by_kind("aten::reflection_pad2d");
const TorchNodeProxy* replication_pad2d = graph.find_node_by_kind("aten::replication_pad2d");

if (convolution_mode)
{
convolution = convolution_mode;
}

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["groups"] = convolution->namedInput("groups");
op->params["in_channels"] = weight.size(1) * op->params["groups"].i;
@@ -126,12 +122,12 @@ public:
op->params["padding"] = convolution->namedInput("padding");
}
op->params["dilation"] = convolution->namedInput("dilation");
op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor();
op->params["bias"] = mod.hasattr("bias");

op->attrs["weight"] = weight;
if (mod.hasattr("bias") && mod.attr("bias").isTensor())
if (mod.hasattr("bias"))
{
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 9
- 13
tools/pnnx/src/pass_level1/nn_Conv3d.cpp View File

@@ -12,11 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

// #include "../pass_level3/fuse_expression.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -33,7 +29,7 @@ public:
return "nn.Conv3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// {
// pnnx::Graph pnnx_graph;
@@ -45,18 +41,18 @@ public:
// pnnx_graph.save("tmp.param", "tmp.bin");
// }

const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution");
const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode");
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d");
const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d");
const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution");
const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* reflection_pad3d = graph.find_node_by_kind("aten::reflection_pad3d");
const TorchNodeProxy* replication_pad3d = graph.find_node_by_kind("aten::replication_pad3d");

if (convolution_mode)
{
convolution = convolution_mode;
}

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["groups"] = convolution->namedInput("groups");
op->params["in_channels"] = weight.size(1) * op->params["groups"].i;
@@ -131,7 +127,7 @@ public:
op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
{
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 5
- 7
tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "nn.ConvTranspose1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution");
const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution");

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["groups"] = convolution->namedInput("groups");
op->params["in_channels"] = weight.size(0);
@@ -50,7 +48,7 @@ public:
op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
{
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}

if (op->inputs.size() > 1)


+ 5
- 7
tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "nn.ConvTranspose2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution");
const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution");

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["groups"] = convolution->namedInput("groups");
op->params["in_channels"] = weight.size(0);
@@ -50,7 +48,7 @@ public:
op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
{
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}

if (op->inputs.size() > 1)


+ 5
- 7
tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "nn.ConvTranspose3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution");
const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution");

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["groups"] = convolution->namedInput("groups");
op->params["in_channels"] = weight.size(0);
@@ -50,7 +48,7 @@ public:
op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
{
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}

if (op->inputs.size() > 1)


+ 1
- 3
tools/pnnx/src/pass_level1/nn_Dropout.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 1
- 3
tools/pnnx/src/pass_level1/nn_Dropout2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 1
- 3
tools/pnnx/src/pass_level1/nn_Dropout3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 3
- 5
tools/pnnx/src/pass_level1/nn_ELU.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.ELU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* elu = find_node_by_kind(graph, "aten::elu");
const TorchNodeProxy* elu = graph.find_node_by_kind("aten::elu");

op->params["alpha"] = elu->namedInput("alpha");
}


+ 4
- 6
tools/pnnx/src/pass_level1/nn_Embedding.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "nn.Embedding";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* embedding = find_node_by_kind(graph, "aten::embedding");
const TorchNodeProxy* embedding = graph.find_node_by_kind("aten::embedding");

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["num_embeddings"] = weight.size(0);
op->params["embedding_dim"] = weight.size(1);


+ 3
- 5
tools/pnnx/src/pass_level1/nn_Fold.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Fold";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* col2im = find_node_by_kind(graph, "aten::col2im");
const TorchNodeProxy* col2im = graph.find_node_by_kind("aten::col2im");

op->params["output_size"] = col2im->namedInput("output_size");
op->params["kernel_size"] = col2im->namedInput("kernel_size");


+ 3
- 5
tools/pnnx/src/pass_level1/nn_GELU.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.GELU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* gelu = find_node_by_kind(graph, "aten::gelu");
const TorchNodeProxy* gelu = graph.find_node_by_kind("aten::gelu");

if (gelu->hasNamedInput("approximate"))
{


+ 3
- 5
tools/pnnx/src/pass_level1/nn_GLU.cpp View File

@@ -13,9 +13,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -32,9 +30,9 @@ public:
return "nn.GLU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* glu = find_node_by_kind(graph, "aten::glu");
const TorchNodeProxy* glu = graph.find_node_by_kind("aten::glu");

op->params["dim"] = glu->namedInput("dim");
}


+ 15
- 17
tools/pnnx/src/pass_level1/nn_GRU.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,17 +29,17 @@ public:
return "nn.GRU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// mod.dump(true, true, true);

// graph->dump();

const torch::jit::Node* gru = find_node_by_kind(graph, "aten::gru");
const TorchNodeProxy* gru = graph.find_node_by_kind("aten::gru");

const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct");
if (return_tuple && return_tuple->inputs().size() == 2 && gru->outputs().size() == 2
&& return_tuple->inputs()[0] == gru->outputs()[1] && return_tuple->inputs()[1] == gru->outputs()[0])
const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct");
if (return_tuple && return_tuple->input_count() == 2 && gru->output_count() == 2
&& return_tuple->input(0) == gru->output(1) && return_tuple->input(1) == gru->output(0))
{
// mark the swapped output tuple
// we would restore the fine order in pass_level3/fuse_rnn_unpack
@@ -54,7 +52,7 @@ public:
// fprintf(stderr, "arg %s\n", aa.name().c_str());
// }

const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor();
const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0");

op->params["input_size"] = weight_ih_l0.size(1);
op->params["hidden_size"] = weight_ih_l0.size(0) / 3;
@@ -72,16 +70,16 @@ public:
std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k);
std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k);

op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor();
op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor();
op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key);
op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key);

if (bias)
{
std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k);
std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k);

op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor();
op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor();
op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key);
op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key);
}

if (bidirectional)
@@ -89,16 +87,16 @@ public:
std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse";
std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse";

op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor();
op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor();
op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key);
op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key);

if (bias)
{
std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse";
std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse";

op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor();
op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor();
op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key);
op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key);
}
}
}


+ 5
- 7
tools/pnnx/src/pass_level1/nn_GroupNorm.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "nn.GroupNorm";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// graph->dump();

const torch::jit::Node* gn = find_node_by_kind(graph, "aten::group_norm");
const TorchNodeProxy* gn = graph.find_node_by_kind("aten::group_norm");

// for (auto aa : gn->schema().arguments())
// {
@@ -48,12 +46,12 @@ public:

if (mod.hasattr("weight") && mod.hasattr("bias"))
{
const auto& weight = mod.attr("weight").toTensor();
const auto& weight = mod.attr("weight");

op->params["num_channels"] = weight.size(0);

op->attrs["weight"] = weight;
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}
else
{


+ 3
- 5
tools/pnnx/src/pass_level1/nn_Hardshrink.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Hardshrink";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* hardshrink = find_node_by_kind(graph, "aten::hardshrink");
const TorchNodeProxy* hardshrink = graph.find_node_by_kind("aten::hardshrink");

op->params["lambd"] = hardshrink->namedInput("lambd");
}


+ 1
- 1
tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 1
- 1
tools/pnnx/src/pass_level1/nn_Hardswish.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 3
- 5
tools/pnnx/src/pass_level1/nn_Hardtanh.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Hardtanh";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* hardtanh = find_node_by_kind(graph, "aten::hardtanh");
const TorchNodeProxy* hardtanh = graph.find_node_by_kind("aten::hardtanh");

op->params["min_val"] = hardtanh->namedInput("min_val");
op->params["max_val"] = hardtanh->namedInput("max_val");


+ 7
- 9
tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "nn.InstanceNorm1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// graph->dump();

const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm");
const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm");

// for (auto aa : in->schema().arguments())
// {
@@ -48,22 +46,22 @@ public:

if (mod.hasattr("weight") && mod.hasattr("bias"))
{
const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["num_features"] = weight.size(0);

op->attrs["weight"] = weight;
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}

if (mod.hasattr("running_mean") && mod.hasattr("running_var"))
{
const auto& running_mean = mod.attr("running_mean").toTensor();
const TorchTensorProxy& running_mean = mod.attr("running_mean");

op->params["num_features"] = running_mean.size(0);

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = mod.attr("running_var").toTensor();
op->attrs["running_var"] = mod.attr("running_var");
}

// take num_features from input shape


+ 7
- 9
tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "nn.InstanceNorm2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// graph->dump();

const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm");
const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm");

// for (auto aa : in->schema().arguments())
// {
@@ -48,22 +46,22 @@ public:

if (mod.hasattr("weight") && mod.hasattr("bias"))
{
const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["num_features"] = weight.size(0);

op->attrs["weight"] = weight;
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}

if (mod.hasattr("running_mean") && mod.hasattr("running_var"))
{
const auto& running_mean = mod.attr("running_mean").toTensor();
const TorchTensorProxy& running_mean = mod.attr("running_mean");

op->params["num_features"] = running_mean.size(0);

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = mod.attr("running_var").toTensor();
op->attrs["running_var"] = mod.attr("running_var");
}

// take num_features from input shape


+ 7
- 9
tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "nn.InstanceNorm3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// graph->dump();

const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm");
const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm");

// for (auto aa : in->schema().arguments())
// {
@@ -48,22 +46,22 @@ public:

if (mod.hasattr("weight") && mod.hasattr("bias"))
{
const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["num_features"] = weight.size(0);

op->attrs["weight"] = weight;
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}

if (mod.hasattr("running_mean") && mod.hasattr("running_var"))
{
const auto& running_mean = mod.attr("running_mean").toTensor();
const TorchTensorProxy& running_mean = mod.attr("running_mean");

op->params["num_features"] = running_mean.size(0);

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = mod.attr("running_var").toTensor();
op->attrs["running_var"] = mod.attr("running_var");
}

// take num_features from input shape


+ 11
- 10
tools/pnnx/src/pass_level1/nn_LPPool1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,21 +29,24 @@ public:
return "nn.LPPool1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow");
op->params["norm_type"] = pow->inputs()[1];
const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow");
op->params["norm_type"] = pow->input(1);

const TorchNodeProxy* avg_pool1d = graph.find_node_by_kind("aten::avg_pool1d");

const torch::jit::Node* avg_pool1d = find_node_by_kind(graph, "aten::avg_pool1d");
const TorchNodeProxy* kernel_size = graph.find_producer_node_by_value(avg_pool1d->namedInput("kernel_size"));
const TorchNodeProxy* stride = graph.find_producer_node_by_value(avg_pool1d->namedInput("stride"));

op->params["kernel_size"] = avg_pool1d->namedInput("kernel_size")->node()->inputs()[0];
if (avg_pool1d->namedInput("stride")->node()->inputs().size() == 0)
op->params["kernel_size"] = kernel_size->input(0);
if (stride->input_count() == 0)
{
op->params["stride"] = op->params["kernel_size"];
}
else
{
op->params["stride"] = avg_pool1d->namedInput("stride")->node()->inputs()[0];
op->params["stride"] = stride->input(0);
}
op->params["ceil_mode"] = avg_pool1d->namedInput("ceil_mode");
}


+ 8
- 8
tools/pnnx/src/pass_level1/nn_LPPool2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,15 +29,17 @@ public:
return "nn.LPPool2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow");
op->params["norm_type"] = pow->inputs()[1];
const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow");
op->params["norm_type"] = pow->input(1);

const TorchNodeProxy* avg_pool2d = graph.find_node_by_kind("aten::avg_pool2d");

const torch::jit::Node* avg_pool2d = find_node_by_kind(graph, "aten::avg_pool2d");
const TorchNodeProxy* stride = graph.find_producer_node_by_value(avg_pool2d->namedInput("stride"));

op->params["kernel_size"] = avg_pool2d->namedInput("kernel_size");
if (avg_pool2d->namedInput("stride")->node()->inputs().size() == 0)
if (stride->input_count() == 0)
{
op->params["stride"] = op->params["kernel_size"];
}


+ 18
- 20
tools/pnnx/src/pass_level1/nn_LSTM.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,17 +29,17 @@ public:
return "nn.LSTM";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// mod.dump(true, true, true);
//
// graph->dump();

const torch::jit::Node* lstm = find_node_by_kind(graph, "aten::lstm");
const TorchNodeProxy* lstm = graph.find_node_by_kind("aten::lstm");

const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct");
if (return_tuple && return_tuple->inputs().size() == 3 && lstm->outputs().size() == 3
&& return_tuple->inputs()[0] == lstm->outputs()[1] && return_tuple->inputs()[1] == lstm->outputs()[2] && return_tuple->inputs()[2] == lstm->outputs()[0])
const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct");
if (return_tuple && return_tuple->input_count() == 3 && lstm->output_count() == 3
&& return_tuple->input(0) == lstm->output(1) && return_tuple->input(1) == lstm->output(2) && return_tuple->input(2) == lstm->output(0))
{
// mark the swapped output tuple
// we would restore the fine order in pass_level3/fuse_rnn_unpack
@@ -54,8 +52,8 @@ public:
// fprintf(stderr, "arg %s\n", aa.name().c_str());
// }

const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor();
const auto& weight_hh_l0 = mod.attr("weight_hh_l0").toTensor();
const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0");
const TorchTensorProxy& weight_hh_l0 = mod.attr("weight_hh_l0");

op->params["input_size"] = weight_ih_l0.size(1);
op->params["hidden_size"] = weight_ih_l0.size(0) / 4;
@@ -75,23 +73,23 @@ public:
std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k);
std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k);

op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor();
op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor();
op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key);
op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key);

if (bias)
{
std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k);
std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k);

op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor();
op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor();
op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key);
op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key);
}

if (proj_size > 0)
{
std::string weight_hr_lk_key = std::string("weight_hr_l") + std::to_string(k);

op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key).toTensor();
op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key);
}

if (bidirectional)
@@ -99,23 +97,23 @@ public:
std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse";
std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse";

op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor();
op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor();
op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key);
op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key);

if (bias)
{
std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse";
std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse";

op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor();
op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor();
op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key);
op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key);
}

if (proj_size > 0)
{
std::string weight_hr_lk_reverse_key = std::string("weight_hr_l") + std::to_string(k) + "_reverse";

op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key).toTensor();
op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key);
}
}
}


+ 5
- 7
tools/pnnx/src/pass_level1/nn_LayerNorm.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.LayerNorm";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* ln = find_node_by_kind(graph, "aten::layer_norm");
const TorchNodeProxy* ln = graph.find_node_by_kind("aten::layer_norm");

op->params["normalized_shape"] = ln->namedInput("normalized_shape");
op->params["eps"] = ln->namedInput("eps");
@@ -41,8 +39,8 @@ public:

if (mod.hasattr("weight") && mod.hasattr("bias"))
{
op->attrs["weight"] = mod.attr("weight").toTensor();
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["weight"] = mod.attr("weight");
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 4
- 6
tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.LeakyReLU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* leaky_relu = find_node_by_kind(graph, "aten::leaky_relu");
const torch::jit::Node* leaky_relu_ = find_node_by_kind(graph, "aten::leaky_relu_");
const TorchNodeProxy* leaky_relu = graph.find_node_by_kind("aten::leaky_relu");
const TorchNodeProxy* leaky_relu_ = graph.find_node_by_kind("aten::leaky_relu_");

if (leaky_relu_)
{


+ 7
- 9
tools/pnnx/src/pass_level1/nn_Linear.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,20 +29,20 @@ public:
return "nn.Linear";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& /*graph*/, const TorchModuleProxy& mod) const
{
const torch::jit::Node* addmm = find_node_by_kind(graph, "aten::addmm");
// const TorchNodeProxy* addmm = graph.find_node_by_kind("aten::addmm");

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["in_features"] = weight.size(1);
op->params["out_features"] = weight.size(0);
op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor();
op->params["bias"] = mod.hasattr("bias");

op->attrs["weight"] = weight;
if (mod.hasattr("bias") && mod.attr("bias").isTensor())
if (mod.hasattr("bias"))
{
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 12
- 13
tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,26 +29,27 @@ public:
return "nn.LocalResponseNorm";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* avg_pool = find_node_by_kind(graph, "aten::avg_pool2d");
const torch::jit::Node* avg_pool3d = find_node_by_kind(graph, "aten::avg_pool3d");
const TorchNodeProxy* avg_pool = graph.find_node_by_kind("aten::avg_pool2d");
const TorchNodeProxy* avg_pool3d = graph.find_node_by_kind("aten::avg_pool3d");

if (avg_pool3d)
{
avg_pool = avg_pool3d;
}

op->params["size"] = avg_pool->namedInput("kernel_size")->node()->inputs()[0];
const TorchNodeProxy* kernel_size = graph.find_producer_node_by_value(avg_pool->namedInput("kernel_size"));
op->params["size"] = kernel_size->input(0);

const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow");
op->params["beta"] = pow->inputs()[1];
const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow");
op->params["beta"] = pow->input(1);

const torch::jit::Node* add = pow->inputs()[0]->node();
op->params["k"] = add->inputs()[1];
const TorchNodeProxy* add = graph.find_producer_node_by_value(pow->input(0));
op->params["k"] = add->input(1);

const torch::jit::Node* mul = add->inputs()[0]->node();
op->params["alpha"] = mul->inputs()[1];
const TorchNodeProxy* mul = graph.find_producer_node_by_value(add->input(0));
op->params["alpha"] = mul->input(1);
}
};



+ 1
- 3
tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 3
- 5
tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.LogSoftmax";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* log_softmax = find_node_by_kind(graph, "aten::log_softmax");
const TorchNodeProxy* log_softmax = graph.find_node_by_kind("aten::log_softmax");

op->params["dim"] = log_softmax->namedInput("dim");
}


+ 4
- 6
tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.MaxPool1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* max_pool1d = find_node_by_kind(graph, "aten::max_pool1d");
const torch::jit::Node* max_pool1d_with_indices = find_node_by_kind(graph, "aten::max_pool1d_with_indices");
const TorchNodeProxy* max_pool1d = graph.find_node_by_kind("aten::max_pool1d");
const TorchNodeProxy* max_pool1d_with_indices = graph.find_node_by_kind("aten::max_pool1d_with_indices");

if (max_pool1d_with_indices)
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.MaxPool2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* max_pool2d = find_node_by_kind(graph, "aten::max_pool2d");
const torch::jit::Node* max_pool2d_with_indices = find_node_by_kind(graph, "aten::max_pool2d_with_indices");
const TorchNodeProxy* max_pool2d = graph.find_node_by_kind("aten::max_pool2d");
const TorchNodeProxy* max_pool2d_with_indices = graph.find_node_by_kind("aten::max_pool2d_with_indices");

if (max_pool2d_with_indices)
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.MaxPool3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* max_pool3d = find_node_by_kind(graph, "aten::max_pool3d");
const torch::jit::Node* max_pool3d_with_indices = find_node_by_kind(graph, "aten::max_pool3d_with_indices");
const TorchNodeProxy* max_pool3d = graph.find_node_by_kind("aten::max_pool3d");
const TorchNodeProxy* max_pool3d_with_indices = graph.find_node_by_kind("aten::max_pool3d_with_indices");

if (max_pool3d_with_indices)
{


+ 1
- 1
tools/pnnx/src/pass_level1/nn_Mish.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 40
- 35
tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp View File

@@ -12,11 +12,13 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include <torch/csrc/api/include/torch/torch.h>
// #include "pass_level1.h"
//
// #include <torch/csrc/api/include/torch/torch.h>
//
// #include "../utils.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -33,19 +35,19 @@ public:
return "nn.MultiheadAttention";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// mod.dump(false, false, false);
// graph->dump();

const torch::jit::Node* multi_head_attention = find_node_by_kind(graph, "aten::_native_multi_head_attention");
const TorchNodeProxy* multi_head_attention = graph.find_node_by_kind("aten::_native_multi_head_attention");
if (multi_head_attention)
{
op->params["num_heads"] = multi_head_attention->namedInput("num_head");
op->params["batch_first"] = true;
op->params["add_zero_attn"] = false;

if (multi_head_attention->hasNamedInput("mask") && multi_head_attention->namedInput("mask") == graph->inputs()[graph->inputs().size() - 1])
if (multi_head_attention->hasNamedInput("mask") && multi_head_attention->namedInput("mask") == graph.input(graph.input_count() - 1))
{
size_t input_count = op->inputs.size();
op->inputnames.resize(input_count);
@@ -54,25 +56,28 @@ public:
}
else
{
const torch::jit::Node* div_num_heads = find_node_by_kind(graph, "aten::div");
const torch::jit::Node* div_num_heads_18 = find_node_by_kind(graph, "aten::floor_divide");
const TorchNodeProxy* div_num_heads = graph.find_node_by_kind("aten::div");
const TorchNodeProxy* div_num_heads_18 = graph.find_node_by_kind("aten::floor_divide");
if (div_num_heads_18)
{
div_num_heads = div_num_heads_18;
}

op->params["num_heads"] = (int)div_num_heads->input(1)->node()->t(torch::jit::attr::value).item<int64_t>();
// const TorchNodeProxy* div_num_heads_input_1 = graph.find_producer_node_by_value(div_num_heads->input(1));

// op->params["num_heads"] = (int)div_num_heads_input_1->t(torch::jit::attr::value).item<int64_t>();
op->params["num_heads"] = div_num_heads->input(1);

const torch::jit::Node* transpose_batch_seq = find_node_by_kind(graph, "aten::transpose");
const TorchNodeProxy* transpose_batch_seq = graph.find_node_by_kind("aten::transpose");

int transpose_dim0 = transpose_batch_seq->input(1)->node()->i(torch::jit::attr::value);
int transpose_dim1 = transpose_batch_seq->input(2)->node()->i(torch::jit::attr::value);
if (transpose_dim0 == 1 && transpose_dim1 == 0)
Parameter transpose_dim0 = transpose_batch_seq->input(1);
Parameter transpose_dim1 = transpose_batch_seq->input(2);
if (transpose_dim0.i == 1 && transpose_dim1.i == 0)
{
op->params["batch_first"] = true;
}

const torch::jit::Node* add_zero_attn = find_node_by_kind(graph, "aten::zeros");
const TorchNodeProxy* add_zero_attn = graph.find_node_by_kind("aten::zeros");
if (add_zero_attn)
{
op->params["add_zero_attn"] = true;
@@ -82,10 +87,10 @@ public:
op->params["add_zero_attn"] = false;
}

const torch::jit::Node* scaled_dot_product_attention = find_node_by_kind(graph, "aten::scaled_dot_product_attention");
const TorchNodeProxy* scaled_dot_product_attention = graph.find_node_by_kind("aten::scaled_dot_product_attention");
if (scaled_dot_product_attention)
{
if (scaled_dot_product_attention->input(3)->type()->kind() != c10::TypeKind::NoneType)
if (!scaled_dot_product_attention->is_input_none(3))
{
size_t input_count = op->inputs.size();
op->inputnames.resize(input_count);
@@ -94,7 +99,7 @@ public:
}

// find attention mask addition pattern pre torch-2.1
const torch::jit::Node* has_attn_mask = find_node_by_kind(graph, "aten::baddbmm");
const TorchNodeProxy* has_attn_mask = graph.find_node_by_kind("aten::baddbmm");
if (has_attn_mask)
{
size_t input_count = op->inputs.size();
@@ -106,14 +111,14 @@ public:
// attn = torch.bmm(Q, K)
// input0 = torch.add_(attn, attn_mask)
// attn0 = torch.softmax(input0, -1)
const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax");
const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax");
if (softmax)
{
const torch::jit::Node* add_ = softmax->input(0)->node();
if (add_ && add_->kind().toDisplayString() == std::string("aten::add_"))
const TorchNodeProxy* add_ = graph.find_producer_node_by_value(softmax->input(0));
if (add_ && add_->kind() == "aten::add_")
{
const torch::jit::Node* bmm = add_->input(0)->node();
if (bmm && bmm->kind().toDisplayString() == std::string("aten::bmm"))
const TorchNodeProxy* bmm = graph.find_producer_node_by_value(add_->input(0));
if (bmm && bmm->kind() == "aten::bmm")
{
size_t input_count = op->inputs.size();
op->inputnames.resize(input_count);
@@ -125,7 +130,7 @@ public:

if (mod.hasattr("in_proj_weight"))
{
const auto& in_proj_weight = mod.attr("in_proj_weight").toTensor();
const TorchTensorProxy& in_proj_weight = mod.attr("in_proj_weight");

op->params["embed_dim"] = in_proj_weight.size(1);
op->params["kdim"] = in_proj_weight.size(1);
@@ -134,9 +139,9 @@ public:
}
else
{
const auto& q_proj_weight = mod.attr("q_proj_weight").toTensor();
const auto& k_proj_weight = mod.attr("k_proj_weight").toTensor();
const auto& v_proj_weight = mod.attr("v_proj_weight").toTensor();
const TorchTensorProxy& q_proj_weight = mod.attr("q_proj_weight");
const TorchTensorProxy& k_proj_weight = mod.attr("k_proj_weight");
const TorchTensorProxy& v_proj_weight = mod.attr("v_proj_weight");

op->params["embed_dim"] = q_proj_weight.size(1);
op->params["kdim"] = k_proj_weight.size(1);
@@ -146,15 +151,15 @@ public:
op->attrs["v_proj_weight"] = v_proj_weight;
}

const auto& out_proj_weight = mod.attr("out_proj").toModule().attr("weight").toTensor();
const TorchTensorProxy& out_proj_weight = mod.attr("out_proj.weight");

op->attrs["out_proj.weight"] = out_proj_weight;

if (mod.hasattr("in_proj_bias") && mod.attr("out_proj").toModule().hasattr("bias"))
if (mod.hasattr("in_proj_bias") && mod.hasattr("out_proj.bias"))
{
// bias=True
const auto& in_proj_bias = mod.attr("in_proj_bias").toTensor();
const auto& out_proj_bias = mod.attr("out_proj").toModule().attr("bias").toTensor();
const TorchTensorProxy& in_proj_bias = mod.attr("in_proj_bias");
const TorchTensorProxy& out_proj_bias = mod.attr("out_proj.bias");

op->params["bias"] = true;
op->attrs["in_proj_bias"] = in_proj_bias;
@@ -166,9 +171,9 @@ public:

// the output projection bias always there no matter bias is False in pytorch 1.8
// this behavior changes since https://github.com/pytorch/pytorch/commit/58d1b3639bc07f9519de18e5a18e575f260c7eeb
if (mod.attr("out_proj").toModule().hasattr("bias"))
if (mod.hasattr("out_proj.bias"))
{
const auto& out_proj_bias = mod.attr("out_proj").toModule().attr("bias").toTensor();
const TorchTensorProxy& out_proj_bias = mod.attr("out_proj.bias");
op->attrs["out_proj.bias"] = out_proj_bias;
}
}
@@ -176,8 +181,8 @@ public:
if (mod.hasattr("bias_k") && mod.hasattr("bias_v"))
{
// add_bias_kv=True
const auto& bias_k = mod.attr("bias_k").toTensor();
const auto& bias_v = mod.attr("bias_v").toTensor();
const TorchTensorProxy& bias_k = mod.attr("bias_k");
const TorchTensorProxy& bias_v = mod.attr("bias_v");

op->params["add_bias_kv"] = true;
op->attrs["bias_k"] = bias_k;


+ 3
- 5
tools/pnnx/src/pass_level1/nn_PReLU.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.PReLU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& /*graph*/, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& /*graph*/, const TorchModuleProxy& mod) const
{
const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

op->params["num_parameters"] = weight.size(0);



+ 3
- 5
tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.PixelShuffle";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pixel_shuffle = find_node_by_kind(graph, "aten::pixel_shuffle");
const TorchNodeProxy* pixel_shuffle = graph.find_node_by_kind("aten::pixel_shuffle");

op->params["upscale_factor"] = pixel_shuffle->namedInput("upscale_factor");
}


+ 3
- 5
tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.PixelUnshuffle";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pixel_unshuffle = find_node_by_kind(graph, "aten::pixel_unshuffle");
const TorchNodeProxy* pixel_unshuffle = graph.find_node_by_kind("aten::pixel_unshuffle");

op->params["downscale_factor"] = pixel_unshuffle->namedInput("downscale_factor");
}


+ 4
- 6
tools/pnnx/src/pass_level1/nn_RMSNorm.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.RMSNorm";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* rmsn = find_node_by_kind(graph, "aten::rms_norm");
const TorchNodeProxy* rmsn = graph.find_node_by_kind("aten::rms_norm");

op->params["normalized_shape"] = rmsn->namedInput("normalized_shape");
op->params["eps"] = rmsn->namedInput("eps");
@@ -41,7 +39,7 @@ public:

if (mod.hasattr("weight"))
{
op->attrs["weight"] = mod.attr("weight").toTensor();
op->attrs["weight"] = mod.attr("weight");
}
}
};


+ 16
- 18
tools/pnnx/src/pass_level1/nn_RNN.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,23 +29,23 @@ public:
return "nn.RNN";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
// mod.dump(true, true, true);

// graph->dump();

const torch::jit::Node* rnn = find_node_by_kind(graph, "aten::rnn_tanh");
const torch::jit::Node* rnn_relu = find_node_by_kind(graph, "aten::rnn_relu");
const TorchNodeProxy* rnn = graph.find_node_by_kind("aten::rnn_tanh");
const TorchNodeProxy* rnn_relu = graph.find_node_by_kind("aten::rnn_relu");

if (rnn_relu)
{
rnn = rnn_relu;
}

const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct");
if (return_tuple && return_tuple->inputs().size() == 2 && rnn->outputs().size() == 2
&& return_tuple->inputs()[0] == rnn->outputs()[1] && return_tuple->inputs()[1] == rnn->outputs()[0])
const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct");
if (return_tuple && return_tuple->input_count() == 2 && rnn->output_count() == 2
&& return_tuple->input(0) == rnn->output(1) && return_tuple->input(1) == rnn->output(0))
{
// mark the swapped output tuple
// we would restore the fine order in pass_level3/fuse_rnn_unpack
@@ -60,7 +58,7 @@ public:
// fprintf(stderr, "arg %s\n", aa.name().c_str());
// }

const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor();
const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0");

op->params["input_size"] = weight_ih_l0.size(1);
op->params["hidden_size"] = weight_ih_l0.size(0);
@@ -79,16 +77,16 @@ public:
std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k);
std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k);

op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor();
op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor();
op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key);
op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key);

if (bias)
{
std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k);
std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k);

op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor();
op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor();
op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key);
op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key);
}

if (bidirectional)
@@ -96,16 +94,16 @@ public:
std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse";
std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse";

op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor();
op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor();
op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key);
op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key);

if (bias)
{
std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse";
std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse";

op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor();
op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor();
op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key);
op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key);
}
}
}


+ 3
- 5
tools/pnnx/src/pass_level1/nn_RReLU.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.RReLU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* rrelu = find_node_by_kind(graph, "aten::rrelu");
const TorchNodeProxy* rrelu = graph.find_node_by_kind("aten::rrelu");

op->params["lower"] = rrelu->namedInput("lower");
op->params["upper"] = rrelu->namedInput("upper");


+ 1
- 1
tools/pnnx/src/pass_level1/nn_ReLU.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 1
- 1
tools/pnnx/src/pass_level1/nn_ReLU6.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 4
- 6
tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ReflectionPad1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* reflection_pad1d = find_node_by_kind(graph, "aten::reflection_pad1d");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* reflection_pad1d = graph.find_node_by_kind("aten::reflection_pad1d");

if (pad)
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ReflectionPad2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* reflection_pad2d = graph.find_node_by_kind("aten::reflection_pad2d");

if (pad)
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ReplicationPad1d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* replication_pad1d = find_node_by_kind(graph, "aten::replication_pad1d");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* replication_pad1d = graph.find_node_by_kind("aten::replication_pad1d");

if (pad)
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ReplicationPad2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* replication_pad2d = graph.find_node_by_kind("aten::replication_pad2d");

if (pad)
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ReplicationPad3d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* replication_pad3d = graph.find_node_by_kind("aten::replication_pad3d");

if (pad)
{


+ 1
- 3
tools/pnnx/src/pass_level1/nn_SELU.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 1
- 1
tools/pnnx/src/pass_level1/nn_SiLU.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 1
- 1
tools/pnnx/src/pass_level1/nn_Sigmoid.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 3
- 5
tools/pnnx/src/pass_level1/nn_Softmax.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Softmax";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax");
const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax");

op->params["dim"] = softmax->namedInput("dim");
}


+ 1
- 3
tools/pnnx/src/pass_level1/nn_Softmax2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 3
- 5
tools/pnnx/src/pass_level1/nn_Softmin.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Softmin";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax");
const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax");

op->params["dim"] = softmax->namedInput("dim");
}


+ 3
- 5
tools/pnnx/src/pass_level1/nn_Softplus.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Softplus";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* softplus = find_node_by_kind(graph, "aten::softplus");
const TorchNodeProxy* softplus = graph.find_node_by_kind("aten::softplus");

op->params["beta"] = softplus->namedInput("beta");
op->params["threshold"] = softplus->namedInput("threshold");


+ 3
- 5
tools/pnnx/src/pass_level1/nn_Softshrink.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Softshrink";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* softshrink = find_node_by_kind(graph, "aten::softshrink");
const TorchNodeProxy* softshrink = graph.find_node_by_kind("aten::softshrink");

op->params["lambd"] = softshrink->namedInput("lambd");
}


+ 1
- 1
tools/pnnx/src/pass_level1/nn_Softsign.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 1
- 1
tools/pnnx/src/pass_level1/nn_Tanh.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 1
- 1
tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

namespace pnnx {



+ 3
- 5
tools/pnnx/src/pass_level1/nn_Threshold.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Threshold";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* threshold = find_node_by_kind(graph, "aten::threshold");
const TorchNodeProxy* threshold = graph.find_node_by_kind("aten::threshold");

op->params["threshold"] = threshold->namedInput("threshold");
op->params["value"] = threshold->namedInput("value");


+ 3
- 5
tools/pnnx/src/pass_level1/nn_Unfold.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.Unfold";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* im2col = find_node_by_kind(graph, "aten::im2col");
const TorchNodeProxy* im2col = graph.find_node_by_kind("aten::im2col");

op->params["kernel_size"] = im2col->namedInput("kernel_size");
op->params["stride"] = im2col->namedInput("stride");


+ 21
- 18
tools/pnnx/src/pass_level1/nn_Upsample.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,19 +29,19 @@ public:
return "nn.Upsample";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* upsample_nearest1d = find_node_by_kind(graph, "aten::upsample_nearest1d");
const torch::jit::Node* upsample_linear1d = find_node_by_kind(graph, "aten::upsample_linear1d");
const TorchNodeProxy* upsample_nearest1d = graph.find_node_by_kind("aten::upsample_nearest1d");
const TorchNodeProxy* upsample_linear1d = graph.find_node_by_kind("aten::upsample_linear1d");

const torch::jit::Node* upsample_nearest2d = find_node_by_kind(graph, "aten::upsample_nearest2d");
const torch::jit::Node* upsample_bilinear2d = find_node_by_kind(graph, "aten::upsample_bilinear2d");
const torch::jit::Node* upsample_bicubic2d = find_node_by_kind(graph, "aten::upsample_bicubic2d");
const TorchNodeProxy* upsample_nearest2d = graph.find_node_by_kind("aten::upsample_nearest2d");
const TorchNodeProxy* upsample_bilinear2d = graph.find_node_by_kind("aten::upsample_bilinear2d");
const TorchNodeProxy* upsample_bicubic2d = graph.find_node_by_kind("aten::upsample_bicubic2d");

const torch::jit::Node* upsample_nearest3d = find_node_by_kind(graph, "aten::upsample_nearest3d");
const torch::jit::Node* upsample_trilinear3d = find_node_by_kind(graph, "aten::upsample_trilinear3d");
const TorchNodeProxy* upsample_nearest3d = graph.find_node_by_kind("aten::upsample_nearest3d");
const TorchNodeProxy* upsample_trilinear3d = graph.find_node_by_kind("aten::upsample_trilinear3d");

const torch::jit::Node* upsample = 0;
const TorchNodeProxy* upsample = 0;
if (upsample_nearest1d)
{
upsample = upsample_nearest1d;
@@ -136,12 +134,17 @@ public:
std::vector<float> scale_factor;
try
{
const torch::jit::Node* size_list = find_node_by_kind(graph, "prim::ListConstruct");
for (auto x : size_list->inputs())
const TorchNodeProxy* size_list = graph.find_node_by_kind("prim::ListConstruct");
const int size_list_input_count = size_list->input_count();
for (int i = 0; i < size_list_input_count; i++)
{
auto scale_tensor = x->node()->inputs()[0]->node()->inputs()[0]->node()->inputs()[0]->node()->inputs()[1]->node()->inputs()[0]->node()->inputs()[0]->node();
auto t = scale_tensor->t(torch::jit::attr::value);
float s = (float)t.item<double>();
const TorchNodeProxy* scale_tensor = graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(size_list->input(i))->input(0))->input(0))->input(0))->input(1))->input(0))->input(0));

// auto t = scale_tensor->t(torch::jit::attr::value);
// float s = (float)t.item<double>();
Parameter ps = scale_tensor->node;
float s = ps.f;

scale_factor.push_back(s);
}

@@ -150,7 +153,7 @@ public:
catch (...)
{
fprintf(stderr, "unhandled upsample recompute_scale_factor graph");
graph->dump();
graph.dump();
}
}
}


+ 3
- 5
tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.UpsamplingBilinear2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* upsample = find_node_by_kind(graph, "aten::upsample_bilinear2d");
const TorchNodeProxy* upsample = graph.find_node_by_kind("aten::upsample_bilinear2d");

if (upsample->hasNamedInput("output_size"))
{


+ 3
- 5
tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,9 +29,9 @@ public:
return "nn.UpsamplingNearest2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* upsample = find_node_by_kind(graph, "aten::upsample_nearest2d");
const TorchNodeProxy* upsample = graph.find_node_by_kind("aten::upsample_nearest2d");

if (upsample->hasNamedInput("output_size"))
{


+ 4
- 6
tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,10 +29,10 @@ public:
return "nn.ZeroPad2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd");
const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad");
const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd");

if (!pad)
{


+ 12
- 6
tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp View File

@@ -12,9 +12,11 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

#include "../utils.h"
#include <torch/script.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/passes/quantization/helper.h>

namespace pnnx {

@@ -31,11 +33,13 @@ public:
return "nn.quantized.Conv2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const
{
const auto& mod = _mod.mod;

// graph->dump();

const torch::jit::Node* quantized_convolution = find_node_by_kind(graph, "quantized::conv2d");
const TorchNodeProxy* quantized_convolution = graph.find_node_by_kind("quantized::conv2d");

// for (auto aa : quantized_convolution->schema().arguments())
// {
@@ -113,11 +117,13 @@ public:
return "nn.intrinsic.quantized.ConvReLU2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const
{
const auto& mod = _mod.mod;

// graph->dump();

const torch::jit::Node* quantized_convolution = find_node_by_kind(graph, "quantized::conv2d_relu");
const TorchNodeProxy* quantized_convolution = graph.find_node_by_kind("quantized::conv2d_relu");

// for (auto aa : quantized_convolution->schema().arguments())
// {


+ 14
- 16
tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,19 +29,19 @@ public:
return "nn.quantized.DeQuantize";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
{
// mod.dump(true, false, false);
// graph->dump();
const torch::jit::Node* dequantize = find_node_by_kind(graph, "aten::dequantize");
// for (auto aa : dequantize->schema().arguments())
// {
// fprintf(stderr, "arg %s\n", aa.name().c_str());
// }
}
// void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
// {
// // mod.dump(true, false, false);
//
// // graph->dump();
//
// const torch::jit::Node* dequantize = find_node_by_kind(graph, "aten::dequantize");
//
// // for (auto aa : dequantize->schema().arguments())
// // {
// // fprintf(stderr, "arg %s\n", aa.name().c_str());
// // }
// }
};

REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(DeQuantize)


+ 13
- 7
tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp View File

@@ -12,9 +12,11 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"
#include "fuse_module_pass.h"

#include "../utils.h"
#include <torch/script.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/passes/quantization/helper.h>

namespace pnnx {

@@ -31,13 +33,15 @@ public:
return "nn.quantized.Linear";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const
{
const auto& mod = _mod.mod;

// mod.dump(true, false, false);

// graph->dump();

const torch::jit::Node* quantized_linear = find_node_by_kind(graph, "quantized::linear");
const TorchNodeProxy* quantized_linear = graph.find_node_by_kind("quantized::linear");

// for (auto aa : quantized_linear->schema().arguments())
// {
@@ -99,13 +103,15 @@ public:
return "nn.intrinsic.quantized.LinearReLU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const
{
const auto& mod = _mod.mod;

// mod.dump(true, false, false);

graph->dump();
graph.dump();

const torch::jit::Node* quantized_linear = find_node_by_kind(graph, "quantized::linear_relu");
const TorchNodeProxy* quantized_linear = graph.find_node_by_kind("quantized::linear_relu");

// for (auto aa : quantized_linear->schema().arguments())
// {


+ 3
- 5
tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,13 +29,13 @@ public:
return "nn.quantized.Quantize";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
// mod.dump(true, false, false);

// graph->dump();

const torch::jit::Node* quantize_per_tensor = find_node_by_kind(graph, "aten::quantize_per_tensor");
const TorchNodeProxy* quantize_per_tensor = graph.find_node_by_kind("aten::quantize_per_tensor");

// for (auto aa : quantize_per_tensor->schema().arguments())
// {


+ 5
- 7
tools/pnnx/src/pass_level1/torchvision_DeformConv2d.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "torchvision.ops.DeformConv2d";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const
{
const torch::jit::Node* deform_conv2d = find_node_by_kind(graph, "torchvision::deform_conv2d");
const TorchNodeProxy* deform_conv2d = graph.find_node_by_kind("torchvision::deform_conv2d");

const auto& weight = mod.attr("weight").toTensor();
const TorchTensorProxy& weight = mod.attr("weight");

const Parameter stride_w = deform_conv2d->namedInput("stride_w");
const Parameter stride_h = deform_conv2d->namedInput("stride_h");
@@ -56,7 +54,7 @@ public:
op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
{
op->attrs["bias"] = mod.attr("bias").toTensor();
op->attrs["bias"] = mod.attr("bias");
}
}
};


+ 4
- 6
tools/pnnx/src/pass_level1/torchvision_RoIAlign.cpp View File

@@ -12,9 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
#include "fuse_module_pass.h"

namespace pnnx {

@@ -31,11 +29,11 @@ public:
return "torchvision.ops.RoIAlign";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& /*mod*/) const
void write(Operator* op, const TorchGraphProxy& graph) const
{
const torch::jit::Node* roi_align = find_node_by_kind(graph, "torchvision::roi_align");
const TorchNodeProxy* roi_align = graph.find_node_by_kind("torchvision::roi_align");

if (roi_align->inputs()[0] == graph->inputs()[2] && roi_align->inputs()[1] == graph->inputs()[1])
if (roi_align->input(0) == graph.input(2) && roi_align->input(1) == graph.input(1))
{
fprintf(stderr, "roi_align inputs swapped detected !\n");
std::swap(op->inputs[0], op->inputs[1]);


+ 0
- 14
tools/pnnx/src/utils.h View File

@@ -15,22 +15,8 @@
#ifndef PNNX_UTILS_H
#define PNNX_UTILS_H

#if BUILD_TORCH2PNNX
#include <memory>
namespace torch {
namespace jit {
struct Graph;
struct Node;
} // namespace jit
} // namespace torch
#endif

namespace pnnx {

#if BUILD_TORCH2PNNX
const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Graph>& graph, const std::string& kind);
#endif

unsigned short float32_to_float16(float value);

float float16_to_float32(unsigned short value);


Loading…
Cancel
Save