Browse Source

pnnx fuse multiheadattention (#4544)

* torch baddbmm

* always convert to fp32 for shape inference

* silence info on nonetype and devicetype
tags/20230517
nihui GitHub 3 years ago
parent
commit
ae4f630467
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1467 additions and 21 deletions
  1. +1
    -1
      .ci/pnnx.yml
  2. +3
    -0
      tools/pnnx/src/CMakeLists.txt
  3. +16
    -3
      tools/pnnx/src/ir.cpp
  4. +45
    -0
      tools/pnnx/src/pass_level0/convert_half_to_float.cpp
  5. +21
    -0
      tools/pnnx/src/pass_level0/convert_half_to_float.h
  6. +15
    -1
      tools/pnnx/src/pass_level0/reset_device.cpp
  7. +3
    -0
      tools/pnnx/src/pass_level0/shape_inference.cpp
  8. +8
    -1
      tools/pnnx/src/pass_level2.cpp
  9. +51
    -0
      tools/pnnx/src/pass_level2/torch_arange.cpp
  10. +44
    -0
      tools/pnnx/src/pass_level2/torch_baddbmm.cpp
  11. +1
    -1
      tools/pnnx/src/pass_level3/eliminate_noop_math.cpp
  12. +2
    -0
      tools/pnnx/src/pass_level5.cpp
  13. +2
    -2
      tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp
  14. +996
    -0
      tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp
  15. +21
    -0
      tools/pnnx/src/pass_level5/fuse_multiheadattention.h
  16. +5
    -5
      tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp
  17. +6
    -6
      tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp
  18. +5
    -0
      tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp
  19. +1
    -1
      tools/pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp
  20. +1
    -0
      tools/pnnx/tests/CMakeLists.txt
  21. +220
    -0
      tools/pnnx/tests/test_pnnx_fuse_multiheadattention.py

+ 1
- 1
.ci/pnnx.yml View File

@@ -62,7 +62,7 @@ jobs:
python3 -m pip install --upgrade pip
apt-get remove -y python3-setuptools
pip3 install -U setuptools
pip3 install -U pytest wheel twine requests
pip3 install -U pytest wheel twine requests einops

- name: setup pytorch
run: |


+ 3
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -3,6 +3,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})

set(pnnx_pass_level0_SRCS
pass_level0/constant_unpooling.cpp
pass_level0/convert_half_to_float.cpp
pass_level0/inline_block.cpp
pass_level0/reset_device.cpp
pass_level0/flatten_input.cpp
@@ -194,6 +195,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_arange.cpp
pass_level2/torch_argmax.cpp
pass_level2/torch_argmin.cpp
pass_level2/torch_baddbmm.cpp
pass_level2/torch_bmm.cpp
pass_level2/torch_bitwise_not.cpp
pass_level2/torch_bitwise_and.cpp
@@ -324,6 +326,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_pad_conv1d.cpp
pass_level5/fuse_pad_conv2d.cpp
pass_level5/fuse_pixel_unshuffle.cpp
pass_level5/fuse_multiheadattention.cpp
pass_level5/fuse_select_to_unbind.cpp
pass_level5/fuse_slice_copy.cpp
pass_level5/fuse_slice_indices.cpp


+ 16
- 3
tools/pnnx/src/ir.cpp View File

@@ -160,9 +160,16 @@ Parameter::Parameter(const torch::jit::Node* value_node)

if (value_node->kind() == c10::prim::Constant)
{
if (value_node->output()->type()->kind() == c10::TypeKind::NoneType)
{
type = 0;
return;
}

if (!value_node->hasAttribute(torch::jit::attr::value))
{
fprintf(stderr, "no attribute value\n");
value_node->dump();
return;
}

@@ -200,6 +207,12 @@ Parameter::Parameter(const torch::jit::Node* value_node)
s = value_node->s(torch::jit::attr::value);
break;
}
case c10::TypeKind::DeviceObjType:
{
type = 4;
s = value_node->s(torch::jit::attr::value);
break;
}
case c10::TypeKind::TensorType:
{
at::Tensor t = value_node->t(torch::jit::attr::value);
@@ -246,7 +259,7 @@ Parameter::Parameter(const torch::jit::Node* value_node)
}
default:
{
fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString());
fprintf(stderr, "unknown Parameter value kind %s\n", c10::typeKindToString(value_node->output()->type()->kind()));
break;
}
}
@@ -284,14 +297,14 @@ Parameter::Parameter(const torch::jit::Node* value_node)
}
default:
{
fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString());
fprintf(stderr, "unknown Parameter value list element kind %s\n", c10::typeKindToString(value_node->output()->type()->cast<c10::ListType>()->getElementType()->kind()));
break;
}
}
}
else
{
fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString());
fprintf(stderr, "unknown Parameter value_node kind %s\n", value_node->kind().toDisplayString());
}
}



+ 45
- 0
tools/pnnx/src/pass_level0/convert_half_to_float.cpp View File

@@ -0,0 +1,45 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "convert_half_to_float.h"

namespace pnnx {

void convert_half_to_float(torch::jit::Module& mod)
{
for (auto submod : mod.children())
{
convert_half_to_float(submod);
}

for (auto named_attr : mod.named_attributes(false))
{
const std::string& name = named_attr.name;
auto attr = named_attr.value;

if (attr.type()->kind() == c10::TypeKind::TensorType)
{
auto t = attr.toTensor();

if (t.scalar_type() == c10::ScalarType::Half)
{
at::Tensor t_fp32 = t.toType(c10::ScalarType::Float);

mod.setattr(name, t_fp32);
}
}
}
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level0/convert_half_to_float.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include <torch/script.h>

namespace pnnx {

void convert_half_to_float(torch::jit::Module& mod);

} // namespace pnnx

+ 15
- 1
tools/pnnx/src/pass_level0/reset_device.cpp View File

@@ -20,8 +20,22 @@ void reset_device(std::shared_ptr<torch::jit::Graph>& graph, const std::string&
{
for (torch::jit::Node* n : graph->nodes())
{
if (n->kind().toDisplayString() == std::string("aten::to"))
if (n->kind().is_aten())
{
if (n->hasNamedInput("dtype"))
{
torch::jit::Node* dtype_node = n->namedInput("dtype")->node();

if (dtype_node->hasAttribute(torch::jit::attr::value))
{
// change dtype=half to dtype=float
if (dtype_node->i(torch::jit::attr::value) == 5)
{
dtype_node->i_(torch::jit::attr::value, 6);
}
}
}

if (n->hasNamedInput("device"))
{
torch::jit::Node* device_node = n->namedInput("device")->node();


+ 3
- 0
tools/pnnx/src/pass_level0/shape_inference.cpp View File

@@ -17,6 +17,7 @@

#include "storezip.h"
#include "pass_level0/constant_unpooling.h"
#include "pass_level0/convert_half_to_float.h"
#include "pass_level0/flatten_input.h"
#include "pass_level0/inline_block.h"
#include "pass_level0/reset_device.h"
@@ -157,6 +158,8 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
torch::jit::Module mod2 = torch::jit::load(ptpath, (device == "gpu") ? c10::kCUDA : c10::kCPU);
mod2.eval();

convert_half_to_float(mod2);

auto graph2 = mod2.get_method("forward").graph();

inline_block(graph2, module_operators);


+ 8
- 1
tools/pnnx/src/pass_level2.cpp View File

@@ -79,8 +79,15 @@ static bool match_parameter(const Parameter& a, const Parameter& b, std::map<std
{
if (b.type == 4 && b.s[0] == '%')
{
std::string key = b.s.substr(1);
if (captured_params.find(key) != captured_params.end())
{
// match previous captured parameter
return captured_params.at(key) == a;
}

// captured parameter
captured_params[b.s.substr(1)] = a;
captured_params[key] = a;
return true;
}



+ 51
- 0
tools/pnnx/src/pass_level2/torch_arange.cpp View File

@@ -68,4 +68,55 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_arange_1, 20)

class torch_arange_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
8 7
pnnx.Input input_0 0 1 start
pnnx.Input input_1 0 1 end
prim::Constant op_0 0 1 dtype value=*
prim::Constant op_1 0 1 layout value=*
prim::Constant op_2 0 1 device value=*
prim::Constant op_3 0 1 pin_memory value=*
aten::arange op_4 6 1 start end dtype layout device pin_memory out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.arange";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_arange_2, 20)

class torch_arange_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 start
pnnx.Input input_1 0 1 end
prim::Constant op_0 0 1 dtype value=*
prim::Constant op_1 0 1 layout value=*
prim::Constant op_2 0 1 pin_memory value=*
aten::arange op_3 6 1 start end dtype layout layout pin_memory out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.arange";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_arange_3, 20)

} // namespace pnnx

+ 44
- 0
tools/pnnx/src/pass_level2/torch_baddbmm.cpp View File

@@ -0,0 +1,44 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_baddbmm : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 batch1
pnnx.Input input_2 0 1 batch2
pnnx.Input input_3 0 1 beta
pnnx.Input input_4 0 1 alpha
aten::baddbmm op_0 5 1 input batch1 batch2 beta alpha out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.baddbmm";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_baddbmm, 20)

} // namespace pnnx

+ 1
- 1
tools/pnnx/src/pass_level3/eliminate_noop_math.cpp View File

@@ -257,7 +257,7 @@ void eliminate_noop_math(Graph& graph)
if (!need_eliminate)
continue;

fprintf(stderr, "eliminate_noop_math %s %s\n", op->type.c_str(), op->name.c_str());
// fprintf(stderr, "eliminate_noop_math %s %s\n", op->type.c_str(), op->name.c_str());

for (auto& x : op->inputs)
{


+ 2
- 0
tools/pnnx/src/pass_level5.cpp View File

@@ -36,6 +36,7 @@
#include "pass_level5/fuse_convtranspose2d_batchnorm2d.h"
#include "pass_level5/fuse_contiguous_view.h"
#include "pass_level5/fuse_linear_batchnorm1d.h"
#include "pass_level5/fuse_multiheadattention.h"
#include "pass_level5/fuse_pad_conv1d.h"
#include "pass_level5/fuse_pad_conv2d.h"
#include "pass_level5/fuse_select_to_unbind.h"
@@ -123,6 +124,7 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons
eliminate_reshape_shape_expression(g);

fuse_channel_shuffle(g);
fuse_multiheadattention(g);

fuse_index_expression(g);



+ 2
- 2
tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp View File

@@ -45,7 +45,7 @@ pnnx.Output output 1 0 out
return "channelshuffle";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
const std::string& expr = captured_params.at("expr").s;
const std::string& expr2 = captured_params.at("expr2").s;
@@ -97,7 +97,7 @@ pnnx.Output output 1 0 out
return "channelshuffle";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// (1,2,58,28,28)
// (1,-1,28,28)


+ 996
- 0
tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp View File

@@ -0,0 +1,996 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_multiheadattention.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

class fuse_multiheadattention_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
14 13
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input 1 bias=%qkv_bias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape op_1 1 1 1 2 shape=%shape
torch.permute op_2 1 1 2 3 dims=(2,0,3,1,4)
torch.unbind op_3 1 3 3 4 5 6 dim=0
pnnx.Expression op_4 1 1 4 7 expr=%expr
torch.permute op_5 1 1 5 8 dims=(0,1,3,2)
torch.matmul op_6 2 1 7 8 9
F.softmax op_7 1 1 9 10 dim=-1
torch.matmul op_8 2 1 10 6 11
torch.permute op_9 1 1 11 12 dims=(0,2,1,3)
Tensor.reshape op_10 1 1 12 13 shape=%shape2
nn.Linear out_proj 1 1 13 out bias=%out_proj_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.MultiheadAttention";
}

const char* name_str() const
{
return "attention";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int embed_dim = captured_params.at("embed_dim").i;
const int qkv_out_features = captured_params.at("qkv_out_features").i;
if (qkv_out_features != embed_dim * 3)
return false;

// (1,-1,3,4,16)
// (1,-1,64)
const std::vector<int>& shape = captured_params.at("shape").ai;
const std::vector<int>& shape2 = captured_params.at("shape2").ai;
if (shape.size() != 5 || shape2.size() != 3)
return false;

const int num_heads = shape[3];
if (shape[0] != shape2[0] || shape[2] != 3 || shape[3] * shape[4] != shape2[2])
return false;

// mul(@0,2.581989e-01)
const std::string& expr = captured_params.at("expr").s;
float inv_sqrt_embed_dim_per_head = 0.f;
int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head);
if (nscan != 1)
return false;

if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001))
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
int num_heads = captured_params.at("shape").ai[3];

bool qkv_bias = captured_params.at("qkv_bias").b;
bool out_proj_bias = captured_params.at("out_proj_bias").b;
bool bias = qkv_bias || out_proj_bias;

op->params["num_heads"] = num_heads;
op->params["batch_first"] = true;
op->params["add_zero_attn"] = false;
op->params["add_bias_kv"] = false;
op->params["bias"] = bias;

int embed_dim = captured_params.at("embed_dim").i;

op->params["embed_dim"] = embed_dim;
op->params["kdim"] = embed_dim;
op->params["vdim"] = embed_dim;

op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight");
if (bias)
{
if (qkv_bias)
{
op->attrs["in_proj_bias"] = captured_attrs.at("op_0.bias");
}
else
{
// init bias as zero
op->attrs["in_proj_bias"] = Attribute();
op->attrs["in_proj_bias"].type = 1;
op->attrs["in_proj_bias"].shape = {embed_dim * 3};

op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float));
memset(op->attrs["in_proj_bias"].data.data(), 0, embed_dim * 3 * sizeof(float));
}
}

op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight");
if (bias)
{
if (out_proj_bias)
{
op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias");
}
else
{
// init bias as zero
op->attrs["out_proj.bias"] = Attribute();
op->attrs["out_proj.bias"].type = 1;
op->attrs["out_proj.bias"].shape = {embed_dim};

op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float));
memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float));
}
}
}
};

class fuse_multiheadattention_pass_sameqkv : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
23 22
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input 31 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input 32 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input 33 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
pnnx.Expression op_3 1 1 32 34 expr=%expr
Tensor.reshape op_4 1 1 31 35 shape=%q_shape
Tensor.reshape op_5 1 1 34 36 shape=%kv_shape
Tensor.reshape op_6 1 1 33 37 shape=%kv_shape
torch.permute op_7 1 1 36 38 dims=(0,2,1,3)
Tensor.reshape op_8 1 1 38 39 shape=%kv_shape2
torch.permute op_9 1 1 35 40 dims=(0,2,1,3)
Tensor.reshape op_10 1 1 40 41 shape=%q_shape2
torch.permute op_11 1 1 39 42 dims=(0,2,1)
torch.matmul op_12 2 1 41 42 43
F.softmax op_13 1 1 43 44 dim=-1
torch.permute op_14 1 1 37 45 dims=(0,2,1,3)
Tensor.reshape op_15 1 1 45 46 shape=%kv_shape2
torch.matmul op_16 2 1 44 46 47
Tensor.reshape op_18 1 1 47 48 shape=%qkv_shape
torch.permute op_19 1 1 48 49 dims=(0,2,1,3)
Tensor.reshape op_20 1 1 49 50 shape=%qkv_shape2
nn.Linear out_proj 1 1 50 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.MultiheadAttention";
}

const char* name_str() const
{
return "attention";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
// q_shape = (1,q,8,40)
// kv_shape = (1,kv,8,40)
// q_shape2 = (8,q,40)
// kv_shape2 = (8,kv,40)
// qkv_shape = (1,8,q,40)
// qkv_shape2 = (1,q,320)
const std::vector<int>& q_shape = captured_params.at("q_shape").ai;
const std::vector<int>& kv_shape = captured_params.at("kv_shape").ai;
const std::vector<int>& q_shape2 = captured_params.at("q_shape2").ai;
const std::vector<int>& kv_shape2 = captured_params.at("kv_shape2").ai;
const std::vector<int>& qkv_shape = captured_params.at("qkv_shape").ai;
const std::vector<int>& qkv_shape2 = captured_params.at("qkv_shape2").ai;
if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3)
return false;

const int batch_size = q_shape[0];
const int q_size = q_shape[1];
const int num_heads = q_shape[2];
const int feat_per_head = q_shape[3];
const int kv_size = kv_shape[1];
if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size)
return false;

if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size)
return false;

if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads)
return false;

if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads)
return false;

// mul(@0,1.581139e-01)
const std::string& expr = captured_params.at("expr").s;
float inv_sqrt_embed_dim_per_head = 0.f;
int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head);
if (nscan != 1)
return false;

const int embed_dim = captured_params.at("embed_dim").i;
if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001))
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
int embed_dim = captured_params.at("embed_dim").i;
int kdim = captured_params.at("kdim").i;
int vdim = captured_params.at("vdim").i;

// (1,*,8,40)
int num_heads = captured_params.at("q_shape").ai[2];

bool q_bias = captured_params.at("q_bias").b;
bool k_bias = captured_params.at("k_bias").b;
bool v_bias = captured_params.at("v_bias").b;
bool out_bias = captured_params.at("out_bias").b;
bool bias = q_bias || k_bias || v_bias || out_bias;

op->params["embed_dim"] = embed_dim;
op->params["kdim"] = kdim;
op->params["vdim"] = vdim;

op->params["num_heads"] = num_heads;
op->params["batch_first"] = true;
op->params["add_zero_attn"] = false;
op->params["add_bias_kv"] = false;
op->params["bias"] = bias;

op->attrs["in_proj_weight"] = Attribute();
op->attrs["in_proj_weight"].type = 1;
op->attrs["in_proj_weight"].shape = {embed_dim * 3, embed_dim};
op->attrs["in_proj_weight"].data.resize(embed_dim * 3 * embed_dim * sizeof(float));

// combine qkv weight
{
float* in_proj_weight_ptr = (float*)op->attrs["in_proj_weight"].data.data();
memcpy(in_proj_weight_ptr, captured_attrs.at("op_0.weight").data.data(), embed_dim * embed_dim * sizeof(float));
in_proj_weight_ptr += embed_dim * embed_dim;
memcpy(in_proj_weight_ptr, captured_attrs.at("op_1.weight").data.data(), embed_dim * embed_dim * sizeof(float));
in_proj_weight_ptr += embed_dim * embed_dim;
memcpy(in_proj_weight_ptr, captured_attrs.at("op_2.weight").data.data(), embed_dim * embed_dim * sizeof(float));
}

op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight");

if (bias)
{
op->attrs["in_proj_bias"] = Attribute();
op->attrs["in_proj_bias"].type = 1;
op->attrs["in_proj_bias"].shape = {embed_dim * 3};
op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float));

// combine qkv bias
{
float* in_proj_bias_ptr = (float*)op->attrs["in_proj_bias"].data.data();
if (q_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_0.bias").data.data(), embed_dim * sizeof(float));
}
else
{
memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float));
}
in_proj_bias_ptr += embed_dim;
if (k_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_1.bias").data.data(), embed_dim * sizeof(float));
}
else
{
memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float));
}
in_proj_bias_ptr += embed_dim;
if (v_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_2.bias").data.data(), embed_dim * sizeof(float));
}
else
{
memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float));
}
}

if (out_bias)
{
op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias");
}
else
{
// init bias as zero
op->attrs["out_proj.bias"] = Attribute();
op->attrs["out_proj.bias"].type = 1;
op->attrs["out_proj.bias"].shape = {embed_dim};

op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float));
memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float));
}
}
}
};

class fuse_multiheadattention_pass_qkv : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
25 24
pnnx.Input input_q 0 1 q
pnnx.Input input_k 0 1 k
pnnx.Input input_v 0 1 v
nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 k 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 v 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
pnnx.Expression op_3 1 1 33 35 expr=%expr
Tensor.reshape op_4 1 1 32 36 shape=%q_shape
Tensor.reshape op_5 1 1 35 37 shape=%kv_shape
Tensor.reshape op_6 1 1 34 38 shape=%kv_shape
torch.permute op_7 1 1 37 39 dims=(0,2,1,3)
Tensor.reshape op_8 1 1 39 40 shape=%kv_shape2
torch.permute op_9 1 1 36 41 dims=(0,2,1,3)
Tensor.reshape op_10 1 1 41 42 shape=%q_shape2
torch.permute op_11 1 1 40 43 dims=(0,2,1)
torch.matmul op_12 2 1 42 43 44
F.softmax op_13 1 1 44 45 dim=-1
torch.permute op_14 1 1 38 46 dims=(0,2,1,3)
Tensor.reshape op_15 1 1 46 47 shape=%kv_shape2
torch.matmul op_16 2 1 45 47 48
Tensor.reshape op_17 1 1 48 49 shape=%qkv_shape
torch.permute op_18 1 1 49 50 dims=(0,2,1,3)
Tensor.reshape op_19 1 1 50 51 shape=%qkv_shape2
nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.MultiheadAttention";
}

const char* name_str() const
{
return "attention";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
// q_shape = (1,q,8,40)
// kv_shape = (1,kv,8,40)
// q_shape2 = (8,q,40)
// kv_shape2 = (8,kv,40)
// qkv_shape = (1,8,q,40)
// qkv_shape2 = (1,q,320)
const std::vector<int>& q_shape = captured_params.at("q_shape").ai;
const std::vector<int>& kv_shape = captured_params.at("kv_shape").ai;
const std::vector<int>& q_shape2 = captured_params.at("q_shape2").ai;
const std::vector<int>& kv_shape2 = captured_params.at("kv_shape2").ai;
const std::vector<int>& qkv_shape = captured_params.at("qkv_shape").ai;
const std::vector<int>& qkv_shape2 = captured_params.at("qkv_shape2").ai;
if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3)
return false;

const int batch_size = q_shape[0];
const int q_size = q_shape[1];
const int num_heads = q_shape[2];
const int feat_per_head = q_shape[3];
const int kv_size = kv_shape[1];
if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size)
return false;

if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size)
return false;

if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads)
return false;

if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads)
return false;

// mul(@0,1.581139e-01)
const std::string& expr = captured_params.at("expr").s;
float inv_sqrt_embed_dim_per_head = 0.f;
int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head);
if (nscan != 1)
return false;

const int embed_dim = captured_params.at("embed_dim").i;
if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001))
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
int embed_dim = captured_params.at("embed_dim").i;
int kdim = captured_params.at("kdim").i;
int vdim = captured_params.at("vdim").i;

// (1,*,8,40)
int num_heads = captured_params.at("q_shape").ai[2];

bool q_bias = captured_params.at("q_bias").b;
bool k_bias = captured_params.at("k_bias").b;
bool v_bias = captured_params.at("v_bias").b;
bool out_bias = captured_params.at("out_bias").b;
bool bias = q_bias || k_bias || v_bias || out_bias;

op->params["embed_dim"] = embed_dim;
op->params["kdim"] = kdim;
op->params["vdim"] = vdim;

op->params["num_heads"] = num_heads;
op->params["batch_first"] = true;
op->params["add_zero_attn"] = false;
op->params["add_bias_kv"] = false;
op->params["bias"] = bias;

op->attrs["q_proj_weight"] = captured_attrs.at("op_0.weight");
op->attrs["k_proj_weight"] = captured_attrs.at("op_1.weight");
op->attrs["v_proj_weight"] = captured_attrs.at("op_2.weight");
op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight");

if (bias)
{
op->attrs["in_proj_bias"] = Attribute();
op->attrs["in_proj_bias"].type = 1;
op->attrs["in_proj_bias"].shape = {embed_dim * 3};
op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float));

// combine qkv bias
{
float* in_proj_bias_ptr = (float*)op->attrs["in_proj_bias"].data.data();
if (q_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_0.bias").data.data(), embed_dim * sizeof(float));
}
else
{
memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float));
}
in_proj_bias_ptr += embed_dim;
if (k_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_1.bias").data.data(), embed_dim * sizeof(float));
}
else
{
memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float));
}
in_proj_bias_ptr += embed_dim;
if (v_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_2.bias").data.data(), embed_dim * sizeof(float));
}
else
{
memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float));
}
}

if (out_bias)
{
op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias");
}
else
{
// init bias as zero
op->attrs["out_proj.bias"] = Attribute();
op->attrs["out_proj.bias"].type = 1;
op->attrs["out_proj.bias"].shape = {embed_dim};

op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float));
memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float));
}
}
}
};

class fuse_multiheadattention_pass_q_samekv : public fuse_multiheadattention_pass_qkv
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
24 23
pnnx.Input input_q 0 1 q
pnnx.Input input_kv 0 1 kv
nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 kv 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 kv 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
pnnx.Expression op_3 1 1 33 35 expr=%expr
Tensor.reshape op_4 1 1 32 36 shape=%q_shape
Tensor.reshape op_5 1 1 35 37 shape=%kv_shape
Tensor.reshape op_6 1 1 34 38 shape=%kv_shape
torch.permute op_7 1 1 37 39 dims=(0,2,1,3)
Tensor.reshape op_8 1 1 39 40 shape=%kv_shape2
torch.permute op_9 1 1 36 41 dims=(0,2,1,3)
Tensor.reshape op_10 1 1 41 42 shape=%q_shape2
torch.permute op_11 1 1 40 43 dims=(0,2,1)
torch.matmul op_12 2 1 42 43 44
F.softmax op_13 1 1 44 45 dim=-1
torch.permute op_14 1 1 38 46 dims=(0,2,1,3)
Tensor.reshape op_15 1 1 46 47 shape=%kv_shape2
torch.matmul op_16 2 1 45 47 48
Tensor.reshape op_17 1 1 48 49 shape=%qkv_shape
torch.permute op_18 1 1 49 50 dims=(0,2,1,3)
Tensor.reshape op_19 1 1 50 51 shape=%qkv_shape2
nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_1 : public fuse_multiheadattention_pass_sameqkv
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
22 21
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input 31 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input 32 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input 33 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 31 35 shape=%q_shape
Tensor.reshape op_4 1 1 32 36 shape=%kv_shape
Tensor.reshape op_5 1 1 33 37 shape=%kv_shape
torch.permute op_6 1 1 36 38 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 38 39 shape=%kv_shape2
torch.permute op_8 1 1 35 40 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 40 41 shape=%q_shape2
torch.einsum op_10 2 1 41 39 42 equation=ijl,ikl->ijk
pnnx.Expression op_11 1 1 42 43 expr=%expr
F.softmax op_12 1 1 43 44 dim=-1
torch.permute op_13 1 1 37 45 dims=(0,2,1,3)
Tensor.reshape op_14 1 1 45 46 shape=%kv_shape2
torch.einsum op_15 2 1 44 46 47 equation=ijl,ilk->ijk
Tensor.reshape op_16 1 1 47 48 shape=%qkv_shape
torch.permute op_17 1 1 48 49 dims=(0,2,1,3)
Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape2
nn.Linear out_proj 1 1 50 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_2 : public fuse_multiheadattention_pass_qkv
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
24 23
pnnx.Input input_q 0 1 q
pnnx.Input input_k 0 1 k
pnnx.Input input_v 0 1 v
nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 k 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 v 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 32 36 shape=%q_shape
Tensor.reshape op_4 1 1 33 37 shape=%kv_shape
Tensor.reshape op_5 1 1 34 38 shape=%kv_shape
torch.permute op_6 1 1 37 39 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 39 40 shape=%kv_shape2
torch.permute op_8 1 1 36 41 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 41 42 shape=%q_shape2
torch.einsum op_10 2 1 42 40 43 equation=ijl,ikl->ijk
pnnx.Expression op_11 1 1 43 44 expr=%expr
F.softmax op_12 1 1 44 45 dim=-1
torch.permute op_13 1 1 38 46 dims=(0,2,1,3)
Tensor.reshape op_14 1 1 46 47 shape=%kv_shape2
torch.einsum op_15 2 1 45 47 48 equation=ijl,ilk->ijk
Tensor.reshape op_16 1 1 48 49 shape=%qkv_shape
torch.permute op_17 1 1 49 50 dims=(0,2,1,3)
Tensor.reshape op_18 1 1 50 51 shape=%qkv_shape2
nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_3 : public fuse_multiheadattention_pass_q_samekv
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
23 22
pnnx.Input input_q 0 1 q
pnnx.Input input_kv 0 1 kv
nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 kv 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 kv 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 32 36 shape=%q_shape
Tensor.reshape op_4 1 1 33 37 shape=%kv_shape
Tensor.reshape op_5 1 1 34 38 shape=%kv_shape
torch.permute op_6 1 1 37 39 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 39 40 shape=%kv_shape2
torch.permute op_8 1 1 36 41 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 41 42 shape=%q_shape2
torch.einsum op_10 2 1 42 40 43 equation=ijl,ikl->ijk
pnnx.Expression op_11 1 1 43 44 expr=%expr
F.softmax op_12 1 1 44 45 dim=-1
torch.permute op_13 1 1 38 46 dims=(0,2,1,3)
Tensor.reshape op_14 1 1 46 47 shape=%kv_shape2
torch.einsum op_15 2 1 45 47 48 equation=ijl,ilk->ijk
Tensor.reshape op_16 1 1 48 49 shape=%qkv_shape
torch.permute op_17 1 1 49 50 dims=(0,2,1,3)
Tensor.reshape op_18 1 1 50 51 shape=%qkv_shape2
nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_5 : public fuse_multiheadattention_pass_sameqkv
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
23 22
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 33 36 shape=%q_shape
Tensor.reshape op_4 1 1 34 37 shape=%kv_shape
Tensor.reshape op_5 1 1 35 38 shape=%kv_shape
torch.permute op_6 1 1 36 39 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 39 40 shape=%q_shape2
torch.permute op_8 1 1 37 41 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2
pnnx.Attribute op_10 0 1 43 @zeros
torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2
torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0
F.softmax op_13 1 1 45 46 dim=-1
torch.permute op_14 1 1 38 47 dims=(0,2,1,3)
Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2
torch.bmm op_16 2 1 46 48 49
Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape
torch.permute op_18 1 1 50 51 dims=(0,2,1,3)
Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2
nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
// q_shape = (1,q,8,40)
// kv_shape = (1,kv,8,40)
// q_shape2 = (8,q,40)
// kv_shape2 = (8,kv,40)
// qkv_shape = (1,8,q,40)
// qkv_shape2 = (1,q,320)
const std::vector<int>& q_shape = captured_params.at("q_shape").ai;
const std::vector<int>& kv_shape = captured_params.at("kv_shape").ai;
const std::vector<int>& q_shape2 = captured_params.at("q_shape2").ai;
const std::vector<int>& kv_shape2 = captured_params.at("kv_shape2").ai;
const std::vector<int>& qkv_shape = captured_params.at("qkv_shape").ai;
const std::vector<int>& qkv_shape2 = captured_params.at("qkv_shape2").ai;
if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3)
return false;

const int batch_size = q_shape[0];
const int q_size = q_shape[1];
const int num_heads = q_shape[2];
const int feat_per_head = q_shape[3];
const int kv_size = kv_shape[1];
if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size)
return false;

if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size)
return false;

if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads)
return false;

if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads)
return false;

const float inv_sqrt_embed_dim_per_head = captured_params.at("alpha").f;
const int embed_dim = captured_params.at("embed_dim").i;
if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001))
return false;

return true;
}
};

class fuse_multiheadattention_pass_6 : public fuse_multiheadattention_pass_5
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
24 23
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 33 36 shape=%q_shape
Tensor.reshape op_4 1 1 34 37 shape=%kv_shape
Tensor.reshape op_5 1 1 35 38 shape=%kv_shape
torch.permute op_6 1 1 36 39 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 39 40 shape=%q_shape2
torch.permute op_8 1 1 37 41 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2
pnnx.Expression op_10 2 1 40 42 43 expr=%expr_zero_shape
torch.empty op_11 1 1 43 zeros
torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2
torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0
F.softmax op_14 1 1 45 46 dim=-1
torch.permute op_15 1 1 38 47 dims=(0,2,1,3)
Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2
torch.bmm op_17 2 1 46 48 49
Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape
torch.permute op_19 1 1 50 51 dims=(0,2,1,3)
Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2
nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_7 : public fuse_multiheadattention_pass_qkv
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
25 24
pnnx.Input input_q 0 1 q
pnnx.Input input_k 0 1 k
pnnx.Input input_v 0 1 v
nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 k 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 v 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 33 36 shape=%q_shape
Tensor.reshape op_4 1 1 34 37 shape=%kv_shape
Tensor.reshape op_5 1 1 35 38 shape=%kv_shape
torch.permute op_6 1 1 36 39 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 39 40 shape=%q_shape2
torch.permute op_8 1 1 37 41 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2
pnnx.Attribute op_10 0 1 43 @zeros
torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2
torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0
F.softmax op_13 1 1 45 46 dim=-1
torch.permute op_14 1 1 38 47 dims=(0,2,1,3)
Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2
torch.bmm op_16 2 1 46 48 49
Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape
torch.permute op_18 1 1 50 51 dims=(0,2,1,3)
Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2
nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
// q_shape = (1,q,8,40)
// kv_shape = (1,kv,8,40)
// q_shape2 = (8,q,40)
// kv_shape2 = (8,kv,40)
// qkv_shape = (1,8,q,40)
// qkv_shape2 = (1,q,320)
const std::vector<int>& q_shape = captured_params.at("q_shape").ai;
const std::vector<int>& kv_shape = captured_params.at("kv_shape").ai;
const std::vector<int>& q_shape2 = captured_params.at("q_shape2").ai;
const std::vector<int>& kv_shape2 = captured_params.at("kv_shape2").ai;
const std::vector<int>& qkv_shape = captured_params.at("qkv_shape").ai;
const std::vector<int>& qkv_shape2 = captured_params.at("qkv_shape2").ai;
if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3)
return false;

const int batch_size = q_shape[0];
const int q_size = q_shape[1];
const int num_heads = q_shape[2];
const int feat_per_head = q_shape[3];
const int kv_size = kv_shape[1];
if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size)
return false;

if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size)
return false;

if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads)
return false;

if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads)
return false;

const float inv_sqrt_embed_dim_per_head = captured_params.at("alpha").f;
const int embed_dim = captured_params.at("embed_dim").i;
if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001))
return false;

return true;
}
};

class fuse_multiheadattention_pass_8 : public fuse_multiheadattention_pass_7
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
24 23
pnnx.Input input_q 0 1 q
pnnx.Input input_kv 0 1 kv
nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 kv 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 kv 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 33 36 shape=%q_shape
Tensor.reshape op_4 1 1 34 37 shape=%kv_shape
Tensor.reshape op_5 1 1 35 38 shape=%kv_shape
torch.permute op_6 1 1 36 39 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 39 40 shape=%q_shape2
torch.permute op_8 1 1 37 41 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2
pnnx.Attribute op_10 0 1 43 @zeros
torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2
torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0
F.softmax op_13 1 1 45 46 dim=-1
torch.permute op_14 1 1 38 47 dims=(0,2,1,3)
Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2
torch.bmm op_16 2 1 46 48 49
Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape
torch.permute op_18 1 1 50 51 dims=(0,2,1,3)
Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2
nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_9 : public fuse_multiheadattention_pass_7
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
26 25
pnnx.Input input_q 0 1 q
pnnx.Input input_k 0 1 k
pnnx.Input input_v 0 1 v
nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 k 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 v 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 33 36 shape=%q_shape
Tensor.reshape op_4 1 1 34 37 shape=%kv_shape
Tensor.reshape op_5 1 1 35 38 shape=%kv_shape
torch.permute op_6 1 1 36 39 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 39 40 shape=%q_shape2
torch.permute op_8 1 1 37 41 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2
pnnx.Expression op_10 1 1 40 43 expr=%expr_zero_shape
torch.empty op_11 1 1 43 zeros
torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2
torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0
F.softmax op_14 1 1 45 46 dim=-1
torch.permute op_15 1 1 38 47 dims=(0,2,1,3)
Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2
torch.bmm op_17 2 1 46 48 49
Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape
torch.permute op_19 1 1 50 51 dims=(0,2,1,3)
Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2
nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_10 : public fuse_multiheadattention_pass_7
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
25 24
pnnx.Input input_q 0 1 q
pnnx.Input input_kv 0 1 kv
nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 kv 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 kv 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 33 36 shape=%q_shape
Tensor.reshape op_4 1 1 34 37 shape=%kv_shape
Tensor.reshape op_5 1 1 35 38 shape=%kv_shape
torch.permute op_6 1 1 36 39 dims=(0,2,1,3)
Tensor.reshape op_7 1 1 39 40 shape=%q_shape2
torch.permute op_8 1 1 37 41 dims=(0,2,1,3)
Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2
pnnx.Expression op_10 1 1 40 43 expr=%expr_zero_shape
torch.empty op_11 1 1 43 zeros
torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2
torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0
F.softmax op_14 1 1 45 46 dim=-1
torch.permute op_15 1 1 38 47 dims=(0,2,1,3)
Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2
torch.bmm op_17 2 1 46 48 49
Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape
torch.permute op_19 1 1 50 51 dims=(0,2,1,3)
Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2
nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

void fuse_multiheadattention(Graph& graph)
{
fuse_multiheadattention_pass a;
fuse_multiheadattention_pass_sameqkv b;
fuse_multiheadattention_pass_qkv c;
fuse_multiheadattention_pass_q_samekv d;
fuse_multiheadattention_pass_1 b1;
fuse_multiheadattention_pass_2 c1;
fuse_multiheadattention_pass_3 d1;
fuse_multiheadattention_pass_5 e;
fuse_multiheadattention_pass_6 f;
fuse_multiheadattention_pass_7 g;
fuse_multiheadattention_pass_8 h;
fuse_multiheadattention_pass_9 i;
fuse_multiheadattention_pass_10 j;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &c, opindex);
pnnx_graph_rewrite(graph, &d, opindex);
pnnx_graph_rewrite(graph, &b1, opindex);
pnnx_graph_rewrite(graph, &c1, opindex);
pnnx_graph_rewrite(graph, &d1, opindex);
pnnx_graph_rewrite(graph, &e, opindex);
pnnx_graph_rewrite(graph, &f, opindex);
pnnx_graph_rewrite(graph, &g, opindex);
pnnx_graph_rewrite(graph, &h, opindex);
pnnx_graph_rewrite(graph, &i, opindex);
pnnx_graph_rewrite(graph, &j, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_multiheadattention.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_multiheadattention(Graph& graph);

} // namespace pnnx

+ 5
- 5
tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp View File

@@ -45,7 +45,7 @@ pnnx.Output output 1 0 out
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
float pad_value = 0.f;
@@ -122,7 +122,7 @@ pnnx.Output output 1 0 out
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// reflect/replicate + nopad
if (captured_params.at("mode").s != "reflect" && captured_params.at("mode").s != "replicate")
@@ -193,7 +193,7 @@ pnnx.Output output 1 0 out
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
float pad_value = 0.f;
@@ -270,7 +270,7 @@ pnnx.Output output 1 0 out
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// replicate + nopad
const std::vector<int>& pad = captured_params.at("pad").ai;
@@ -338,7 +338,7 @@ pnnx.Output output 1 0 out
return "conv1d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// reflect + nopad
const std::vector<int>& pad = captured_params.at("pad").ai;


+ 6
- 6
tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp View File

@@ -45,7 +45,7 @@ pnnx.Output output 1 0 out
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
float pad_value = 0.f;
@@ -134,7 +134,7 @@ pnnx.Output output 1 0 out
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// reflect/replicate + nopad
if (captured_params.at("mode").s != "reflect" && captured_params.at("mode").s != "replicate")
@@ -218,7 +218,7 @@ pnnx.Output output 1 0 out
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
float pad_value = 0.f;
@@ -296,7 +296,7 @@ pnnx.Output output 1 0 out
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// constant-0 + zeros
const std::vector<int>& pad = captured_params.at("pad").ai;
@@ -365,7 +365,7 @@ pnnx.Output output 1 0 out
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// replicate + nopad
const std::vector<int>& pad = captured_params.at("pad").ai;
@@ -434,7 +434,7 @@ pnnx.Output output 1 0 out
return "conv2d";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// reflect + nopad
const std::vector<int>& pad = captured_params.at("pad").ai;


+ 5
- 0
tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp View File

@@ -14,6 +14,7 @@

#include "normalize_einsum_equation.h"

#include <ctype.h>
#include <algorithm>
#include <map>
#include <vector>
@@ -41,6 +42,10 @@ void normalize_einsum_equation(Graph& graph)
continue;

std::string equation = op->params.at("equation").s;

// remove all spaces
equation.erase(std::remove_if(equation.begin(), equation.end(), isspace), equation.end());

size_t equation_len = equation.size();

std::map<char, char> xset;


+ 1
- 1
tools/pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp View File

@@ -54,7 +54,7 @@ pnnx.Output output 2 0 out0 out1
return "shufflechannel_slice";
}

bool match_captured_params_attrs(const std::map<std::string, Parameter>& captured_params) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// (116,2,1024)
// (1,0,2)


+ 1
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -293,6 +293,7 @@ pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d)
pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d)
pnnx_add_test(pnnx_fuse_input_unpack)
pnnx_add_test(pnnx_fuse_linear_batchnorm1d)
pnnx_add_test(pnnx_fuse_multiheadattention)
pnnx_add_test(pnnx_fuse_select_to_unbind)
pnnx_add_test(pnnx_fuse_slice_to_tensor_split)
pnnx_add_test(pnnx_fuse_adjacent_reshape)


+ 220
- 0
tools/pnnx/tests/test_pnnx_fuse_multiheadattention.py View File

@@ -0,0 +1,220 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

class Attention(nn.Module):
def __init__(self, embed_dim, num_heads, qkv_bias=True):
super().__init__()
self.num_heads = num_heads
head_dim = embed_dim // num_heads
self.scale = head_dim ** -0.5

self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, embed_dim)

def forward(self, x):
_, N, C = x.shape
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]

attn = q.matmul(k.permute((0, 1, 3, 2)))
attn = F.softmax(attn, dim=-1)

x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
return x

class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
if context_dim is None:
context_dim = query_dim

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)

def forward(self, x, context=None):
h = self.heads

q = self.to_q(x)
if context is None:
context = x
k = self.to_k(context)
v = self.to_v(context)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

class diffusers_CrossAttnProcessor:
def __call__(self, attn, hidden_states, encoder_hidden_states=None):
batch_size, sequence_length, _ = hidden_states.shape
query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

attention_probs = attn.get_attention_scores(query, key)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

return hidden_states

class diffusers_CrossAttention(nn.Module):
def __init__(self, query_dim, cross_attention_dim=None, heads=8, dim_head=64, dropout=0.0, bias=False):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)

self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))

self.processor = diffusers_CrossAttnProcessor()

def forward(self, hidden_states, encoder_hidden_states=None):
return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states)

def batch_to_head_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def head_to_batch_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor

def get_attention_scores(self, query, key):
dtype = query.dtype

baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0

attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)

attention_probs = attention_scores.softmax(dim=-1)
attention_probs = attention_probs.to(dtype)

return attention_probs

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.attention_0_0 = Attention(embed_dim=64, num_heads=4)
self.attention_0_1 = Attention(embed_dim=64, num_heads=8, qkv_bias=False)

self.attention_1_0 = CrossAttention(query_dim=64, heads=4, dim_head=16)
self.attention_1_1 = CrossAttention(query_dim=64, heads=8, dim_head=8, context_dim=17)

self.attention_2_0 = diffusers_CrossAttention(query_dim=64, heads=4, dim_head=16)
self.attention_2_1 = diffusers_CrossAttention(query_dim=64, heads=8, dim_head=8, cross_attention_dim=17)

def forward(self, x, y):
a = self.attention_0_0(x)
a = self.attention_0_1(a)

b = self.attention_1_0(x)
b = self.attention_1_1(b, y)

c = self.attention_2_0(x)
c = self.attention_2_1(c, y)
return a, b, c

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 20, 64)
y = torch.rand(1, 20, 17)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_pnnx_fuse_multiheadattention.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_fuse_multiheadattention.pt inputshape=[1,20,64],[1,20,17]")

# pnnx inference
import test_pnnx_fuse_multiheadattention_pnnx
b = test_pnnx_fuse_multiheadattention_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save