Browse Source

pnnx fuse more function to module (#4351)

* pnnx fuse more function to module

* rename some pass name

* fuse adjacent reshape, fuse pad conv2d

* fuse pad conv1d
tags/20221128
nihui GitHub 3 years ago
parent
commit
aed05aa851
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 3006 additions and 2075 deletions
  1. +11
    -2
      tools/pnnx/src/CMakeLists.txt
  2. +13
    -8
      tools/pnnx/src/pass_level2.cpp
  3. +2
    -0
      tools/pnnx/src/pass_level2.h
  4. +27
    -10
      tools/pnnx/src/pass_level5.cpp
  5. +2
    -2
      tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp
  6. +1
    -1
      tools/pnnx/src/pass_level5/eliminate_noop_slice.h
  7. +2
    -2
      tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp
  8. +1
    -1
      tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h
  9. +105
    -0
      tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp
  10. +21
    -0
      tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h
  11. +401
    -0
      tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp
  12. +21
    -0
      tools/pnnx/src/pass_level5/fuse_pad_conv1d.h
  13. +500
    -0
      tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp
  14. +21
    -0
      tools/pnnx/src/pass_level5/fuse_pad_conv2d.h
  15. +384
    -0
      tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp
  16. +21
    -0
      tools/pnnx/src/pass_level5/fuse_static_batchnorm.h
  17. +236
    -0
      tools/pnnx/src/pass_level5/fuse_static_conv.cpp
  18. +351
    -0
      tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp
  19. +21
    -0
      tools/pnnx/src/pass_level5/fuse_static_convtranspose.h
  20. +79
    -0
      tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp
  21. +21
    -0
      tools/pnnx/src/pass_level5/fuse_static_groupnorm.h
  22. +195
    -0
      tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp
  23. +21
    -0
      tools/pnnx/src/pass_level5/fuse_static_instancenorm.h
  24. +78
    -0
      tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp
  25. +21
    -0
      tools/pnnx/src/pass_level5/fuse_static_layernorm.h
  26. +195
    -0
      tools/pnnx/src/pass_level5/fuse_static_linear.cpp
  27. +21
    -0
      tools/pnnx/src/pass_level5/fuse_static_linear.h
  28. +0
    -248
      tools/pnnx/src/pass_ncnn/F_conv1d.cpp
  29. +0
    -264
      tools/pnnx/src/pass_ncnn/F_conv2d.cpp
  30. +0
    -280
      tools/pnnx/src/pass_ncnn/F_conv3d.cpp
  31. +0
    -326
      tools/pnnx/src/pass_ncnn/F_conv_transpose1d.cpp
  32. +0
    -354
      tools/pnnx/src/pass_ncnn/F_conv_transpose2d.cpp
  33. +0
    -378
      tools/pnnx/src/pass_ncnn/F_conv_transpose3d.cpp
  34. +0
    -49
      tools/pnnx/src/pass_ncnn/F_group_norm.cpp
  35. +0
    -55
      tools/pnnx/src/pass_ncnn/F_layer_norm.cpp
  36. +0
    -95
      tools/pnnx/src/pass_ncnn/F_linear.cpp
  37. +3
    -0
      tools/pnnx/tests/CMakeLists.txt
  38. +61
    -0
      tools/pnnx/tests/test_pnnx_fuse_adjacent_reshape.py
  39. +84
    -0
      tools/pnnx/tests/test_pnnx_fuse_pad_conv1d.py
  40. +86
    -0
      tools/pnnx/tests/test_pnnx_fuse_pad_conv2d.py

+ 11
- 2
tools/pnnx/src/CMakeLists.txt View File

@@ -304,10 +304,11 @@ set(pnnx_pass_level5_SRCS
pass_level5/eliminate_noop_expression.cpp
pass_level5/eliminate_noop_pad.cpp
pass_level5/eliminate_noop_upsample.cpp
pass_level5/eliminate_slice.cpp
pass_level5/eliminate_view_reshape.cpp
pass_level5/eliminate_noop_slice.cpp
pass_level5/eliminate_noop_view_reshape.cpp
pass_level5/eval_expression.cpp
pass_level5/fold_constants.cpp
pass_level5/fuse_adjacent_reshape.cpp
pass_level5/fuse_channel_shuffle.cpp
pass_level5/fuse_constant_expression.cpp
pass_level5/fuse_conv1d_batchnorm1d.cpp
@@ -316,11 +317,19 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_convtranspose2d_batchnorm2d.cpp
pass_level5/fuse_contiguous_view.cpp
pass_level5/fuse_linear_batchnorm1d.cpp
pass_level5/fuse_pad_conv1d.cpp
pass_level5/fuse_pad_conv2d.cpp
pass_level5/fuse_select_to_unbind.cpp
pass_level5/fuse_slice_copy.cpp
pass_level5/fuse_slice_indices.cpp
pass_level5/fuse_slice_to_tensor_split.cpp
pass_level5/fuse_static_batchnorm.cpp
pass_level5/fuse_static_conv.cpp
pass_level5/fuse_static_convtranspose.cpp
pass_level5/fuse_static_groupnorm.cpp
pass_level5/fuse_static_instancenorm.cpp
pass_level5/fuse_static_layernorm.cpp
pass_level5/fuse_static_linear.cpp
pass_level5/normalize_einsum_equation.cpp
pass_level5/unroll_rnn_op.cpp
)


+ 13
- 8
tools/pnnx/src/pass_level2.cpp View File

@@ -39,6 +39,11 @@ bool GraphRewriterPass::match(const std::map<std::string, Parameter>& captured_p
return match(captured_params);
}

bool GraphRewriterPass::match(const std::map<std::string, const Operator*>& /*matched_operators*/) const
{
return true;
}

void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
for (auto x : captured_params)
@@ -215,7 +220,7 @@ static bool match_operator(const Operator* a, const Operator* b, std::map<std::s
return true;
}

static bool match(const Operator* anchor, const Operator* pattern, std::unordered_map<std::string, const Operator*>& matched_operators, std::unordered_map<std::string, const Operand*>& matched_inputs, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
static bool match(const Operator* anchor, const Operator* pattern, std::map<std::string, const Operator*>& matched_operators, std::map<std::string, const Operand*>& matched_inputs, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
{
if (!match_operator(anchor, pattern, captured_params, captured_attrs))
return false;
@@ -290,9 +295,9 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
bool matched = true;

// lets match from output
std::unordered_map<std::string, const Operator*> matched_operators;
std::unordered_map<std::string, const Operand*> matched_inputs;
std::unordered_map<std::string, const Operand*> matched_outputs;
std::map<std::string, const Operator*> matched_operators;
std::map<std::string, const Operand*> matched_inputs;
std::map<std::string, const Operand*> matched_outputs;
std::map<std::string, Parameter> captured_params;
std::map<std::string, Attribute> captured_attrs;

@@ -311,8 +316,8 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
{
const Operator* anchor = graph.ops[j];

std::unordered_map<std::string, const Operator*> matched_operators2;
std::unordered_map<std::string, const Operand*> matched_inputs2;
std::map<std::string, const Operator*> matched_operators2;
std::map<std::string, const Operand*> matched_inputs2;
std::map<std::string, Parameter> captured_params2;
std::map<std::string, Attribute> captured_attrs2;
if (!match(anchor, pattern2, matched_operators2, matched_inputs2, captured_params2, captured_attrs2))
@@ -372,7 +377,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
break;
}

if (matched && !pass->match(captured_params, captured_attrs))
if (matched && (!pass->match(captured_params, captured_attrs) || !pass->match(matched_operators)))
{
matched_operators.clear();
matched_inputs.clear();
@@ -393,7 +398,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
// lets replace

// remove all operands inside matched graph
std::unordered_map<std::string, Operand*> operands_to_remove;
std::map<std::string, Operand*> operands_to_remove;
for (auto& _x : matched_operators)
{
Operator* x = (Operator*)_x.second;


+ 2
- 0
tools/pnnx/src/pass_level2.h View File

@@ -34,6 +34,8 @@ public:

virtual bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const;

virtual bool match(const std::map<std::string, const Operator*>& matched_operators) const;

virtual void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const;

virtual void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const;


+ 27
- 10
tools/pnnx/src/pass_level5.cpp View File

@@ -22,9 +22,10 @@
#include "pass_level5/eliminate_noop_expression.h"
#include "pass_level5/eliminate_noop_pad.h"
#include "pass_level5/eliminate_noop_upsample.h"
#include "pass_level5/eliminate_slice.h"
#include "pass_level5/eliminate_view_reshape.h"
#include "pass_level5/eliminate_noop_slice.h"
#include "pass_level5/eliminate_noop_view_reshape.h"
#include "pass_level5/eval_expression.h"
#include "pass_level5/fuse_adjacent_reshape.h"
#include "pass_level5/fuse_channel_shuffle.h"
#include "pass_level5/fuse_constant_expression.h"
#include "pass_level5/fuse_conv1d_batchnorm1d.h"
@@ -33,11 +34,19 @@
#include "pass_level5/fuse_convtranspose2d_batchnorm2d.h"
#include "pass_level5/fuse_contiguous_view.h"
#include "pass_level5/fuse_linear_batchnorm1d.h"
#include "pass_level5/fuse_pad_conv1d.h"
#include "pass_level5/fuse_pad_conv2d.h"
#include "pass_level5/fuse_select_to_unbind.h"
#include "pass_level5/fuse_slice_copy.h"
#include "pass_level5/fuse_slice_indices.h"
#include "pass_level5/fuse_slice_to_tensor_split.h"
#include "pass_level5/fuse_static_batchnorm.h"
#include "pass_level5/fuse_static_conv.h"
#include "pass_level5/fuse_static_convtranspose.h"
#include "pass_level5/fuse_static_groupnorm.h"
#include "pass_level5/fuse_static_instancenorm.h"
#include "pass_level5/fuse_static_layernorm.h"
#include "pass_level5/fuse_static_linear.h"
#include "pass_level5/normalize_einsum_equation.h"
#include "pass_level4/dead_code_elimination.h"
#include "pass_level4/canonicalize.h"
@@ -51,9 +60,11 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_constant_expression(g);

fold_constants(g, foldable_constants, foldable_constants_zippath);

eliminate_noop_expression(g);

eliminate_slice(g);
eliminate_noop_slice(g);

fuse_slice_indices(g);

@@ -69,18 +80,24 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_slice_copy(g);

fuse_static_batchnorm(g);
fuse_static_groupnorm(g);
fuse_static_instancenorm(g);
fuse_static_layernorm(g);

fuse_static_conv(g);
fuse_static_convtranspose(g);
fuse_static_linear(g);

fuse_conv1d_batchnorm1d(g);

fuse_conv2d_batchnorm2d(g);

fuse_convtranspose1d_batchnorm1d(g);

fuse_convtranspose2d_batchnorm2d(g);

fuse_linear_batchnorm1d(g);

fuse_pad_conv1d(g);
fuse_pad_conv2d(g);

eliminate_noop_pad(g);

eliminate_noop_cat(g);
@@ -91,11 +108,11 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_contiguous_view(g);

eliminate_view_reshape(g);
fuse_adjacent_reshape(g);

fuse_channel_shuffle(g);
eliminate_noop_view_reshape(g);

fold_constants(g, foldable_constants, foldable_constants_zippath);
fuse_channel_shuffle(g);

fuse_index_expression(g);



tools/pnnx/src/pass_level5/eliminate_slice.cpp → tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp View File

@@ -12,7 +12,7 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "eliminate_slice.h"
#include "eliminate_noop_slice.h"

#include <limits.h>
#include <algorithm>
@@ -20,7 +20,7 @@

namespace pnnx {

void eliminate_slice(Graph& graph)
void eliminate_noop_slice(Graph& graph)
{
while (1)
{

tools/pnnx/src/pass_level5/eliminate_slice.h → tools/pnnx/src/pass_level5/eliminate_noop_slice.h View File

@@ -16,6 +16,6 @@

namespace pnnx {

void eliminate_slice(Graph& graph);
void eliminate_noop_slice(Graph& graph);

} // namespace pnnx

tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp → tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp View File

@@ -12,14 +12,14 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "eliminate_view_reshape.h"
#include "eliminate_noop_view_reshape.h"

#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void eliminate_view_reshape(Graph& graph)
void eliminate_noop_view_reshape(Graph& graph)
{
while (1)
{

tools/pnnx/src/pass_level5/eliminate_view_reshape.h → tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h View File

@@ -16,6 +16,6 @@

namespace pnnx {

void eliminate_view_reshape(Graph& graph);
void eliminate_noop_view_reshape(Graph& graph);

} // namespace pnnx

+ 105
- 0
tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp View File

@@ -0,0 +1,105 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_adjacent_reshape.h"

#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void fuse_adjacent_reshape(Graph& graph)
{
while (1)
{
bool matched = false;

for (int i = (int)graph.ops.size() - 1; i > 0; i--)
{
Operator* op = graph.ops[i];

// look for Tensor.view / Tensor.reshape / torch.squeeze / torch.unsqueeze chain
if (op->type != "Tensor.view" && op->type != "Tensor.reshape" && op->type != "torch.squeeze" && op->type != "torch.unsqueeze")
continue;

if ((op->type == "torch.squeeze" || op->type == "torch.unsqueeze") && op->outputs[0]->shape.empty())
continue;

std::vector<Operator*> reshapes_to_delete;
const Operand* in0 = op->inputs[0];
while (in0->consumers.size() == 1 && (in0->producer->type == "Tensor.view" || in0->producer->type == "Tensor.reshape" || in0->producer->type == "torch.squeeze" || in0->producer->type == "torch.unsqueeze"))
{
reshapes_to_delete.push_back(in0->producer);
in0 = in0->producer->inputs[0];
}

if (reshapes_to_delete.empty())
continue;

// keep the last reshape only
matched = true;

op->type = "Tensor.reshape";

if (!op->outputs[0]->shape.empty())
{
op->params.clear();
op->params["shape"] = op->outputs[0]->shape;
}

for (auto& op0 : reshapes_to_delete)
{
for (auto& x : op0->inputs)
{
x->remove_consumer(op0);
}

Operand* op0_in = op0->inputs[0];
Operand* op0_out = op0->outputs[0];

for (auto& x : op0_out->consumers)
{
for (size_t j = 0; j < x->inputs.size(); j++)
{
if (x->inputs[j] == op0_out)
x->inputs[j] = op0_in;
}

op0_in->consumers.push_back(x);
}

op0_in->name = op0_out->name;

op0_out->producer = 0;
op0_out->consumers.clear();

graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op0_out));
delete op0_out;

op0->inputs.clear();
op0->outputs.clear();

graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op0));
delete op0;
}

break;
}

if (!matched)
break;
}
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_adjacent_reshape(Graph& graph);

} // namespace pnnx

+ 401
- 0
tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp View File

@@ -0,0 +1,401 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_pad_conv1d.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_pad_conv1d_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
F.pad op_pad 1 1 input a mode=constant pad=%pad value=%value
nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv1d";
}

const char* name_str() const
{
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
float pad_value = 0.f;
if (captured_params.at("value").type == 2)
pad_value = captured_params.at("value").i;
if (captured_params.at("value").type == 3)
pad_value = captured_params.at("value").f;

if (pad_value != 0.f)
return false;

const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 2)
return false;

if (pad.size() == 2 && pad[0] != pad[1])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<int>& pad = captured_params.at("pad").ai;
std::vector<int> padding = captured_params.at("padding").ai;
padding[0] += pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "zeros";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv1d_pass_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
F.pad op_pad 1 1 input a mode=%mode pad=%pad
nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv1d";
}

const char* name_str() const
{
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// reflect/replicate + nopad
if (captured_params.at("mode").s != "reflect" && captured_params.at("mode").s != "replicate")
return false;

const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 2)
return false;

if (pad.size() == 2 && pad[0] != pad[1])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<int>& pad = captured_params.at("pad").ai;
std::vector<int> padding(1);
padding[0] = pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = captured_params.at("mode");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv1d_pass_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.ConstantPad1d op_pad 1 1 input a padding=%pad value=%value
nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv1d";
}

const char* name_str() const
{
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
float pad_value = 0.f;
if (captured_params.at("value").type == 2)
pad_value = captured_params.at("value").i;
if (captured_params.at("value").type == 3)
pad_value = captured_params.at("value").f;

if (pad_value != 0.f)
return false;

const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 2)
return false;

if (pad[0] != pad[1])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::vector<int> padding = captured_params.at("padding").ai;
const std::vector<int>& pad = captured_params.at("pad").ai;
padding[0] += pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "zeros";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv1d_pass_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.ReplicationPad1d op_pad 1 1 input a padding=%pad
nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv1d";
}

const char* name_str() const
{
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// replicate + nopad
const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 2)
return false;

if (pad[0] != pad[1])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::vector<int> padding(1);
const std::vector<int>& pad = captured_params.at("pad").ai;
padding[0] = pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "replicate";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv1d_pass_4 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.ReflectionPad1d op_pad 1 1 input a padding=%pad
nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv1d";
}

const char* name_str() const
{
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// reflect + nopad
const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 2)
return false;

if (pad[0] != pad[1])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::vector<int> padding(1);
const std::vector<int>& pad = captured_params.at("pad").ai;
padding[0] = pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "reflect";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

void fuse_pad_conv1d(Graph& graph)
{
fuse_pad_conv1d_pass a;
fuse_pad_conv1d_pass_1 b;
fuse_pad_conv1d_pass_2 c;
fuse_pad_conv1d_pass_3 d;
fuse_pad_conv1d_pass_4 e;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &c, opindex);
pnnx_graph_rewrite(graph, &d, opindex);
pnnx_graph_rewrite(graph, &e, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_pad_conv1d.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_pad_conv1d(Graph& graph);

} // namespace pnnx

+ 500
- 0
tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp View File

@@ -0,0 +1,500 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_pad_conv2d.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_pad_conv2d_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
F.pad op_pad 1 1 input a mode=constant pad=%pad value=%value
nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv2d";
}

const char* name_str() const
{
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
float pad_value = 0.f;
if (captured_params.at("value").type == 2)
pad_value = captured_params.at("value").i;
if (captured_params.at("value").type == 3)
pad_value = captured_params.at("value").f;

if (pad_value != 0.f)
return false;

const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 2 && pad.size() != 4)
return false;

if (pad.size() == 2 && pad[0] != pad[1])
return false;

if (pad.size() == 4 && (pad[0] != pad[1] || pad[2] != pad[3]))
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<int>& pad = captured_params.at("pad").ai;
std::vector<int> padding = captured_params.at("padding").ai;

if (pad.size() == 2)
{
padding[1] += pad[0];
}
else if (pad.size() == 4)
{
padding[0] += pad[2];
padding[1] += pad[0];
}

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "zeros";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv2d_pass_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
F.pad op_pad 1 1 input a mode=%mode pad=%pad
nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv2d";
}

const char* name_str() const
{
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// reflect/replicate + nopad
if (captured_params.at("mode").s != "reflect" && captured_params.at("mode").s != "replicate")
return false;

const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 2 && pad.size() != 4)
return false;

if (pad.size() == 2 && pad[0] != pad[1])
return false;

if (pad.size() == 4 && (pad[0] != pad[1] || pad[2] != pad[3]))
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<int>& pad = captured_params.at("pad").ai;
std::vector<int> padding(2);

if (pad.size() == 2)
{
padding[0] = 0;
padding[1] = pad[0];
}
else if (pad.size() == 4)
{
padding[0] = pad[2];
padding[1] = pad[0];
}

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = captured_params.at("mode");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv2d_pass_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.ConstantPad2d op_pad 1 1 input a padding=%pad value=%value
nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv2d";
}

const char* name_str() const
{
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
float pad_value = 0.f;
if (captured_params.at("value").type == 2)
pad_value = captured_params.at("value").i;
if (captured_params.at("value").type == 3)
pad_value = captured_params.at("value").f;

if (pad_value != 0.f)
return false;

const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 4)
return false;

if (pad[0] != pad[1] || pad[2] != pad[3])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::vector<int> padding = captured_params.at("padding").ai;
const std::vector<int>& pad = captured_params.at("pad").ai;
padding[0] += pad[2];
padding[1] += pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "zeros";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv2d_pass_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.ZeroPad2d op_pad 1 1 input a padding=%pad
nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv2d";
}

const char* name_str() const
{
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 4)
return false;

if (pad[0] != pad[1] || pad[2] != pad[3])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::vector<int> padding = captured_params.at("padding").ai;
const std::vector<int>& pad = captured_params.at("pad").ai;
padding[0] += pad[2];
padding[1] += pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "zeros";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv2d_pass_4 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.ReplicationPad2d op_pad 1 1 input a padding=%pad
nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv2d";
}

const char* name_str() const
{
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// replicate + nopad
const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 4)
return false;

if (pad[0] != pad[1] || pad[2] != pad[3])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::vector<int> padding(2);
const std::vector<int>& pad = captured_params.at("pad").ai;
padding[0] = pad[2];
padding[1] = pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "replicate";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

class fuse_pad_conv2d_pass_5 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.ReflectionPad2d op_pad 1 1 input a padding=%pad
nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv2d";
}

const char* name_str() const
{
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
{
// reflect + nopad
const std::vector<int>& pad = captured_params.at("pad").ai;
for (int x : pad)
{
if (x < 0)
return false;
}

if (pad.size() != 4)
return false;

if (pad[0] != pad[1] || pad[2] != pad[3])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::vector<int> padding(2);
const std::vector<int>& pad = captured_params.at("pad").ai;
padding[0] = pad[2];
padding[1] = pad[0];

op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = "reflect";
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = padding;
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = captured_params.at("bias");

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (captured_params.at("bias").b)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
}
};

void fuse_pad_conv2d(Graph& graph)
{
fuse_pad_conv2d_pass a;
fuse_pad_conv2d_pass_1 b;
fuse_pad_conv2d_pass_2 c;
fuse_pad_conv2d_pass_3 d;
fuse_pad_conv2d_pass_4 e;
fuse_pad_conv2d_pass_5 f;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &c, opindex);
pnnx_graph_rewrite(graph, &d, opindex);
pnnx_graph_rewrite(graph, &e, opindex);
pnnx_graph_rewrite(graph, &f, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_pad_conv2d.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_pad_conv2d(Graph& graph);

} // namespace pnnx

+ 384
- 0
tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp View File

@@ -0,0 +1,384 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_static_batchnorm.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_static_Fbatchnorm_pass_1d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm1d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 2 || input_rank == 3;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = false;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
}
};

class fuse_static_Fbatchnorm_pass_1d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm1d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 2 || input_rank == 3;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Fbatchnorm_pass_2d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm2d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 4;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = false;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
}
};

class fuse_static_Fbatchnorm_pass_2d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm2d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 4;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Fbatchnorm_pass_3d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm3d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 5;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = false;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
}
};

class fuse_static_Fbatchnorm_pass_3d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm3d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 5;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

void fuse_static_batchnorm(Graph& graph)
{
fuse_static_Fbatchnorm_pass_1d a;
fuse_static_Fbatchnorm_pass_2d b;
fuse_static_Fbatchnorm_pass_3d c;
fuse_static_Fbatchnorm_pass_1d_1 a1;
fuse_static_Fbatchnorm_pass_2d_1 b1;
fuse_static_Fbatchnorm_pass_3d_1 c1;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &c, opindex);
pnnx_graph_rewrite(graph, &a1, opindex);
pnnx_graph_rewrite(graph, &b1, opindex);
pnnx_graph_rewrite(graph, &c1, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_static_batchnorm.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_static_batchnorm(Graph& graph);

} // namespace pnnx

+ 236
- 0
tools/pnnx/src/pass_level5/fuse_static_conv.cpp View File

@@ -120,6 +120,82 @@ pnnx.Output output 1 0 out
}
};

class fuse_static_Fconv1d_pass_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv1d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Expression op_1 2 1 a bias out expr=%expr
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv1d";
}

const char* name_str() const
{
return "conv1d";
}

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::string& expr = captured_params.at("expr").s;
if (expr != "add(@0,@1)")
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

int out_channels = weight.shape[0];
if (bias.shape != std::vector<int>{1, out_channels, 1})
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Fconv2d_pass : public GraphRewriterPass
{
public:
@@ -219,6 +295,82 @@ pnnx.Output output 1 0 out
}
};

class fuse_static_Fconv2d_pass_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv2d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Expression op_1 2 1 a bias out expr=%expr
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv2d";
}

const char* name_str() const
{
return "conv2d";
}

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::string& expr = captured_params.at("expr").s;
if (expr != "add(@0,@1)")
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

int out_channels = weight.shape[0];
if (bias.shape != std::vector<int>{1, out_channels, 1, 1})
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Fconv3d_pass : public GraphRewriterPass
{
public:
@@ -318,8 +470,88 @@ pnnx.Output output 1 0 out
}
};

class fuse_static_Fconv3d_pass_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv3d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Expression op_1 2 1 a bias out expr=%expr
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv3d";
}

const char* name_str() const
{
return "conv3d";
}

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::string& expr = captured_params.at("expr").s;
if (expr != "add(@0,@1)")
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

int out_channels = weight.shape[0];
if (bias.shape != std::vector<int>{1, out_channels, 1, 1, 1})
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3], weight.shape[4]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

void fuse_static_conv(Graph& graph)
{
fuse_static_Fconv1d_pass_3 a3;
fuse_static_Fconv2d_pass_3 a4;
fuse_static_Fconv3d_pass_3 a5;

fuse_static_Fconv1d_pass a;
fuse_static_Fconv1d_pass_2 b;
fuse_static_Fconv2d_pass c;
@@ -328,6 +560,10 @@ void fuse_static_conv(Graph& graph)
fuse_static_Fconv3d_pass_2 f;
int opindex = 0;

pnnx_graph_rewrite(graph, &a3, opindex);
pnnx_graph_rewrite(graph, &a4, opindex);
pnnx_graph_rewrite(graph, &a5, opindex);

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &c, opindex);


+ 351
- 0
tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp View File

@@ -0,0 +1,351 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_static_convtranspose.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_static_Fconvtranspose1d_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose1d";
}

const char* name_str() const
{
return "conv_transpose1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = false;

op->attrs["weight"] = weight;
}
};

class fuse_static_Fconvtranspose1d_pass_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose1d";
}

const char* name_str() const
{
return "conv_transpose1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Fconvtranspose2d_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose2d";
}

const char* name_str() const
{
return "conv_transpose2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = false;

op->attrs["weight"] = weight;
}
};

class fuse_static_Fconvtranspose2d_pass_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose2d";
}

const char* name_str() const
{
return "conv_transpose2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Fconvtranspose3d_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose3d";
}

const char* name_str() const
{
return "conv_transpose3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = false;

op->attrs["weight"] = weight;
}
};

class fuse_static_Fconvtranspose3d_pass_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose3d";
}

const char* name_str() const
{
return "conv_transpose3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

void fuse_static_convtranspose(Graph& graph)
{
fuse_static_Fconvtranspose1d_pass a;
fuse_static_Fconvtranspose1d_pass_2 b;
fuse_static_Fconvtranspose2d_pass c;
fuse_static_Fconvtranspose2d_pass_2 d;
fuse_static_Fconvtranspose3d_pass e;
fuse_static_Fconvtranspose3d_pass_2 f;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &c, opindex);
pnnx_graph_rewrite(graph, &d, opindex);
pnnx_graph_rewrite(graph, &e, opindex);
pnnx_graph_rewrite(graph, &f, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_static_convtranspose.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_static_convtranspose(Graph& graph);

} // namespace pnnx

+ 79
- 0
tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp View File

@@ -0,0 +1,79 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_static_groupnorm.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_static_Fgroupnorm_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.group_norm op_0 3 1 input weight bias out num_groups=%num_groups eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.GroupNorm";
}

const char* name_str() const
{
return "group_norm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_channels"] = weight.shape[0];
op->params["num_groups"] = captured_params.at("num_groups");
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

void fuse_static_groupnorm(Graph& graph)
{
fuse_static_Fgroupnorm_pass a;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_static_groupnorm.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_static_groupnorm(Graph& graph);

} // namespace pnnx

+ 195
- 0
tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp View File

@@ -0,0 +1,195 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_static_instancenorm.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_static_Finstancenorm_pass_1d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.InstanceNorm1d";
}

const char* name_str() const
{
return "instance_norm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 2 || input_rank == 3;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = weight.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;
op->params["track_running_stats"] = false;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Finstancenorm_pass_2d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.InstanceNorm1d";
}

const char* name_str() const
{
return "instance_norm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 4;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = weight.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;
op->params["track_running_stats"] = false;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Finstancenorm_pass_3d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.InstanceNorm1d";
}

const char* name_str() const
{
return "instance_norm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 5;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = weight.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;
op->params["track_running_stats"] = false;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

void fuse_static_instancenorm(Graph& graph)
{
fuse_static_Finstancenorm_pass_1d a;
fuse_static_Finstancenorm_pass_2d b;
fuse_static_Finstancenorm_pass_3d c;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &c, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_static_instancenorm.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_static_instancenorm(Graph& graph);

} // namespace pnnx

+ 78
- 0
tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp View File

@@ -0,0 +1,78 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_static_layernorm.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_static_Flayernorm_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.layer_norm op_0 3 1 input weight bias out normalized_shape=%normalized_shape eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.LayerNorm";
}

const char* name_str() const
{
return "layer_norm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["normalized_shape"] = captured_params.at("normalized_shape");
op->params["eps"] = captured_params.at("eps");
op->params["elementwise_affine"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

void fuse_static_layernorm(Graph& graph)
{
fuse_static_Flayernorm_pass a;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_static_layernorm.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_static_layernorm(Graph& graph);

} // namespace pnnx

+ 195
- 0
tools/pnnx/src/pass_level5/fuse_static_linear.cpp View File

@@ -0,0 +1,195 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_static_linear.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_static_Flinear_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.linear op_0 2 1 input weight out bias=None
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Linear";
}

const char* name_str() const
{
return "linear";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["in_features"] = weight.shape[1];
op->params["out_features"] = weight.shape[0];
op->params["bias"] = false;

op->attrs["weight"] = weight;
}
};

class fuse_static_Flinear_pass_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.linear op_0 3 1 input weight bias out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Linear";
}

const char* name_str() const
{
return "linear";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_features"] = weight.shape[1];
op->params["out_features"] = weight.shape[0];
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Flinear_pass_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.linear op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Expression op_1 2 1 a bias out expr=%expr
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Linear";
}

const char* name_str() const
{
return "linear";
}

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::string& expr = captured_params.at("expr").s;
if (expr != "add(@0,@1)")
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

int out_channels = weight.shape[0];
if (bias.shape != std::vector<int>{1, out_channels, 1})
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_features"] = weight.shape[1];
op->params["out_features"] = weight.shape[0];
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

void fuse_static_linear(Graph& graph)
{
fuse_static_Flinear_pass_3 a3;

fuse_static_Flinear_pass a;
fuse_static_Flinear_pass_2 b;
int opindex = 0;

pnnx_graph_rewrite(graph, &a3, opindex);

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_static_linear.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_static_linear(Graph& graph);

} // namespace pnnx

+ 0
- 248
tools/pnnx/src/pass_ncnn/F_conv1d.cpp View File

@@ -18,254 +18,6 @@ namespace pnnx {

namespace ncnn {

class F_conv1d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Convolution1D";
}

const char* name_str() const
{
return "conv1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv1d, 20)

class F_conv1d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Convolution1D";
}

const char* name_str() const
{
return "conv1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv1d_1, 20)

class F_conv1d_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "ConvolutionDepthWise1D";
}

const char* name_str() const
{
return "convdw1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = captured_params.at("groups");

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv1d_2, 21)

class F_conv1d_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "ConvolutionDepthWise1D";
}

const char* name_str() const
{
return "convdw1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = captured_params.at("groups");

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv1d_3, 21)

class F_conv1d_4 : public GraphRewriterPass
{
public:


+ 0
- 264
tools/pnnx/src/pass_ncnn/F_conv2d.cpp View File

@@ -18,270 +18,6 @@ namespace pnnx {

namespace ncnn {

class F_conv2d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Convolution";
}

const char* name_str() const
{
return "conv2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[3];
op->params["11"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[1];
op->params["12"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[1];
op->params["13"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d, 20)

class F_conv2d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Convolution";
}

const char* name_str() const
{
return "conv2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[3];
op->params["11"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[1];
op->params["12"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[1];
op->params["13"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_1, 20)

class F_conv2d_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "ConvolutionDepthWise";
}

const char* name_str() const
{
return "convdw2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[3];
op->params["11"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[1];
op->params["12"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[1];
op->params["13"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = captured_params.at("groups");

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_2, 21)

class F_conv2d_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "ConvolutionDepthWise";
}

const char* name_str() const
{
return "convdw2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[3];
op->params["11"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[1];
op->params["12"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[1];
op->params["13"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = captured_params.at("groups");

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_3, 21)

class F_conv2d_4 : public GraphRewriterPass
{
public:


+ 0
- 280
tools/pnnx/src/pass_ncnn/F_conv3d.cpp View File

@@ -18,286 +18,6 @@ namespace pnnx {

namespace ncnn {

class F_conv3d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Convolution3D";
}

const char* name_str() const
{
return "conv3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[4];
op->params["11"] = weight.shape[3];
op->params["21"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[2];
op->params["12"] = captured_params.at("dilation").ai[1];
op->params["22"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[2];
op->params["13"] = captured_params.at("stride").ai[1];
op->params["23"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[2];
op->params["14"] = captured_params.at("padding").ai[1];
op->params["24"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv3d, 20)

class F_conv3d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Convolution3D";
}

const char* name_str() const
{
return "conv3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[4];
op->params["11"] = weight.shape[3];
op->params["21"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[2];
op->params["12"] = captured_params.at("dilation").ai[1];
op->params["22"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[2];
op->params["13"] = captured_params.at("stride").ai[1];
op->params["23"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[2];
op->params["14"] = captured_params.at("padding").ai[1];
op->params["24"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv3d_1, 20)

class F_conv3d_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "ConvolutionDepthWise3D";
}

const char* name_str() const
{
return "convdw3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[4];
op->params["11"] = weight.shape[3];
op->params["21"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[2];
op->params["12"] = captured_params.at("dilation").ai[1];
op->params["22"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[2];
op->params["13"] = captured_params.at("stride").ai[1];
op->params["23"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[2];
op->params["14"] = captured_params.at("padding").ai[1];
op->params["24"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = captured_params.at("groups");

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv3d_2, 21)

class F_conv3d_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "ConvolutionDepthWise3D";
}

const char* name_str() const
{
return "convdw3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = weight.shape[4];
op->params["11"] = weight.shape[3];
op->params["21"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[2];
op->params["12"] = captured_params.at("dilation").ai[1];
op->params["22"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[2];
op->params["13"] = captured_params.at("stride").ai[1];
op->params["23"] = captured_params.at("stride").ai[0];
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
op->params["4"] = -233;
else if (captured_params.at("padding").s == "valid")
op->params["4"] = 0;
}
else
{
op->params["4"] = captured_params.at("padding").ai[2];
op->params["14"] = captured_params.at("padding").ai[1];
op->params["24"] = captured_params.at("padding").ai[0];
}
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = captured_params.at("groups");

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv3d_3, 21)

} // namespace ncnn

} // namespace pnnx

+ 0
- 326
tools/pnnx/src/pass_ncnn/F_conv_transpose1d.cpp View File

@@ -18,332 +18,6 @@ namespace pnnx {

namespace ncnn {

class F_conv_transpose1d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Deconvolution1D";
}

const char* name_str() const
{
return "conv_transpose1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

// transpose inch-outch-kw to outch-inch-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1];
const int kw = weight.shape[2];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch * inch * kw);
float* w2 = (float*)new_weight.data();

// reorder weight from inch-outch to outch-inch
for (int i = 0; i < outch; i++)
{
for (int j = 0; j < inch; j++)
{
for (int k = 0; k < kw; k++)
{
w2[(i * inch + j) * kw + k] = w[(j * outch + i) * kw + k];
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch, inch, kw}, new_weight);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose1d, 20)

class F_conv_transpose1d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Deconvolution1D";
}

const char* name_str() const
{
return "conv_transpose1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

// transpose inch-outch-kw to outch-inch-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1];
const int kw = weight.shape[2];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch * inch * kw);
float* w2 = (float*)new_weight.data();

// reorder weight from inch-outch to outch-inch
for (int i = 0; i < outch; i++)
{
for (int j = 0; j < inch; j++)
{
for (int k = 0; k < kw; k++)
{
w2[(i * inch + j) * kw + k] = w[(j * outch + i) * kw + k];
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch, inch, kw}, new_weight);
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose1d_1, 20)

class F_conv_transpose1d_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "DeconvolutionDepthWise1D";
}

const char* name_str() const
{
return "deconvdw1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["0"] = weight.shape[1] * groups;
op->params["1"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = groups;

// transpose group-inch/group-outch/group-kw to group-outch/group-inch/group-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1] * groups;
const int kw = weight.shape[2];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch / groups * inch * kw);
float* w2 = (float*)new_weight.data();
const int outch_g = outch / groups;
const int inch_g = inch / groups;

for (int g = 0; g < groups; g++)
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * kw;
const float* wg = w + g * inch_g * outch_g * kw;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
{
for (int k = 0; k < kw; k++)
{
wg2[(i * inch_g + j) * kw + k] = wg[(j * outch_g + i) * kw + k];
}
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch / groups, inch, kw}, new_weight);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose1d_2, 21)

class F_conv_transpose1d_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "DeconvolutionDepthWise1D";
}

const char* name_str() const
{
return "deconvdw1d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["0"] = weight.shape[1] * groups;
op->params["1"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = groups;

// transpose group-inch/group-outch/group-kw to group-outch/group-inch/group-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1] * groups;
const int kw = weight.shape[2];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch / groups * inch * kw);
float* w2 = (float*)new_weight.data();
const int outch_g = outch / groups;
const int inch_g = inch / groups;

for (int g = 0; g < groups; g++)
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * kw;
const float* wg = w + g * inch_g * outch_g * kw;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
{
for (int k = 0; k < kw; k++)
{
wg2[(i * inch_g + j) * kw + k] = wg[(j * outch_g + i) * kw + k];
}
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch / groups, inch, kw}, new_weight);
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose1d_3, 21)

} // namespace ncnn

} // namespace pnnx

+ 0
- 354
tools/pnnx/src/pass_ncnn/F_conv_transpose2d.cpp View File

@@ -18,360 +18,6 @@ namespace pnnx {

namespace ncnn {

class F_conv_transpose2d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Deconvolution";
}

const char* name_str() const
{
return "conv_transpose2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[3];
op->params["11"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[1];
op->params["12"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[1];
op->params["13"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[1];
op->params["19"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

// transpose inch-outch-kh-kw to outch-inch-kh-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1];
const int kh = weight.shape[2];
const int kw = weight.shape[3];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch * inch * kh * kw);
float* w2 = (float*)new_weight.data();
const int maxk = kh * kw;

// reorder weight from inch-outch to outch-inch
for (int i = 0; i < outch; i++)
{
for (int j = 0; j < inch; j++)
{
for (int k = 0; k < maxk; k++)
{
w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k];
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch, inch, kh, kw}, new_weight);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose2d, 20)

class F_conv_transpose2d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Deconvolution";
}

const char* name_str() const
{
return "conv_transpose2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[3];
op->params["11"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[1];
op->params["12"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[1];
op->params["13"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[1];
op->params["19"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

// transpose inch-outch-kh-kw to outch-inch-kh-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1];
const int kh = weight.shape[2];
const int kw = weight.shape[3];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch * inch * kh * kw);
float* w2 = (float*)new_weight.data();
const int maxk = kh * kw;

// reorder weight from inch-outch to outch-inch
for (int i = 0; i < outch; i++)
{
for (int j = 0; j < inch; j++)
{
for (int k = 0; k < maxk; k++)
{
w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k];
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch, inch, kh, kw}, new_weight);
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose2d_1, 20)

class F_conv_transpose2d_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "DeconvolutionDepthWise";
}

const char* name_str() const
{
return "deconvdw2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["0"] = weight.shape[1] * groups;
op->params["1"] = weight.shape[3];
op->params["11"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[1];
op->params["12"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[1];
op->params["13"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[1];
op->params["19"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = groups;

// transpose group-inch/group-outch/group-kh-kw to group-outch/group-inch/group-kh-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1] * groups;
const int kh = weight.shape[2];
const int kw = weight.shape[3];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch / groups * inch * kh * kw);
float* w2 = (float*)new_weight.data();
const int outch_g = outch / groups;
const int inch_g = inch / groups;
const int maxk = kh * kw;

for (int g = 0; g < groups; g++)
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * maxk;
const float* wg = w + g * inch_g * outch_g * maxk;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
{
for (int k = 0; k < maxk; k++)
{
wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k];
}
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch / groups, inch, kh, kw}, new_weight);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose2d_2, 21)

class F_conv_transpose2d_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "DeconvolutionDepthWise";
}

const char* name_str() const
{
return "deconvdw2d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["0"] = weight.shape[1] * groups;
op->params["1"] = weight.shape[3];
op->params["11"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[1];
op->params["12"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[1];
op->params["13"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[1];
op->params["19"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = groups;

// transpose group-inch/group-outch/group-kh-kw to group-outch/group-inch/group-kh-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1] * groups;
const int kh = weight.shape[2];
const int kw = weight.shape[3];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch / groups * inch * kh * kw);
float* w2 = (float*)new_weight.data();
const int outch_g = outch / groups;
const int inch_g = inch / groups;
const int maxk = kh * kw;

for (int g = 0; g < groups; g++)
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * maxk;
const float* wg = w + g * inch_g * outch_g * maxk;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
{
for (int k = 0; k < maxk; k++)
{
wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k];
}
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch / groups, inch, kh, kw}, new_weight);
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose2d_3, 21)

} // namespace ncnn

} // namespace pnnx

+ 0
- 378
tools/pnnx/src/pass_ncnn/F_conv_transpose3d.cpp View File

@@ -18,384 +18,6 @@ namespace pnnx {

namespace ncnn {

class F_conv_transpose3d : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Deconvolution3D";
}

const char* name_str() const
{
return "conv_transpose3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[4];
op->params["11"] = weight.shape[3];
op->params["21"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[2];
op->params["12"] = captured_params.at("dilation").ai[1];
op->params["22"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[2];
op->params["13"] = captured_params.at("stride").ai[1];
op->params["23"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[2];
op->params["14"] = captured_params.at("padding").ai[1];
op->params["24"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[2];
op->params["19"] = captured_params.at("output_padding").ai[1];
op->params["20"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

// transpose inch-outch-kd-kh-kw to outch-inch-kd-kh-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1];
const int kd = weight.shape[2];
const int kh = weight.shape[3];
const int kw = weight.shape[4];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch * inch * kd * kh * kw);
float* w2 = (float*)new_weight.data();
const int maxk = kd * kh * kw;

// reorder weight from inch-outch to outch-inch
for (int i = 0; i < outch; i++)
{
for (int j = 0; j < inch; j++)
{
for (int k = 0; k < maxk; k++)
{
w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k];
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch, inch, kd, kh, kw}, new_weight);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose3d, 20)

class F_conv_transpose3d_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Deconvolution3D";
}

const char* name_str() const
{
return "conv_transpose3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[4];
op->params["11"] = weight.shape[3];
op->params["21"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[2];
op->params["12"] = captured_params.at("dilation").ai[1];
op->params["22"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[2];
op->params["13"] = captured_params.at("stride").ai[1];
op->params["23"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[2];
op->params["14"] = captured_params.at("padding").ai[1];
op->params["24"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[2];
op->params["19"] = captured_params.at("output_padding").ai[1];
op->params["20"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));

// transpose inch-outch-kd-kh-kw to outch-inch-kd-kh-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1];
const int kd = weight.shape[2];
const int kh = weight.shape[3];
const int kw = weight.shape[4];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch * inch * kd * kh * kw);
float* w2 = (float*)new_weight.data();
const int maxk = kd * kh * kw;

// reorder weight from inch-outch to outch-inch
for (int i = 0; i < outch; i++)
{
for (int j = 0; j < inch; j++)
{
for (int k = 0; k < maxk; k++)
{
w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k];
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch, inch, kd, kh, kw}, new_weight);
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose3d_1, 20)

class F_conv_transpose3d_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "DeconvolutionDepthWise3D";
}

const char* name_str() const
{
return "deconvdw3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["0"] = weight.shape[1] * groups;
op->params["1"] = weight.shape[4];
op->params["11"] = weight.shape[3];
op->params["21"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[2];
op->params["12"] = captured_params.at("dilation").ai[1];
op->params["22"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[2];
op->params["13"] = captured_params.at("stride").ai[1];
op->params["23"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[2];
op->params["14"] = captured_params.at("padding").ai[1];
op->params["24"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[2];
op->params["19"] = captured_params.at("output_padding").ai[1];
op->params["20"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 0;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = groups;

// transpose group-inch/group-outch/group-kd-kh-kw to group-outch/group-inch/group-kd-kh-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1] * groups;
const int kd = weight.shape[2];
const int kh = weight.shape[3];
const int kw = weight.shape[4];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch / groups * inch * kd * kh * kw);
float* w2 = (float*)new_weight.data();
const int outch_g = outch / groups;
const int inch_g = inch / groups;
const int maxk = kd * kh * kw;

for (int g = 0; g < groups; g++)
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * maxk;
const float* wg = w + g * inch_g * outch_g * maxk;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
{
for (int k = 0; k < maxk; k++)
{
wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k];
}
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch / groups, inch, kd, kh, kw}, new_weight);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose3d_2, 21)

class F_conv_transpose3d_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "DeconvolutionDepthWise3D";
}

const char* name_str() const
{
return "deconvdw3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int groups = captured_params.at("groups").i;

op->params["0"] = weight.shape[1] * groups;
op->params["1"] = weight.shape[4];
op->params["11"] = weight.shape[3];
op->params["21"] = weight.shape[2];
op->params["2"] = captured_params.at("dilation").ai[2];
op->params["12"] = captured_params.at("dilation").ai[1];
op->params["22"] = captured_params.at("dilation").ai[0];
op->params["3"] = captured_params.at("stride").ai[2];
op->params["13"] = captured_params.at("stride").ai[1];
op->params["23"] = captured_params.at("stride").ai[0];
op->params["4"] = captured_params.at("padding").ai[2];
op->params["14"] = captured_params.at("padding").ai[1];
op->params["24"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[2];
op->params["19"] = captured_params.at("output_padding").ai[1];
op->params["20"] = captured_params.at("output_padding").ai[0];
op->params["5"] = 1;
op->params["6"] = (int)(weight.data.size() / sizeof(float));
op->params["7"] = groups;

// transpose group-inch/group-outch/group-kd-kh-kw to group-outch/group-inch/group-kd-kh-kw
const int inch = weight.shape[0];
const int outch = weight.shape[1] * groups;
const int kd = weight.shape[2];
const int kh = weight.shape[3];
const int kw = weight.shape[4];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch / groups * inch * kd * kh * kw);
float* w2 = (float*)new_weight.data();
const int outch_g = outch / groups;
const int inch_g = inch / groups;
const int maxk = kd * kh * kw;

for (int g = 0; g < groups; g++)
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * maxk;
const float* wg = w + g * inch_g * outch_g * maxk;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
{
for (int k = 0; k < maxk; k++)
{
wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k];
}
}
}
}
}

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch / groups, inch, kd, kh, kw}, new_weight);
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose3d_3, 21)

} // namespace ncnn

} // namespace pnnx

+ 0
- 49
tools/pnnx/src/pass_ncnn/F_group_norm.cpp View File

@@ -60,55 +60,6 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_group_norm, 20)

class F_group_norm_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.group_norm op_0 3 1 input weight bias out num_groups=%num_groups eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "GroupNorm";
}

const char* name_str() const
{
return "gn";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = captured_params.at("num_groups");
op->params["1"] = weight.shape[0];
op->params["2"] = captured_params.at("eps");
op->params["3"] = 1;

op->attrs["0"] = weight;
op->attrs["1"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_group_norm_1, 20)

} // namespace ncnn

} // namespace pnnx

+ 0
- 55
tools/pnnx/src/pass_ncnn/F_layer_norm.cpp View File

@@ -58,61 +58,6 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_layer_norm, 20)

class F_layer_norm_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.layer_norm op_0 3 1 input weight bias out normalized_shape=%normalized_shape eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "LayerNorm";
}

const char* name_str() const
{
return "ln";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const std::vector<int>& normalized_shape = captured_params.at("normalized_shape").ai;
int affine_size = normalized_shape[0];
for (size_t i = 1; i < normalized_shape.size(); i++)
{
affine_size *= normalized_shape[i];
}

op->params["0"] = affine_size;
op->params["1"] = captured_params.at("eps");
op->params["2"] = 1;

op->attrs["0"] = weight;
op->attrs["1"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_layer_norm_1, 20)

} // namespace ncnn

} // namespace pnnx

+ 0
- 95
tools/pnnx/src/pass_ncnn/F_linear.cpp View File

@@ -18,101 +18,6 @@ namespace pnnx {

namespace ncnn {

class F_linear : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.linear op_0 2 1 input weight out bias=None
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "InnerProduct";
}

const char* name_str() const
{
return "linear";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = 0;
op->params["2"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_linear, 20)

class F_linear_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.linear op_0 3 1 input weight bias out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "InnerProduct";
}

const char* name_str() const
{
return "linear";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["0"] = weight.shape[0];
op->params["1"] = 1;
op->params["2"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_linear_1, 20)

} // namespace ncnn

} // namespace pnnx

+ 3
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -292,6 +292,9 @@ pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d)
pnnx_add_test(pnnx_fuse_linear_batchnorm1d)
pnnx_add_test(pnnx_fuse_select_to_unbind)
pnnx_add_test(pnnx_fuse_slice_to_tensor_split)
pnnx_add_test(pnnx_fuse_adjacent_reshape)
pnnx_add_test(pnnx_fuse_pad_conv1d)
pnnx_add_test(pnnx_fuse_pad_conv2d)

if(Torch_VERSION VERSION_GREATER_EQUAL "1.9")
pnnx_add_test(F_mish)


+ 61
- 0
tools/pnnx/tests/test_pnnx_fuse_adjacent_reshape.py View File

@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = x.view(1, 1, 8).reshape(2, -1)
y = y.reshape(-1, x.size(0)).unsqueeze(1)
z = z.unsqueeze(0).unsqueeze(2).view(-1)
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(8)
y = torch.rand(9, 10)
z = torch.rand(8, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_pnnx_fuse_adjacent_reshape.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_fuse_adjacent_reshape.pt inputshape=[8],[9,10],[8,9,10]")

# pnnx inference
import test_pnnx_fuse_adjacent_reshape_pnnx
b = test_pnnx_fuse_adjacent_reshape_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 84
- 0
tools/pnnx/tests/test_pnnx_fuse_pad_conv1d.py View File

@@ -0,0 +1,84 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.pad_0 = nn.ConstantPad1d(2, 0.0)
self.pad_1 = nn.ReflectionPad1d(4)
self.pad_2 = nn.ReplicationPad1d(3)

self.conv_0 = nn.Conv1d(in_channels=12, out_channels=14, kernel_size=3)
self.conv_1 = nn.Conv1d(in_channels=14, out_channels=14, kernel_size=1)
self.conv_2 = nn.Conv1d(in_channels=14, out_channels=14, kernel_size=2)
self.conv_3 = nn.Conv1d(in_channels=14, out_channels=12, kernel_size=3, padding=(1,))

def forward(self, x):
x = self.pad_0(x)
x = F.pad(x, pad=(1,1))
x = self.conv_0(x)

x = self.pad_1(x)
x = self.conv_1(x)

x = F.pad(x, pad=(3,3), mode='reflect')
x = self.conv_1(x)

x = self.pad_2(x)
x = self.conv_2(x)

x = F.pad(x, pad=(1,1), mode='replicate')
x = self.conv_2(x)

x = F.pad(x, pad=(2,2))
x = self.conv_3(x)

return x

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 12, 13)

a = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_pnnx_pnnx_fuse_pad_conv1d.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_pnnx_fuse_pad_conv1d.pt inputshape=[1,12,13]")

# pnnx inference
import test_pnnx_pnnx_fuse_pad_conv1d_pnnx
b = test_pnnx_pnnx_fuse_pad_conv1d_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 86
- 0
tools/pnnx/tests/test_pnnx_fuse_pad_conv2d.py View File

@@ -0,0 +1,86 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.pad_0 = nn.ConstantPad2d(2, 0.0)
self.pad_1 = nn.ReflectionPad2d(4)
self.pad_2 = nn.ReplicationPad2d(3)
self.pad_3 = nn.ZeroPad2d((1,1,0,0))

self.conv_0 = nn.Conv2d(in_channels=12, out_channels=14, kernel_size=3)
self.conv_1 = nn.Conv2d(in_channels=14, out_channels=14, kernel_size=1)
self.conv_2 = nn.Conv2d(in_channels=14, out_channels=14, kernel_size=2)
self.conv_3 = nn.Conv2d(in_channels=14, out_channels=12, kernel_size=3, padding=(1,1))

def forward(self, x):
x = self.pad_0(x)
x = F.pad(x, pad=(1,1))
x = self.conv_0(x)

x = self.pad_1(x)
x = self.conv_1(x)

x = F.pad(x, pad=(3,3,2,2), mode='reflect')
x = self.conv_1(x)

x = self.pad_2(x)
x = self.conv_2(x)

x = F.pad(x, pad=(1,1,1,1), mode='replicate')
x = self.conv_2(x)

x = self.pad_3(x)
x = F.pad(x, pad=(2,2,0,0))
x = self.conv_3(x)

return x

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 12, 13, 13)

a = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_pnnx_pnnx_fuse_pad_conv2d.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_pnnx_fuse_pad_conv2d.pt inputshape=[1,12,13,13]")

# pnnx inference
import test_pnnx_pnnx_fuse_pad_conv2d_pnnx
b = test_pnnx_pnnx_fuse_pad_conv2d_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save