diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index bb468ab2f..ede82a87a 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -62,7 +62,7 @@ jobs: python3 -m pip install --upgrade pip apt-get remove -y python3-setuptools pip3 install -U setuptools - pip3 install -U pytest wheel twine requests + pip3 install -U pytest wheel twine requests einops - name: setup pytorch run: | diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 5695e4a95..8b13ef702 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) set(pnnx_pass_level0_SRCS pass_level0/constant_unpooling.cpp + pass_level0/convert_half_to_float.cpp pass_level0/inline_block.cpp pass_level0/reset_device.cpp pass_level0/flatten_input.cpp @@ -194,6 +195,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_arange.cpp pass_level2/torch_argmax.cpp pass_level2/torch_argmin.cpp + pass_level2/torch_baddbmm.cpp pass_level2/torch_bmm.cpp pass_level2/torch_bitwise_not.cpp pass_level2/torch_bitwise_and.cpp @@ -324,6 +326,7 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_pad_conv1d.cpp pass_level5/fuse_pad_conv2d.cpp pass_level5/fuse_pixel_unshuffle.cpp + pass_level5/fuse_multiheadattention.cpp pass_level5/fuse_select_to_unbind.cpp pass_level5/fuse_slice_copy.cpp pass_level5/fuse_slice_indices.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 2a8c5d54f..bd605df42 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -160,9 +160,16 @@ Parameter::Parameter(const torch::jit::Node* value_node) if (value_node->kind() == c10::prim::Constant) { + if (value_node->output()->type()->kind() == c10::TypeKind::NoneType) + { + type = 0; + return; + } + if (!value_node->hasAttribute(torch::jit::attr::value)) { fprintf(stderr, "no attribute value\n"); + value_node->dump(); return; } @@ -200,6 +207,12 @@ Parameter::Parameter(const torch::jit::Node* value_node) s = value_node->s(torch::jit::attr::value); break; } + case c10::TypeKind::DeviceObjType: + { + type = 4; + s = value_node->s(torch::jit::attr::value); + break; + } case c10::TypeKind::TensorType: { at::Tensor t = value_node->t(torch::jit::attr::value); @@ -246,7 +259,7 @@ Parameter::Parameter(const torch::jit::Node* value_node) } default: { - fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); + fprintf(stderr, "unknown Parameter value kind %s\n", c10::typeKindToString(value_node->output()->type()->kind())); break; } } @@ -284,14 +297,14 @@ Parameter::Parameter(const torch::jit::Node* value_node) } default: { - fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); + fprintf(stderr, "unknown Parameter value list element kind %s\n", c10::typeKindToString(value_node->output()->type()->cast()->getElementType()->kind())); break; } } } else { - fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); + fprintf(stderr, "unknown Parameter value_node kind %s\n", value_node->kind().toDisplayString()); } } diff --git a/tools/pnnx/src/pass_level0/convert_half_to_float.cpp b/tools/pnnx/src/pass_level0/convert_half_to_float.cpp new file mode 100644 index 000000000..8395f3335 --- /dev/null +++ b/tools/pnnx/src/pass_level0/convert_half_to_float.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_half_to_float.h" + +namespace pnnx { + +void convert_half_to_float(torch::jit::Module& mod) +{ + for (auto submod : mod.children()) + { + convert_half_to_float(submod); + } + + for (auto named_attr : mod.named_attributes(false)) + { + const std::string& name = named_attr.name; + auto attr = named_attr.value; + + if (attr.type()->kind() == c10::TypeKind::TensorType) + { + auto t = attr.toTensor(); + + if (t.scalar_type() == c10::ScalarType::Half) + { + at::Tensor t_fp32 = t.toType(c10::ScalarType::Float); + + mod.setattr(name, t_fp32); + } + } + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/convert_half_to_float.h b/tools/pnnx/src/pass_level0/convert_half_to_float.h new file mode 100644 index 000000000..9eef6d1d1 --- /dev/null +++ b/tools/pnnx/src/pass_level0/convert_half_to_float.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include + +namespace pnnx { + +void convert_half_to_float(torch::jit::Module& mod); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/reset_device.cpp b/tools/pnnx/src/pass_level0/reset_device.cpp index a0475c561..fc6edba9a 100644 --- a/tools/pnnx/src/pass_level0/reset_device.cpp +++ b/tools/pnnx/src/pass_level0/reset_device.cpp @@ -20,8 +20,22 @@ void reset_device(std::shared_ptr& graph, const std::string& { for (torch::jit::Node* n : graph->nodes()) { - if (n->kind().toDisplayString() == std::string("aten::to")) + if (n->kind().is_aten()) { + if (n->hasNamedInput("dtype")) + { + torch::jit::Node* dtype_node = n->namedInput("dtype")->node(); + + if (dtype_node->hasAttribute(torch::jit::attr::value)) + { + // change dtype=half to dtype=float + if (dtype_node->i(torch::jit::attr::value) == 5) + { + dtype_node->i_(torch::jit::attr::value, 6); + } + } + } + if (n->hasNamedInput("device")) { torch::jit::Node* device_node = n->namedInput("device")->node(); diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index 4059fcda5..170963feb 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -17,6 +17,7 @@ #include "storezip.h" #include "pass_level0/constant_unpooling.h" +#include "pass_level0/convert_half_to_float.h" #include "pass_level0/flatten_input.h" #include "pass_level0/inline_block.h" #include "pass_level0/reset_device.h" @@ -157,6 +158,8 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrtype.c_str(), op->name.c_str()); + // fprintf(stderr, "eliminate_noop_math %s %s\n", op->type.c_str(), op->name.c_str()); for (auto& x : op->inputs) { diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 001765348..279d6019b 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -36,6 +36,7 @@ #include "pass_level5/fuse_convtranspose2d_batchnorm2d.h" #include "pass_level5/fuse_contiguous_view.h" #include "pass_level5/fuse_linear_batchnorm1d.h" +#include "pass_level5/fuse_multiheadattention.h" #include "pass_level5/fuse_pad_conv1d.h" #include "pass_level5/fuse_pad_conv2d.h" #include "pass_level5/fuse_select_to_unbind.h" @@ -123,6 +124,7 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons eliminate_reshape_shape_expression(g); fuse_channel_shuffle(g); + fuse_multiheadattention(g); fuse_index_expression(g); diff --git a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp index 70ff10057..9811a5bbe 100644 --- a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp +++ b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp @@ -45,7 +45,7 @@ pnnx.Output output 1 0 out return "channelshuffle"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { const std::string& expr = captured_params.at("expr").s; const std::string& expr2 = captured_params.at("expr2").s; @@ -97,7 +97,7 @@ pnnx.Output output 1 0 out return "channelshuffle"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // (1,2,58,28,28) // (1,-1,28,28) diff --git a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp new file mode 100644 index 000000000..24099bf86 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp @@ -0,0 +1,996 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_multiheadattention.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +class fuse_multiheadattention_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +14 13 +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input 1 bias=%qkv_bias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 1 2 shape=%shape +torch.permute op_2 1 1 2 3 dims=(2,0,3,1,4) +torch.unbind op_3 1 3 3 4 5 6 dim=0 +pnnx.Expression op_4 1 1 4 7 expr=%expr +torch.permute op_5 1 1 5 8 dims=(0,1,3,2) +torch.matmul op_6 2 1 7 8 9 +F.softmax op_7 1 1 9 10 dim=-1 +torch.matmul op_8 2 1 10 6 11 +torch.permute op_9 1 1 11 12 dims=(0,2,1,3) +Tensor.reshape op_10 1 1 12 13 shape=%shape2 +nn.Linear out_proj 1 1 13 out bias=%out_proj_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.MultiheadAttention"; + } + + const char* name_str() const + { + return "attention"; + } + + bool match(const std::map& captured_params) const + { + const int embed_dim = captured_params.at("embed_dim").i; + const int qkv_out_features = captured_params.at("qkv_out_features").i; + if (qkv_out_features != embed_dim * 3) + return false; + + // (1,-1,3,4,16) + // (1,-1,64) + const std::vector& shape = captured_params.at("shape").ai; + const std::vector& shape2 = captured_params.at("shape2").ai; + if (shape.size() != 5 || shape2.size() != 3) + return false; + + const int num_heads = shape[3]; + if (shape[0] != shape2[0] || shape[2] != 3 || shape[3] * shape[4] != shape2[2]) + return false; + + // mul(@0,2.581989e-01) + const std::string& expr = captured_params.at("expr").s; + float inv_sqrt_embed_dim_per_head = 0.f; + int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head); + if (nscan != 1) + return false; + + if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + int num_heads = captured_params.at("shape").ai[3]; + + bool qkv_bias = captured_params.at("qkv_bias").b; + bool out_proj_bias = captured_params.at("out_proj_bias").b; + bool bias = qkv_bias || out_proj_bias; + + op->params["num_heads"] = num_heads; + op->params["batch_first"] = true; + op->params["add_zero_attn"] = false; + op->params["add_bias_kv"] = false; + op->params["bias"] = bias; + + int embed_dim = captured_params.at("embed_dim").i; + + op->params["embed_dim"] = embed_dim; + op->params["kdim"] = embed_dim; + op->params["vdim"] = embed_dim; + + op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight"); + if (bias) + { + if (qkv_bias) + { + op->attrs["in_proj_bias"] = captured_attrs.at("op_0.bias"); + } + else + { + // init bias as zero + op->attrs["in_proj_bias"] = Attribute(); + op->attrs["in_proj_bias"].type = 1; + op->attrs["in_proj_bias"].shape = {embed_dim * 3}; + + op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float)); + memset(op->attrs["in_proj_bias"].data.data(), 0, embed_dim * 3 * sizeof(float)); + } + } + + op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight"); + if (bias) + { + if (out_proj_bias) + { + op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias"); + } + else + { + // init bias as zero + op->attrs["out_proj.bias"] = Attribute(); + op->attrs["out_proj.bias"].type = 1; + op->attrs["out_proj.bias"].shape = {embed_dim}; + + op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float)); + memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float)); + } + } + } +}; + +class fuse_multiheadattention_pass_sameqkv : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +23 22 +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input 31 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 32 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 33 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +pnnx.Expression op_3 1 1 32 34 expr=%expr +Tensor.reshape op_4 1 1 31 35 shape=%q_shape +Tensor.reshape op_5 1 1 34 36 shape=%kv_shape +Tensor.reshape op_6 1 1 33 37 shape=%kv_shape +torch.permute op_7 1 1 36 38 dims=(0,2,1,3) +Tensor.reshape op_8 1 1 38 39 shape=%kv_shape2 +torch.permute op_9 1 1 35 40 dims=(0,2,1,3) +Tensor.reshape op_10 1 1 40 41 shape=%q_shape2 +torch.permute op_11 1 1 39 42 dims=(0,2,1) +torch.matmul op_12 2 1 41 42 43 +F.softmax op_13 1 1 43 44 dim=-1 +torch.permute op_14 1 1 37 45 dims=(0,2,1,3) +Tensor.reshape op_15 1 1 45 46 shape=%kv_shape2 +torch.matmul op_16 2 1 44 46 47 +Tensor.reshape op_18 1 1 47 48 shape=%qkv_shape +torch.permute op_19 1 1 48 49 dims=(0,2,1,3) +Tensor.reshape op_20 1 1 49 50 shape=%qkv_shape2 +nn.Linear out_proj 1 1 50 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.MultiheadAttention"; + } + + const char* name_str() const + { + return "attention"; + } + + bool match(const std::map& captured_params) const + { + // q_shape = (1,q,8,40) + // kv_shape = (1,kv,8,40) + // q_shape2 = (8,q,40) + // kv_shape2 = (8,kv,40) + // qkv_shape = (1,8,q,40) + // qkv_shape2 = (1,q,320) + const std::vector& q_shape = captured_params.at("q_shape").ai; + const std::vector& kv_shape = captured_params.at("kv_shape").ai; + const std::vector& q_shape2 = captured_params.at("q_shape2").ai; + const std::vector& kv_shape2 = captured_params.at("kv_shape2").ai; + const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; + const std::vector& qkv_shape2 = captured_params.at("qkv_shape2").ai; + if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3) + return false; + + const int batch_size = q_shape[0]; + const int q_size = q_shape[1]; + const int num_heads = q_shape[2]; + const int feat_per_head = q_shape[3]; + const int kv_size = kv_shape[1]; + if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size) + return false; + + if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size) + return false; + + if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads) + return false; + + if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads) + return false; + + // mul(@0,1.581139e-01) + const std::string& expr = captured_params.at("expr").s; + float inv_sqrt_embed_dim_per_head = 0.f; + int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head); + if (nscan != 1) + return false; + + const int embed_dim = captured_params.at("embed_dim").i; + if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + int embed_dim = captured_params.at("embed_dim").i; + int kdim = captured_params.at("kdim").i; + int vdim = captured_params.at("vdim").i; + + // (1,*,8,40) + int num_heads = captured_params.at("q_shape").ai[2]; + + bool q_bias = captured_params.at("q_bias").b; + bool k_bias = captured_params.at("k_bias").b; + bool v_bias = captured_params.at("v_bias").b; + bool out_bias = captured_params.at("out_bias").b; + bool bias = q_bias || k_bias || v_bias || out_bias; + + op->params["embed_dim"] = embed_dim; + op->params["kdim"] = kdim; + op->params["vdim"] = vdim; + + op->params["num_heads"] = num_heads; + op->params["batch_first"] = true; + op->params["add_zero_attn"] = false; + op->params["add_bias_kv"] = false; + op->params["bias"] = bias; + + op->attrs["in_proj_weight"] = Attribute(); + op->attrs["in_proj_weight"].type = 1; + op->attrs["in_proj_weight"].shape = {embed_dim * 3, embed_dim}; + op->attrs["in_proj_weight"].data.resize(embed_dim * 3 * embed_dim * sizeof(float)); + + // combine qkv weight + { + float* in_proj_weight_ptr = (float*)op->attrs["in_proj_weight"].data.data(); + memcpy(in_proj_weight_ptr, captured_attrs.at("op_0.weight").data.data(), embed_dim * embed_dim * sizeof(float)); + in_proj_weight_ptr += embed_dim * embed_dim; + memcpy(in_proj_weight_ptr, captured_attrs.at("op_1.weight").data.data(), embed_dim * embed_dim * sizeof(float)); + in_proj_weight_ptr += embed_dim * embed_dim; + memcpy(in_proj_weight_ptr, captured_attrs.at("op_2.weight").data.data(), embed_dim * embed_dim * sizeof(float)); + } + + op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight"); + + if (bias) + { + op->attrs["in_proj_bias"] = Attribute(); + op->attrs["in_proj_bias"].type = 1; + op->attrs["in_proj_bias"].shape = {embed_dim * 3}; + op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float)); + + // combine qkv bias + { + float* in_proj_bias_ptr = (float*)op->attrs["in_proj_bias"].data.data(); + if (q_bias) + { + memcpy(in_proj_bias_ptr, captured_attrs.at("op_0.bias").data.data(), embed_dim * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); + } + in_proj_bias_ptr += embed_dim; + if (k_bias) + { + memcpy(in_proj_bias_ptr, captured_attrs.at("op_1.bias").data.data(), embed_dim * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); + } + in_proj_bias_ptr += embed_dim; + if (v_bias) + { + memcpy(in_proj_bias_ptr, captured_attrs.at("op_2.bias").data.data(), embed_dim * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); + } + } + + if (out_bias) + { + op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias"); + } + else + { + // init bias as zero + op->attrs["out_proj.bias"] = Attribute(); + op->attrs["out_proj.bias"].type = 1; + op->attrs["out_proj.bias"].shape = {embed_dim}; + + op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float)); + memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float)); + } + } + } +}; + +class fuse_multiheadattention_pass_qkv : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +25 24 +pnnx.Input input_q 0 1 q +pnnx.Input input_k 0 1 k +pnnx.Input input_v 0 1 v +nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 k 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 v 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +pnnx.Expression op_3 1 1 33 35 expr=%expr +Tensor.reshape op_4 1 1 32 36 shape=%q_shape +Tensor.reshape op_5 1 1 35 37 shape=%kv_shape +Tensor.reshape op_6 1 1 34 38 shape=%kv_shape +torch.permute op_7 1 1 37 39 dims=(0,2,1,3) +Tensor.reshape op_8 1 1 39 40 shape=%kv_shape2 +torch.permute op_9 1 1 36 41 dims=(0,2,1,3) +Tensor.reshape op_10 1 1 41 42 shape=%q_shape2 +torch.permute op_11 1 1 40 43 dims=(0,2,1) +torch.matmul op_12 2 1 42 43 44 +F.softmax op_13 1 1 44 45 dim=-1 +torch.permute op_14 1 1 38 46 dims=(0,2,1,3) +Tensor.reshape op_15 1 1 46 47 shape=%kv_shape2 +torch.matmul op_16 2 1 45 47 48 +Tensor.reshape op_17 1 1 48 49 shape=%qkv_shape +torch.permute op_18 1 1 49 50 dims=(0,2,1,3) +Tensor.reshape op_19 1 1 50 51 shape=%qkv_shape2 +nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.MultiheadAttention"; + } + + const char* name_str() const + { + return "attention"; + } + + bool match(const std::map& captured_params) const + { + // q_shape = (1,q,8,40) + // kv_shape = (1,kv,8,40) + // q_shape2 = (8,q,40) + // kv_shape2 = (8,kv,40) + // qkv_shape = (1,8,q,40) + // qkv_shape2 = (1,q,320) + const std::vector& q_shape = captured_params.at("q_shape").ai; + const std::vector& kv_shape = captured_params.at("kv_shape").ai; + const std::vector& q_shape2 = captured_params.at("q_shape2").ai; + const std::vector& kv_shape2 = captured_params.at("kv_shape2").ai; + const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; + const std::vector& qkv_shape2 = captured_params.at("qkv_shape2").ai; + if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3) + return false; + + const int batch_size = q_shape[0]; + const int q_size = q_shape[1]; + const int num_heads = q_shape[2]; + const int feat_per_head = q_shape[3]; + const int kv_size = kv_shape[1]; + if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size) + return false; + + if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size) + return false; + + if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads) + return false; + + if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads) + return false; + + // mul(@0,1.581139e-01) + const std::string& expr = captured_params.at("expr").s; + float inv_sqrt_embed_dim_per_head = 0.f; + int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head); + if (nscan != 1) + return false; + + const int embed_dim = captured_params.at("embed_dim").i; + if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + int embed_dim = captured_params.at("embed_dim").i; + int kdim = captured_params.at("kdim").i; + int vdim = captured_params.at("vdim").i; + + // (1,*,8,40) + int num_heads = captured_params.at("q_shape").ai[2]; + + bool q_bias = captured_params.at("q_bias").b; + bool k_bias = captured_params.at("k_bias").b; + bool v_bias = captured_params.at("v_bias").b; + bool out_bias = captured_params.at("out_bias").b; + bool bias = q_bias || k_bias || v_bias || out_bias; + + op->params["embed_dim"] = embed_dim; + op->params["kdim"] = kdim; + op->params["vdim"] = vdim; + + op->params["num_heads"] = num_heads; + op->params["batch_first"] = true; + op->params["add_zero_attn"] = false; + op->params["add_bias_kv"] = false; + op->params["bias"] = bias; + + op->attrs["q_proj_weight"] = captured_attrs.at("op_0.weight"); + op->attrs["k_proj_weight"] = captured_attrs.at("op_1.weight"); + op->attrs["v_proj_weight"] = captured_attrs.at("op_2.weight"); + op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight"); + + if (bias) + { + op->attrs["in_proj_bias"] = Attribute(); + op->attrs["in_proj_bias"].type = 1; + op->attrs["in_proj_bias"].shape = {embed_dim * 3}; + op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float)); + + // combine qkv bias + { + float* in_proj_bias_ptr = (float*)op->attrs["in_proj_bias"].data.data(); + if (q_bias) + { + memcpy(in_proj_bias_ptr, captured_attrs.at("op_0.bias").data.data(), embed_dim * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); + } + in_proj_bias_ptr += embed_dim; + if (k_bias) + { + memcpy(in_proj_bias_ptr, captured_attrs.at("op_1.bias").data.data(), embed_dim * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); + } + in_proj_bias_ptr += embed_dim; + if (v_bias) + { + memcpy(in_proj_bias_ptr, captured_attrs.at("op_2.bias").data.data(), embed_dim * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); + } + } + + if (out_bias) + { + op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias"); + } + else + { + // init bias as zero + op->attrs["out_proj.bias"] = Attribute(); + op->attrs["out_proj.bias"].type = 1; + op->attrs["out_proj.bias"].shape = {embed_dim}; + + op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float)); + memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float)); + } + } + } +}; + +class fuse_multiheadattention_pass_q_samekv : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +24 23 +pnnx.Input input_q 0 1 q +pnnx.Input input_kv 0 1 kv +nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +pnnx.Expression op_3 1 1 33 35 expr=%expr +Tensor.reshape op_4 1 1 32 36 shape=%q_shape +Tensor.reshape op_5 1 1 35 37 shape=%kv_shape +Tensor.reshape op_6 1 1 34 38 shape=%kv_shape +torch.permute op_7 1 1 37 39 dims=(0,2,1,3) +Tensor.reshape op_8 1 1 39 40 shape=%kv_shape2 +torch.permute op_9 1 1 36 41 dims=(0,2,1,3) +Tensor.reshape op_10 1 1 41 42 shape=%q_shape2 +torch.permute op_11 1 1 40 43 dims=(0,2,1) +torch.matmul op_12 2 1 42 43 44 +F.softmax op_13 1 1 44 45 dim=-1 +torch.permute op_14 1 1 38 46 dims=(0,2,1,3) +Tensor.reshape op_15 1 1 46 47 shape=%kv_shape2 +torch.matmul op_16 2 1 45 47 48 +Tensor.reshape op_17 1 1 48 49 shape=%qkv_shape +torch.permute op_18 1 1 49 50 dims=(0,2,1,3) +Tensor.reshape op_19 1 1 50 51 shape=%qkv_shape2 +nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_1 : public fuse_multiheadattention_pass_sameqkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +22 21 +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input 31 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 32 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 33 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 31 35 shape=%q_shape +Tensor.reshape op_4 1 1 32 36 shape=%kv_shape +Tensor.reshape op_5 1 1 33 37 shape=%kv_shape +torch.permute op_6 1 1 36 38 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 38 39 shape=%kv_shape2 +torch.permute op_8 1 1 35 40 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 40 41 shape=%q_shape2 +torch.einsum op_10 2 1 41 39 42 equation=ijl,ikl->ijk +pnnx.Expression op_11 1 1 42 43 expr=%expr +F.softmax op_12 1 1 43 44 dim=-1 +torch.permute op_13 1 1 37 45 dims=(0,2,1,3) +Tensor.reshape op_14 1 1 45 46 shape=%kv_shape2 +torch.einsum op_15 2 1 44 46 47 equation=ijl,ilk->ijk +Tensor.reshape op_16 1 1 47 48 shape=%qkv_shape +torch.permute op_17 1 1 48 49 dims=(0,2,1,3) +Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape2 +nn.Linear out_proj 1 1 50 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_2 : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +24 23 +pnnx.Input input_q 0 1 q +pnnx.Input input_k 0 1 k +pnnx.Input input_v 0 1 v +nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 k 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 v 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 32 36 shape=%q_shape +Tensor.reshape op_4 1 1 33 37 shape=%kv_shape +Tensor.reshape op_5 1 1 34 38 shape=%kv_shape +torch.permute op_6 1 1 37 39 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 39 40 shape=%kv_shape2 +torch.permute op_8 1 1 36 41 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 41 42 shape=%q_shape2 +torch.einsum op_10 2 1 42 40 43 equation=ijl,ikl->ijk +pnnx.Expression op_11 1 1 43 44 expr=%expr +F.softmax op_12 1 1 44 45 dim=-1 +torch.permute op_13 1 1 38 46 dims=(0,2,1,3) +Tensor.reshape op_14 1 1 46 47 shape=%kv_shape2 +torch.einsum op_15 2 1 45 47 48 equation=ijl,ilk->ijk +Tensor.reshape op_16 1 1 48 49 shape=%qkv_shape +torch.permute op_17 1 1 49 50 dims=(0,2,1,3) +Tensor.reshape op_18 1 1 50 51 shape=%qkv_shape2 +nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_3 : public fuse_multiheadattention_pass_q_samekv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +23 22 +pnnx.Input input_q 0 1 q +pnnx.Input input_kv 0 1 kv +nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 32 36 shape=%q_shape +Tensor.reshape op_4 1 1 33 37 shape=%kv_shape +Tensor.reshape op_5 1 1 34 38 shape=%kv_shape +torch.permute op_6 1 1 37 39 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 39 40 shape=%kv_shape2 +torch.permute op_8 1 1 36 41 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 41 42 shape=%q_shape2 +torch.einsum op_10 2 1 42 40 43 equation=ijl,ikl->ijk +pnnx.Expression op_11 1 1 43 44 expr=%expr +F.softmax op_12 1 1 44 45 dim=-1 +torch.permute op_13 1 1 38 46 dims=(0,2,1,3) +Tensor.reshape op_14 1 1 46 47 shape=%kv_shape2 +torch.einsum op_15 2 1 45 47 48 equation=ijl,ilk->ijk +Tensor.reshape op_16 1 1 48 49 shape=%qkv_shape +torch.permute op_17 1 1 49 50 dims=(0,2,1,3) +Tensor.reshape op_18 1 1 50 51 shape=%qkv_shape2 +nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_5 : public fuse_multiheadattention_pass_sameqkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +23 22 +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=%q_shape +Tensor.reshape op_4 1 1 34 37 shape=%kv_shape +Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +torch.permute op_6 1 1 36 39 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +torch.permute op_8 1 1 37 41 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +pnnx.Attribute op_10 0 1 43 @zeros +torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2 +torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0 +F.softmax op_13 1 1 45 46 dim=-1 +torch.permute op_14 1 1 38 47 dims=(0,2,1,3) +Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2 +torch.bmm op_16 2 1 46 48 49 +Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape +torch.permute op_18 1 1 50 51 dims=(0,2,1,3) +Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2 +nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + // q_shape = (1,q,8,40) + // kv_shape = (1,kv,8,40) + // q_shape2 = (8,q,40) + // kv_shape2 = (8,kv,40) + // qkv_shape = (1,8,q,40) + // qkv_shape2 = (1,q,320) + const std::vector& q_shape = captured_params.at("q_shape").ai; + const std::vector& kv_shape = captured_params.at("kv_shape").ai; + const std::vector& q_shape2 = captured_params.at("q_shape2").ai; + const std::vector& kv_shape2 = captured_params.at("kv_shape2").ai; + const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; + const std::vector& qkv_shape2 = captured_params.at("qkv_shape2").ai; + if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3) + return false; + + const int batch_size = q_shape[0]; + const int q_size = q_shape[1]; + const int num_heads = q_shape[2]; + const int feat_per_head = q_shape[3]; + const int kv_size = kv_shape[1]; + if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size) + return false; + + if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size) + return false; + + if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads) + return false; + + if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads) + return false; + + const float inv_sqrt_embed_dim_per_head = captured_params.at("alpha").f; + const int embed_dim = captured_params.at("embed_dim").i; + if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) + return false; + + return true; + } +}; + +class fuse_multiheadattention_pass_6 : public fuse_multiheadattention_pass_5 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +24 23 +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=%q_shape +Tensor.reshape op_4 1 1 34 37 shape=%kv_shape +Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +torch.permute op_6 1 1 36 39 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +torch.permute op_8 1 1 37 41 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +pnnx.Expression op_10 2 1 40 42 43 expr=%expr_zero_shape +torch.empty op_11 1 1 43 zeros +torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2 +torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0 +F.softmax op_14 1 1 45 46 dim=-1 +torch.permute op_15 1 1 38 47 dims=(0,2,1,3) +Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2 +torch.bmm op_17 2 1 46 48 49 +Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape +torch.permute op_19 1 1 50 51 dims=(0,2,1,3) +Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2 +nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_7 : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +25 24 +pnnx.Input input_q 0 1 q +pnnx.Input input_k 0 1 k +pnnx.Input input_v 0 1 v +nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 k 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 v 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=%q_shape +Tensor.reshape op_4 1 1 34 37 shape=%kv_shape +Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +torch.permute op_6 1 1 36 39 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +torch.permute op_8 1 1 37 41 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +pnnx.Attribute op_10 0 1 43 @zeros +torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2 +torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0 +F.softmax op_13 1 1 45 46 dim=-1 +torch.permute op_14 1 1 38 47 dims=(0,2,1,3) +Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2 +torch.bmm op_16 2 1 46 48 49 +Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape +torch.permute op_18 1 1 50 51 dims=(0,2,1,3) +Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2 +nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + // q_shape = (1,q,8,40) + // kv_shape = (1,kv,8,40) + // q_shape2 = (8,q,40) + // kv_shape2 = (8,kv,40) + // qkv_shape = (1,8,q,40) + // qkv_shape2 = (1,q,320) + const std::vector& q_shape = captured_params.at("q_shape").ai; + const std::vector& kv_shape = captured_params.at("kv_shape").ai; + const std::vector& q_shape2 = captured_params.at("q_shape2").ai; + const std::vector& kv_shape2 = captured_params.at("kv_shape2").ai; + const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; + const std::vector& qkv_shape2 = captured_params.at("qkv_shape2").ai; + if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3) + return false; + + const int batch_size = q_shape[0]; + const int q_size = q_shape[1]; + const int num_heads = q_shape[2]; + const int feat_per_head = q_shape[3]; + const int kv_size = kv_shape[1]; + if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size) + return false; + + if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size) + return false; + + if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads) + return false; + + if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads) + return false; + + const float inv_sqrt_embed_dim_per_head = captured_params.at("alpha").f; + const int embed_dim = captured_params.at("embed_dim").i; + if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) + return false; + + return true; + } +}; + +class fuse_multiheadattention_pass_8 : public fuse_multiheadattention_pass_7 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +24 23 +pnnx.Input input_q 0 1 q +pnnx.Input input_kv 0 1 kv +nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=%q_shape +Tensor.reshape op_4 1 1 34 37 shape=%kv_shape +Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +torch.permute op_6 1 1 36 39 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +torch.permute op_8 1 1 37 41 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +pnnx.Attribute op_10 0 1 43 @zeros +torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2 +torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0 +F.softmax op_13 1 1 45 46 dim=-1 +torch.permute op_14 1 1 38 47 dims=(0,2,1,3) +Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2 +torch.bmm op_16 2 1 46 48 49 +Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape +torch.permute op_18 1 1 50 51 dims=(0,2,1,3) +Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2 +nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_9 : public fuse_multiheadattention_pass_7 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +26 25 +pnnx.Input input_q 0 1 q +pnnx.Input input_k 0 1 k +pnnx.Input input_v 0 1 v +nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 k 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 v 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=%q_shape +Tensor.reshape op_4 1 1 34 37 shape=%kv_shape +Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +torch.permute op_6 1 1 36 39 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +torch.permute op_8 1 1 37 41 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +pnnx.Expression op_10 1 1 40 43 expr=%expr_zero_shape +torch.empty op_11 1 1 43 zeros +torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2 +torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0 +F.softmax op_14 1 1 45 46 dim=-1 +torch.permute op_15 1 1 38 47 dims=(0,2,1,3) +Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2 +torch.bmm op_17 2 1 46 48 49 +Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape +torch.permute op_19 1 1 50 51 dims=(0,2,1,3) +Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2 +nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_10 : public fuse_multiheadattention_pass_7 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +25 24 +pnnx.Input input_q 0 1 q +pnnx.Input input_kv 0 1 kv +nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=%q_shape +Tensor.reshape op_4 1 1 34 37 shape=%kv_shape +Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +torch.permute op_6 1 1 36 39 dims=(0,2,1,3) +Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +torch.permute op_8 1 1 37 41 dims=(0,2,1,3) +Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +pnnx.Expression op_10 1 1 40 43 expr=%expr_zero_shape +torch.empty op_11 1 1 43 zeros +torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2 +torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0 +F.softmax op_14 1 1 45 46 dim=-1 +torch.permute op_15 1 1 38 47 dims=(0,2,1,3) +Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2 +torch.bmm op_17 2 1 46 48 49 +Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape +torch.permute op_19 1 1 50 51 dims=(0,2,1,3) +Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2 +nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +void fuse_multiheadattention(Graph& graph) +{ + fuse_multiheadattention_pass a; + fuse_multiheadattention_pass_sameqkv b; + fuse_multiheadattention_pass_qkv c; + fuse_multiheadattention_pass_q_samekv d; + fuse_multiheadattention_pass_1 b1; + fuse_multiheadattention_pass_2 c1; + fuse_multiheadattention_pass_3 d1; + fuse_multiheadattention_pass_5 e; + fuse_multiheadattention_pass_6 f; + fuse_multiheadattention_pass_7 g; + fuse_multiheadattention_pass_8 h; + fuse_multiheadattention_pass_9 i; + fuse_multiheadattention_pass_10 j; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); + pnnx_graph_rewrite(graph, &b1, opindex); + pnnx_graph_rewrite(graph, &c1, opindex); + pnnx_graph_rewrite(graph, &d1, opindex); + pnnx_graph_rewrite(graph, &e, opindex); + pnnx_graph_rewrite(graph, &f, opindex); + pnnx_graph_rewrite(graph, &g, opindex); + pnnx_graph_rewrite(graph, &h, opindex); + pnnx_graph_rewrite(graph, &i, opindex); + pnnx_graph_rewrite(graph, &j, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_multiheadattention.h b/tools/pnnx/src/pass_level5/fuse_multiheadattention.h new file mode 100644 index 000000000..d8c1914d2 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_multiheadattention.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_multiheadattention(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp index 2f1260061..c78db4c66 100644 --- a/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp @@ -45,7 +45,7 @@ pnnx.Output output 1 0 out return "conv1d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // constant-0 + zeros float pad_value = 0.f; @@ -122,7 +122,7 @@ pnnx.Output output 1 0 out return "conv1d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // reflect/replicate + nopad if (captured_params.at("mode").s != "reflect" && captured_params.at("mode").s != "replicate") @@ -193,7 +193,7 @@ pnnx.Output output 1 0 out return "conv1d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // constant-0 + zeros float pad_value = 0.f; @@ -270,7 +270,7 @@ pnnx.Output output 1 0 out return "conv1d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // replicate + nopad const std::vector& pad = captured_params.at("pad").ai; @@ -338,7 +338,7 @@ pnnx.Output output 1 0 out return "conv1d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // reflect + nopad const std::vector& pad = captured_params.at("pad").ai; diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp index 3723ed9c0..0823168ce 100644 --- a/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp @@ -45,7 +45,7 @@ pnnx.Output output 1 0 out return "conv2d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // constant-0 + zeros float pad_value = 0.f; @@ -134,7 +134,7 @@ pnnx.Output output 1 0 out return "conv2d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // reflect/replicate + nopad if (captured_params.at("mode").s != "reflect" && captured_params.at("mode").s != "replicate") @@ -218,7 +218,7 @@ pnnx.Output output 1 0 out return "conv2d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // constant-0 + zeros float pad_value = 0.f; @@ -296,7 +296,7 @@ pnnx.Output output 1 0 out return "conv2d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // constant-0 + zeros const std::vector& pad = captured_params.at("pad").ai; @@ -365,7 +365,7 @@ pnnx.Output output 1 0 out return "conv2d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // replicate + nopad const std::vector& pad = captured_params.at("pad").ai; @@ -434,7 +434,7 @@ pnnx.Output output 1 0 out return "conv2d"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // reflect + nopad const std::vector& pad = captured_params.at("pad").ai; diff --git a/tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp b/tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp index f02ac81ff..0f1902ab9 100644 --- a/tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp +++ b/tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp @@ -14,6 +14,7 @@ #include "normalize_einsum_equation.h" +#include #include #include #include @@ -41,6 +42,10 @@ void normalize_einsum_equation(Graph& graph) continue; std::string equation = op->params.at("equation").s; + + // remove all spaces + equation.erase(std::remove_if(equation.begin(), equation.end(), isspace), equation.end()); + size_t equation_len = equation.size(); std::map xset; diff --git a/tools/pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp b/tools/pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp index 23bdc4d5e..9e7265c15 100644 --- a/tools/pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp +++ b/tools/pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp @@ -54,7 +54,7 @@ pnnx.Output output 2 0 out0 out1 return "shufflechannel_slice"; } - bool match_captured_params_attrs(const std::map& captured_params) const + bool match(const std::map& captured_params) const { // (116,2,1024) // (1,0,2) diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 5c7450b2e..ab7358041 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -293,6 +293,7 @@ pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) pnnx_add_test(pnnx_fuse_input_unpack) pnnx_add_test(pnnx_fuse_linear_batchnorm1d) +pnnx_add_test(pnnx_fuse_multiheadattention) pnnx_add_test(pnnx_fuse_select_to_unbind) pnnx_add_test(pnnx_fuse_slice_to_tensor_split) pnnx_add_test(pnnx_fuse_adjacent_reshape) diff --git a/tools/pnnx/tests/test_pnnx_fuse_multiheadattention.py b/tools/pnnx/tests/test_pnnx_fuse_multiheadattention.py new file mode 100644 index 000000000..3de0b203d --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_multiheadattention.py @@ -0,0 +1,220 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +class Attention(nn.Module): + def __init__(self, embed_dim, num_heads, qkv_bias=True): + super().__init__() + self.num_heads = num_heads + head_dim = embed_dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, embed_dim) + + def forward(self, x): + _, N, C = x.shape + qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + + attn = q.matmul(k.permute((0, 1, 3, 2))) + attn = F.softmax(attn, dim=-1) + + x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C)) + x = self.proj(x) + return x + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + if context_dim is None: + context_dim = query_dim + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None): + h = self.heads + + q = self.to_q(x) + if context is None: + context = x + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + +class diffusers_CrossAttnProcessor: + def __call__(self, attn, hidden_states, encoder_hidden_states=None): + batch_size, sequence_length, _ = hidden_states.shape + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + +class diffusers_CrossAttention(nn.Module): + def __init__(self, query_dim, cross_attention_dim=None, heads=8, dim_head=64, dropout=0.0, bias=False): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + self.processor = diffusers_CrossAttnProcessor() + + def forward(self, hidden_states, encoder_hidden_states=None): + return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def get_attention_scores(self, query, key): + dtype = query.dtype + + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype) + + return attention_probs + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.attention_0_0 = Attention(embed_dim=64, num_heads=4) + self.attention_0_1 = Attention(embed_dim=64, num_heads=8, qkv_bias=False) + + self.attention_1_0 = CrossAttention(query_dim=64, heads=4, dim_head=16) + self.attention_1_1 = CrossAttention(query_dim=64, heads=8, dim_head=8, context_dim=17) + + self.attention_2_0 = diffusers_CrossAttention(query_dim=64, heads=4, dim_head=16) + self.attention_2_1 = diffusers_CrossAttention(query_dim=64, heads=8, dim_head=8, cross_attention_dim=17) + + def forward(self, x, y): + a = self.attention_0_0(x) + a = self.attention_0_1(a) + + b = self.attention_1_0(x) + b = self.attention_1_1(b, y) + + c = self.attention_2_0(x) + c = self.attention_2_1(c, y) + return a, b, c + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 20, 64) + y = torch.rand(1, 20, 17) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_pnnx_fuse_multiheadattention.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_multiheadattention.pt inputshape=[1,20,64],[1,20,17]") + + # pnnx inference + import test_pnnx_fuse_multiheadattention_pnnx + b = test_pnnx_fuse_multiheadattention_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)