// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. // #include "pass_level1.h" // // #include // // #include "../utils.h" #include "fuse_module_pass.h" namespace pnnx { class MultiheadAttention : public FuseModulePass { public: const char* match_type_str() const { return "__torch__.torch.nn.modules.activation.MultiheadAttention"; } const char* type_str() const { return "nn.MultiheadAttention"; } void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const { // mod.dump(false, false, false); // graph->dump(); const TorchNodeProxy* multi_head_attention = graph.find_node_by_kind("aten::_native_multi_head_attention"); if (multi_head_attention) { op->params["num_heads"] = multi_head_attention->namedInput("num_head"); op->params["batch_first"] = true; op->params["add_zero_attn"] = false; if (multi_head_attention->hasNamedInput("mask") && multi_head_attention->namedInput("mask") == graph.input(graph.input_count() - 1)) { size_t input_count = op->inputs.size(); op->inputnames.resize(input_count); op->inputnames[input_count - 1] = "attn_mask"; } } else { const TorchNodeProxy* div_num_heads = graph.find_node_by_kind("aten::div"); const TorchNodeProxy* div_num_heads_18 = graph.find_node_by_kind("aten::floor_divide"); if (div_num_heads_18) { div_num_heads = div_num_heads_18; } // const TorchNodeProxy* div_num_heads_input_1 = graph.find_producer_node_by_value(div_num_heads->input(1)); // op->params["num_heads"] = (int)div_num_heads_input_1->t(torch::jit::attr::value).item(); op->params["num_heads"] = div_num_heads->input(1); const TorchNodeProxy* transpose_batch_seq = graph.find_node_by_kind("aten::transpose"); Parameter transpose_dim0 = transpose_batch_seq->input(1); Parameter transpose_dim1 = transpose_batch_seq->input(2); if (transpose_dim0.i == 1 && transpose_dim1.i == 0) { op->params["batch_first"] = true; } const TorchNodeProxy* add_zero_attn = graph.find_node_by_kind("aten::zeros"); if (add_zero_attn) { op->params["add_zero_attn"] = true; } else { op->params["add_zero_attn"] = false; } const TorchNodeProxy* scaled_dot_product_attention = graph.find_node_by_kind("aten::scaled_dot_product_attention"); if (scaled_dot_product_attention) { if (!scaled_dot_product_attention->is_input_none(3)) { size_t input_count = op->inputs.size(); op->inputnames.resize(input_count); op->inputnames[input_count - 1] = "attn_mask"; } } // find attention mask addition pattern pre torch-2.1 const TorchNodeProxy* has_attn_mask = graph.find_node_by_kind("aten::baddbmm"); if (has_attn_mask) { size_t input_count = op->inputs.size(); op->inputnames.resize(input_count); op->inputnames[input_count - 1] = "attn_mask"; } // find attention mask addition pattern pre torch-1.12 // attn = torch.bmm(Q, K) // input0 = torch.add_(attn, attn_mask) // attn0 = torch.softmax(input0, -1) const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax"); if (softmax) { const TorchNodeProxy* add_ = graph.find_producer_node_by_value(softmax->input(0)); if (add_ && add_->kind() == "aten::add_") { const TorchNodeProxy* bmm = graph.find_producer_node_by_value(add_->input(0)); if (bmm && bmm->kind() == "aten::bmm") { size_t input_count = op->inputs.size(); op->inputnames.resize(input_count); op->inputnames[input_count - 1] = "attn_mask"; } } } } if (mod.hasattr("in_proj_weight")) { const TorchTensorProxy& in_proj_weight = mod.attr("in_proj_weight"); op->params["embed_dim"] = in_proj_weight.size(1); op->params["kdim"] = in_proj_weight.size(1); op->params["vdim"] = in_proj_weight.size(1); op->attrs["in_proj_weight"] = in_proj_weight; } else { const TorchTensorProxy& q_proj_weight = mod.attr("q_proj_weight"); const TorchTensorProxy& k_proj_weight = mod.attr("k_proj_weight"); const TorchTensorProxy& v_proj_weight = mod.attr("v_proj_weight"); op->params["embed_dim"] = q_proj_weight.size(1); op->params["kdim"] = k_proj_weight.size(1); op->params["vdim"] = v_proj_weight.size(1); op->attrs["q_proj_weight"] = q_proj_weight; op->attrs["k_proj_weight"] = k_proj_weight; op->attrs["v_proj_weight"] = v_proj_weight; } const TorchTensorProxy& out_proj_weight = mod.attr("out_proj.weight"); op->attrs["out_proj.weight"] = out_proj_weight; if (mod.hasattr("in_proj_bias") && mod.hasattr("out_proj.bias")) { // bias=True const TorchTensorProxy& in_proj_bias = mod.attr("in_proj_bias"); const TorchTensorProxy& out_proj_bias = mod.attr("out_proj.bias"); op->params["bias"] = true; op->attrs["in_proj_bias"] = in_proj_bias; op->attrs["out_proj.bias"] = out_proj_bias; } else { op->params["bias"] = false; // the output projection bias always there no matter bias is False in pytorch 1.8 // this behavior changes since https://github.com/pytorch/pytorch/commit/58d1b3639bc07f9519de18e5a18e575f260c7eeb if (mod.hasattr("out_proj.bias")) { const TorchTensorProxy& out_proj_bias = mod.attr("out_proj.bias"); op->attrs["out_proj.bias"] = out_proj_bias; } } if (mod.hasattr("bias_k") && mod.hasattr("bias_v")) { // add_bias_kv=True const TorchTensorProxy& bias_k = mod.attr("bias_k"); const TorchTensorProxy& bias_v = mod.attr("bias_v"); op->params["add_bias_kv"] = true; op->attrs["bias_k"] = bias_k; op->attrs["bias_v"] = bias_v; } else { op->params["add_bias_kv"] = false; } } }; REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(MultiheadAttention) } // namespace pnnx