diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 2c4329501..d374ed46e 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -242,6 +242,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_randn.cpp pass_level2/torch_randn_like.cpp pass_level2/torch_real.cpp + pass_level2/torch_repeat_interleave.cpp pass_level2/torch_roll.cpp pass_level2/torch_scatter_add.cpp pass_level2/torch_split.cpp @@ -332,7 +333,9 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_pad_conv1d.cpp pass_level5/fuse_pad_conv2d.cpp pass_level5/fuse_pixel_unshuffle.cpp + pass_level5/fuse_layernorm.cpp pass_level5/fuse_multiheadattention.cpp + pass_level5/fuse_scaled_dot_product_attention.cpp pass_level5/fuse_select_to_unbind.cpp pass_level5/fuse_slice_copy.cpp pass_level5/fuse_slice_indices.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index eb270a94e..8eae58264 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -588,6 +588,14 @@ Attribute operator+(const Attribute& a, const Attribute& b) Parameter Parameter::parse_from_string(const std::string& value) { + if (value.find('%') != std::string::npos) + { + Parameter p; + p.type = 4; + p.s = value; + return p; + } + Parameter p; p.type = 0; @@ -659,6 +667,96 @@ Parameter Parameter::parse_from_string(const std::string& value) return p; } +std::string Parameter::encode_to_string(const Parameter& param) +{ + if (param.type == 0) + { + return std::string("None"); + } + if (param.type == 1) + { + if (param.b) + return std::string("True"); + else + return std::string("False"); + } + if (param.type == 2) + { + return std::to_string(param.i); + } + if (param.type == 3) + { + char buf[64]; + sprintf(buf, "%e", param.f); + return std::string(buf); + } + if (param.type == 4) + { + return param.s; + } + if (param.type == 5) + { + std::string s("("); + for (size_t i = 0; i < param.ai.size(); i++) + { + s += std::to_string(param.ai[i]); + if (i + 1 != param.ai.size()) + s += std::string(","); + } + s += std::string(")"); + return s; + } + if (param.type == 6) + { + std::string s("("); + for (size_t i = 0; i < param.af.size(); i++) + { + char buf[64]; + sprintf(buf, "%e", param.af[i]); + s += std::string(buf); + if (i + 1 != param.af.size()) + s += std::string(","); + } + s += std::string(")"); + return s; + } + if (param.type == 7) + { + std::string s("("); + for (size_t i = 0; i < param.as.size(); i++) + { + s += param.as[i]; + if (i + 1 != param.as.size()) + s += std::string(","); + } + s += std::string(")"); + return s; + } + if (param.type == 10) + { + char buf[128]; + sprintf(buf, "%e+%ej", param.c.real(), param.c.imag()); + return std::string(buf); + } + if (param.type == 11) + { + std::string s("("); + for (size_t i = 0; i < param.ac.size(); i++) + { + char buf[128]; + sprintf(buf, "%e+%ej", param.ac[i].real(), param.ac[i].imag()); + s += std::string(buf); + if (i + 1 != param.ac.size()) + s += std::string(","); + } + s += std::string(")"); + return s; + } + + fprintf(stderr, "unknown parameter type %d\n", param.type); + return std::string(); +} + Graph::Graph() { } @@ -752,6 +850,14 @@ static void load_shape(Operator* op, const std::string& key, const std::string& { operand->shape.push_back(-1); } + else if (elem[0] == '%') + { + // encode %abc as symbolic tag + operand->shape.push_back(-233); + int index = operand->shape.size() - 1; + std::string key = elem.substr(1); + operand->params[std::string("__shape_") + std::to_string(index)] = key; + } else { int i = std::stoi(elem); @@ -965,77 +1071,8 @@ int Graph::save(const std::string& parampath, const std::string& binpath) fprintf(paramfp, " %s=", it.first.c_str()); const Parameter& param = it.second; - if (param.type == 0) - { - fprintf(paramfp, "None"); - } - if (param.type == 1) - { - if (param.b) - fprintf(paramfp, "True"); - else - fprintf(paramfp, "False"); - } - if (param.type == 2) - { - fprintf(paramfp, "%d", param.i); - } - if (param.type == 3) - { - fprintf(paramfp, "%e", param.f); - } - if (param.type == 4) - { - fprintf(paramfp, "%s", param.s.c_str()); - } - if (param.type == 5) - { - fprintf(paramfp, "("); - for (size_t i = 0; i < param.ai.size(); i++) - { - fprintf(paramfp, "%d", param.ai[i]); - if (i + 1 != param.ai.size()) - fprintf(paramfp, ","); - } - fprintf(paramfp, ")"); - } - if (param.type == 6) - { - fprintf(paramfp, "("); - for (size_t i = 0; i < param.af.size(); i++) - { - fprintf(paramfp, "%e", param.af[i]); - if (i + 1 != param.af.size()) - fprintf(paramfp, ","); - } - fprintf(paramfp, ")"); - } - if (param.type == 7) - { - fprintf(paramfp, "("); - for (size_t i = 0; i < param.as.size(); i++) - { - fprintf(paramfp, "%s", param.as[i].c_str()); - if (i + 1 != param.as.size()) - fprintf(paramfp, ","); - } - fprintf(paramfp, ")"); - } - if (param.type == 10) - { - fprintf(paramfp, "%e+%ej", param.c.real(), param.c.imag()); - } - if (param.type == 11) - { - fprintf(paramfp, "("); - for (size_t i = 0; i < param.ac.size(); i++) - { - fprintf(paramfp, "%e+%ej", param.ac[i].real(), param.ac[i].imag()); - if (i + 1 != param.ac.size()) - fprintf(paramfp, ","); - } - fprintf(paramfp, ")"); - } + std::string s = Parameter::encode_to_string(param); + fprintf(paramfp, "%s", s.c_str()); } for (const auto& it : op->attrs) @@ -2638,11 +2675,62 @@ int Graph::parse(const std::string& param) { // attribute // load_attribute(op, key.substr(1), value, szr); + op->attrs[key.substr(1)] = Attribute(); + + Attribute& attr = op->attrs[key.substr(1)]; + + attr.type = 0; + if (value.empty()) + continue; + + if (value[0] == '%') + { + // @data=%op1.data + attr.data = std::vector(value.begin(), value.end()); + } + + if (value[0] == '(') + { + // @data=(1,%c,?,4)f32 + + // type + std::string typestr = value.substr(value.find_last_of(')') + 1); + attr.type = string_to_type(typestr.c_str()); + + // shape + std::string lc = value.substr(1, value.find_last_of(')') - 1); + std::istringstream lcss(lc); + + attr.shape.clear(); + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + if (elem == "?") + { + attr.shape.push_back(-1); + } + else if (elem[0] == '%') + { + // encode %abc as symbolic tag + attr.shape.push_back(-233); + int index = attr.shape.size() - 1; + std::string key = elem.substr(1); + attr.params[std::string("__shape_") + std::to_string(index)] = key; + } + else + { + int i = std::stoi(elem); + attr.shape.push_back(i); + } + } + } } else if (key[0] == '$') { // operand input key - // load_input_key(op, key.substr(1), value); + load_input_key(op, key.substr(1), value); } else if (key[0] == '#') { diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 46efbcb1f..96ad57cbf 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -171,6 +171,7 @@ public: #endif // BUILD_PNNX static Parameter parse_from_string(const std::string& value); + static std::string encode_to_string(const Parameter& param); // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others 10=c 11=ac int type; @@ -217,6 +218,8 @@ public: std::vector shape; std::vector data; + + std::map params; }; bool operator==(const Attribute& lhs, const Attribute& rhs); @@ -246,6 +249,7 @@ private: friend class Graph; Operand() { + type = 0; } }; diff --git a/tools/pnnx/src/pass_level1.cpp b/tools/pnnx/src/pass_level1.cpp index 7b79713a8..1b1d971d5 100644 --- a/tools/pnnx/src/pass_level1.cpp +++ b/tools/pnnx/src/pass_level1.cpp @@ -131,7 +131,7 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrattrs[name] = sub_mod.attr(name).toTensor(); + op->attrs["data"] = sub_mod.attr(name).toTensor(); } } else if (n->kind() == c10::prim::Constant) // || n->kind() == c10::prim::ListConstruct) @@ -165,7 +165,7 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrparams.erase("value"); - op->attrs[name] = n->t(torch::jit::attr::value); + op->attrs["data"] = n->t(torch::jit::attr::value); } } else if (n->kind() == c10::prim::CallMethod) diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index c9a96923c..199861774 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -24,6 +24,17 @@ GraphRewriterPass::~GraphRewriterPass() { } +const char* GraphRewriterPass::replace_pattern_graph() const +{ + return 0; +} + +const char* GraphRewriterPass::type_str() const +{ + fprintf(stderr, "GraphRewriterPass type_str() should be implemented\n"); + return "unk"; +} + const char* GraphRewriterPass::name_str() const { return type_str(); @@ -46,15 +57,96 @@ bool GraphRewriterPass::match(const std::map& /*ma void GraphRewriterPass::write(Operator* op, const std::map& captured_params) const { - for (auto x : captured_params) + if (replace_pattern_graph() == 0) + { + for (auto x : captured_params) + { + op->params[x.first] = x.second; + } + + return; + } + + for (auto x : op->params) { - op->params[x.first] = x.second; + if (x.second.type != 4) + continue; + + std::string str = x.second.s; + if (str.find('%') == std::string::npos) + continue; + + // search % token and replace with captured + size_t pos = str.find('%'); + while (pos != std::string::npos) + { + // %xyz + char buf[256]; + sscanf(str.c_str() + pos + 1, "%255[^][,() ]", buf); + std::string key(buf); + + if (captured_params.find(key) == captured_params.end()) + { + fprintf(stderr, "replace pattern param %%%s missing captured\n", key.c_str()); + return; + } + + // replace %xyz with encoded_str + std::string encoded_str = Parameter::encode_to_string(captured_params.at(key)); + str.replace(pos, key.size() + 1, encoded_str); + + pos = str.find('%', pos + 1); + } + + op->params[x.first] = Parameter::parse_from_string(str); } } -void GraphRewriterPass::write(Operator* op, const std::map& captured_params, const std::map& /*captured_attrs*/) const +void GraphRewriterPass::write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { write(op, captured_params); + + for (auto x : op->attrs) + { + if (x.second.type != 0) + continue; + + std::string key((const char*)x.second.data.data()); + if (key.empty()) + continue; + + op->attrs[x.first] = captured_attrs.at(key); + } +} + +void GraphRewriterPass::write(const std::map& ops, const std::map& captured_params) const +{ + for (auto x : ops) + { + Operator* op = x.second; + write(op, captured_params); + } +} + +void GraphRewriterPass::write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const +{ + write(ops, captured_params); + + for (auto x : ops) + { + Operator* op = x.second; + for (auto x : op->attrs) + { + if (x.second.type != 0) + continue; + + std::string key(x.second.data.begin(), x.second.data.end()); + if (key.empty() || key[0] != '%') + continue; + + op->attrs[x.first] = captured_attrs.at(key.substr(1)); + } + } } static std::map > g_global_pnnx_graph_rewriter_passes; @@ -75,6 +167,126 @@ GraphRewriterPassRegister::~GraphRewriterPassRegister() delete pass; } +static bool token_is_argument(const std::string& t) +{ + if (t[0] != '@' || t.size() < 2) + return false; + + for (size_t i = 1; i < t.size(); i++) + { + if (t[i] < '0' || t[i] > '9') + return false; + } + + return true; +} + +static bool match_expression(const Operator* a, const Operator* b, std::map& captured_params) +{ + if (a->params.size() != 1 || a->params.find("expr") == a->params.end()) + return false; + + if (b->params.size() != 1 || b->params.find("expr") == b->params.end()) + return false; + + const std::string& a_expr = a->params.at("expr").s; + const std::string& b_expr = b->params.at("expr").s; + + if (a_expr == b_expr) + return true; + + // split into tokens + std::vector a_tokens; + std::vector b_tokens; + { + std::string t; + for (size_t i = 0; i < a_expr.size(); i++) + { + char ch = a_expr[i]; + + if (ch == '[') // list + { + t += ch; + a_tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + a_tokens.push_back(t); + t.clear(); + } + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + a_tokens.push_back(t); + } + } + { + std::string t; + for (size_t i = 0; i < b_expr.size(); i++) + { + char ch = b_expr[i]; + + if (ch == '[') // list + { + t += ch; + b_tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + b_tokens.push_back(t); + t.clear(); + } + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + b_tokens.push_back(t); + } + } + + if (a_tokens.size() != b_tokens.size()) + return false; + + // capture values + for (size_t i = 0; i < a_tokens.size(); i++) + { + const std::string& at = a_tokens[i]; + const std::string& bt = b_tokens[i]; + + if (at == bt) + continue; + + if (bt[0] != '%') + return false; + + if (token_is_argument(at)) + return false; + + std::string key = bt.substr(1); + + captured_params[key] = Parameter::parse_from_string(at); + } + + return true; +} + static bool match_parameter(const Parameter& a, const Parameter& b, std::map& captured_params) { if (b.type == 4 && b.s[0] == '%') @@ -97,6 +309,77 @@ static bool match_parameter(const Parameter& a, const Parameter& b, std::map '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) + { + // string + if (a.type != 7) + return false; + + if (a.as[i] != elem) + return false; + } + else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) + { + // float + if (a.type != 6) + return false; + + if (a.af[i] != std::stof(elem)) + return false; + } + else + { + // integer + if (a.type != 5) + return false; + + if (a.ai[i] != std::stoi(elem)) + return false; + } + + i++; + } + + return true; + } + if (a.type != b.type) { if (a.type == 2 && b.type == 3) @@ -174,6 +457,76 @@ static bool match_parameter(const Parameter& a, const Parameter& b, std::map& captured_params, const std::string& attrname, std::map& captured_attrs) +{ + // @data + // @data=(1,2,3,4)f32 + // @data=%op1.data + + if (b.type == 0) + { + std::string bs(b.data.begin(), b.data.end()); + if (bs.empty()) + { + // capture any shape + captured_attrs[attrname] = a; + return true; + } + + if (bs[0] == '%') + { + // the captured replace + return true; + } + + fprintf(stderr, "malformed attribute pattern %s\n", bs.c_str()); + return false; + } + + const std::vector& a_shape = a.shape; + const std::vector& b_shape = b.shape; + if (b_shape.empty()) + return false; + + if (a_shape.empty()) + return false; + + if (a_shape.size() != b_shape.size()) + return false; + + for (size_t j = 0; j < a_shape.size(); j++) + { + int ai = a_shape[j]; + int bi = b_shape[j]; + if (ai == bi) + continue; + + if (bi == -1) + continue; + + if (bi > 0) + return false; + + if (bi != -233) + return false; + + std::string key = b.params.at(std::string("__shape_") + std::to_string(j)).s; + + if (captured_params.find(key) != captured_params.end()) + { + // match previous captured parameter + if (captured_params.at(key).i != ai) + return false; + } + + // captured parameter + captured_params[key] = ai; + } + + captured_attrs[attrname] = a; + return true; +} + static bool match_operator(const Operator* a, const Operator* b, std::map& captured_params, std::map& captured_attrs) { if (a->type != b->type) @@ -197,6 +550,11 @@ static bool match_operator(const Operator* a, const Operator* b, std::mapname + '.' + pkey] = pp; } } + else if (a->type == "pnnx.Expression") + { + if (!match_expression(a, b, captured_params)) + return false; + } else { if (a->params.size() != b->params.size()) @@ -215,13 +573,120 @@ static bool match_operator(const Operator* a, const Operator* b, std::mapinputs.size(); i++) + { + int a_type = a->inputs[i]->type; + int b_type = b->inputs[i]->type; + if (b_type != 0 && a_type != b_type) + return false; + + const std::vector& a_shape = a->inputs[i]->shape; + const std::vector& b_shape = b->inputs[i]->shape; + if (b_shape.empty()) + continue; + + if (a_shape.empty()) + return false; + + if (a_shape.size() != b_shape.size()) + return false; + + for (size_t j = 0; j < a_shape.size(); j++) + { + int ai = a_shape[j]; + int bi = b_shape[j]; + if (ai == bi) + continue; + + if (bi == -1) + continue; + + if (bi > 0) + return false; + + if (bi != -233) + return false; + + std::string key = b->inputs[i]->params.at(std::string("__shape_") + std::to_string(j)).s; + + if (captured_params.find(key) != captured_params.end()) + { + // match previous captured parameter + if (captured_params.at(key).i != ai) + return false; + } + + // captured parameter + captured_params[key] = ai; + } + } + + for (size_t i = 0; i < a->outputs.size(); i++) + { + int a_type = a->outputs[i]->type; + int b_type = b->outputs[i]->type; + if (b_type != 0 && a_type != b_type) + return false; + + const std::vector& a_shape = a->outputs[i]->shape; + const std::vector& b_shape = b->outputs[i]->shape; + if (b_shape.empty()) + continue; + + if (a_shape.empty()) + return false; + + if (a_shape.size() != b_shape.size()) + return false; + + for (size_t j = 0; j < a_shape.size(); j++) + { + int ai = a_shape[j]; + int bi = b_shape[j]; + if (ai == bi) + continue; + + if (bi == -1) + continue; + + if (bi > 0) + return false; + + if (bi != -233) + return false; + + std::string key = b->outputs[i]->params.at(std::string("__shape_") + std::to_string(j)).s; + + if (captured_params.find(key) != captured_params.end()) + { + // match previous captured parameter + if (captured_params.at(key).i != ai) + return false; + } + + // captured parameter + captured_params[key] = ai; + } + } + for (const auto& p : a->attrs) { const std::string& akey = p.first; const Attribute& aa = p.second; - // capture all attributes - captured_attrs[b->name + '.' + akey] = aa; + std::string attrname = b->name + '.' + akey; + + if (b->attrs.find(akey) == b->attrs.end()) + { + // capture all attributes + captured_attrs[attrname] = aa; + } + else + { + if (!match_attribute(aa, b->attrs.at(akey), captured_params, attrname, captured_attrs)) + return false; + } } return true; @@ -484,27 +949,111 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde cur = graph.ops[cur_index]; } - Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur); - - for (const auto& k : pattern_graph_inputs) + if (pass->replace_pattern_graph() == 0) { - Operand* r = (Operand*)matched_inputs.at(k); - r->consumers.push_back(op); - op->inputs.push_back(r); + // insert single + Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur); - op->inputnames.push_back(k); - } + for (const auto& k : pattern_graph_inputs) + { + Operand* r = (Operand*)matched_inputs.at(k); + r->consumers.push_back(op); + op->inputs.push_back(r); - for (const auto& k : pattern_graph_outputs) - { - Operand* r = (Operand*)matched_outputs.at(k); - r->producer = op; - op->outputs.push_back(r); + op->inputnames.push_back(k); + } + + for (const auto& k : pattern_graph_outputs) + { + Operand* r = (Operand*)matched_outputs.at(k); + r->producer = op; + op->outputs.push_back(r); + } + + pass->write(op, captured_params, captured_attrs); + + new_ops.push_back(op); } + else + { + // insert multiple + Graph replace_graph; + replace_graph.parse(pass->replace_pattern_graph()); + + // move operators and operands from replace_graph to graph except input and output + std::map ops; + for (size_t i = 0; i < replace_graph.ops.size(); i++) + { + Operator* op = replace_graph.ops[i]; + if (op->type == "pnnx.Input" || op->type == "pnnx.Output") + continue; + + graph.ops.insert(std::find(graph.ops.begin(), graph.ops.end(), cur), op); + replace_graph.ops[i] = 0; + ops[op->name] = op; + } + + for (size_t i = 0; i < replace_graph.operands.size(); i++) + { + Operand* r = replace_graph.operands[i]; + if (r->producer->type == "pnnx.Input" || (r->consumers.size() == 1 && r->consumers[0]->type == "pnnx.Output")) + continue; + + graph.operands.push_back(r); + replace_graph.operands[i] = 0; + } + + replace_graph.ops.erase(std::remove(replace_graph.ops.begin(), replace_graph.ops.end(), (Operator*)0), replace_graph.ops.end()); + replace_graph.operands.erase(std::remove(replace_graph.operands.begin(), replace_graph.operands.end(), (Operand*)0), replace_graph.operands.end()); - pass->write(op, captured_params, captured_attrs); + for (size_t i = 0; i < pattern_graph_inputs.size(); i++) + { + const std::string& k = pattern_graph_inputs[i]; + Operand* r = (Operand*)matched_inputs.at(k); + const Operand* rr = replace_graph.get_operand(k); + + for (auto x : rr->consumers) + { + r->consumers.push_back(x); - new_ops.push_back(op); + x->inputnames.resize(x->inputs.size()); + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j]->name == k) + { + x->inputs[j] = r; + x->inputnames[j] = k; + break; + } + } + } + } + + for (size_t i = 0; i < pattern_graph_outputs.size(); i++) + { + const std::string& k = pattern_graph_outputs[i]; + Operand* r = (Operand*)matched_outputs.at(k); + const Operand* rr = replace_graph.get_operand(k); + + r->producer = rr->producer; + + for (size_t j = 0; j < r->producer->outputs.size(); j++) + { + if (r->producer->outputs[j]->name == k) + { + r->producer->outputs[j] = r; + break; + } + } + } + + pass->write(ops, captured_params, captured_attrs); + + for (auto x : ops) + { + new_ops.push_back(x.second); + } + } } // assign new op name number diff --git a/tools/pnnx/src/pass_level2.h b/tools/pnnx/src/pass_level2.h index af0fb8346..335e32174 100644 --- a/tools/pnnx/src/pass_level2.h +++ b/tools/pnnx/src/pass_level2.h @@ -26,7 +26,9 @@ public: virtual const char* match_pattern_graph() const = 0; - virtual const char* type_str() const = 0; + virtual const char* replace_pattern_graph() const; + + virtual const char* type_str() const; virtual const char* name_str() const; @@ -39,6 +41,10 @@ public: virtual void write(Operator* op, const std::map& captured_params) const; virtual void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const; + + virtual void write(const std::map& ops, const std::map& captured_params) const; + + virtual void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const; }; class GraphRewriterPassRegister diff --git a/tools/pnnx/src/pass_level2/torch_repeat_interleave.cpp b/tools/pnnx/src/pass_level2/torch_repeat_interleave.cpp new file mode 100644 index 000000000..0552e19a8 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_repeat_interleave.cpp @@ -0,0 +1,68 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +#include + +namespace pnnx { + +class torch_repeat_interleave : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 repeats +pnnx.Input input_2 0 1 dim +prim::Constant op_0 0 1 output_size value=* +aten::repeat_interleave op_1 4 1 input repeats dim output_size out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.repeat_interleave"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_repeat_interleave, 20) + +class torch_repeat_interleave_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 repeats +pnnx.Input input_2 0 1 dim +aten::repeat_interleave op_0 3 1 input repeats dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.repeat_interleave"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_repeat_interleave_1, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 279d6019b..08bd2c715 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -35,10 +35,12 @@ #include "pass_level5/fuse_convtranspose1d_batchnorm1d.h" #include "pass_level5/fuse_convtranspose2d_batchnorm2d.h" #include "pass_level5/fuse_contiguous_view.h" +#include "pass_level5/fuse_layernorm.h" #include "pass_level5/fuse_linear_batchnorm1d.h" #include "pass_level5/fuse_multiheadattention.h" #include "pass_level5/fuse_pad_conv1d.h" #include "pass_level5/fuse_pad_conv2d.h" +#include "pass_level5/fuse_scaled_dot_product_attention.h" #include "pass_level5/fuse_select_to_unbind.h" #include "pass_level5/fuse_slice_copy.h" #include "pass_level5/fuse_slice_indices.h" @@ -124,7 +126,9 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons eliminate_reshape_shape_expression(g); fuse_channel_shuffle(g); + fuse_layernorm(g); fuse_multiheadattention(g); + fuse_scaled_dot_product_attention(g); fuse_index_expression(g); diff --git a/tools/pnnx/src/pass_level5/fold_constants.cpp b/tools/pnnx/src/pass_level5/fold_constants.cpp index e5bccd498..47906f761 100644 --- a/tools/pnnx/src/pass_level5/fold_constants.cpp +++ b/tools/pnnx/src/pass_level5/fold_constants.cpp @@ -43,9 +43,9 @@ void fold_constants(Graph& graph, const std::set& foldable_constant // replace producer with attribute Operator* op_new = graph.new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op); - op_new->attrs[std::string("pnnx_fold_") + name] = Attribute(); + op_new->attrs["data"] = Attribute(); - Attribute& t2 = op_new->attrs[std::string("pnnx_fold_") + name]; + Attribute& t2 = op_new->attrs["data"]; t2.type = operand->type; t2.shape = operand->shape; size_t size = zip.get_file_size(name); diff --git a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp index 9811a5bbe..6b79bc059 100644 --- a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp +++ b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp @@ -26,9 +26,9 @@ public: return R"PNNXIR(7767517 7 6 pnnx.Input input 0 1 input -pnnx.Expression op_0 1 1 input 13 expr=%expr +pnnx.Expression op_0 1 1 input 13 expr=[int(size(@0,0)),2,int(floor_divide(size(@0,1),%groups)),int(size(@0,2)),int(size(@0,3))] Tensor.view op_1 2 1 input 13 14 -pnnx.Expression op_2 1 1 input 15 expr=%expr2 +pnnx.Expression op_2 1 1 input 15 expr=[int(size(@0,0)),-1,int(size(@0,2)),int(size(@0,3))] torch.transpose op_3 1 1 14 16 dim0=1 dim1=2 Tensor.reshape op_4 2 1 16 15 out pnnx.Output output 1 0 out @@ -44,32 +44,6 @@ pnnx.Output output 1 0 out { return "channelshuffle"; } - - bool match(const std::map& captured_params) const - { - const std::string& expr = captured_params.at("expr").s; - const std::string& expr2 = captured_params.at("expr2").s; - - if (expr2 != "[int(size(@0,0)),-1,int(size(@0,2)),int(size(@0,3))]") - return false; - - int groups; - int nscan = sscanf(expr.c_str(), "[int(size(@0,0)),2,int(floor_divide(size(@0,1),%d)),int(size(@0,2)),int(size(@0,3))]", &groups); - if (nscan != 1) - return false; - - return true; - } - - void write(Operator* op, const std::map& captured_params) const - { - const std::string& expr = captured_params.at("expr").s; - - int groups; - sscanf(expr.c_str(), "[int(size(@0,0)),2,int(floor_divide(size(@0,1),%d)),int(size(@0,2)),int(size(@0,3))]", &groups); - - op->params["groups"] = groups; - } }; class fuse_channel_shuffle_pass_1 : public GraphRewriterPass @@ -80,9 +54,9 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -Tensor.view op_0 1 1 input 13 shape=%shape +Tensor.view op_0 1 1 input 13 shape=(%batch,%groups,%channels_per_group,%h,%w) torch.transpose op_1 1 1 13 14 dim0=1 dim1=2 -Tensor.reshape op_2 1 1 14 out shape=%shape2 +Tensor.reshape op_2 1 1 14 out shape=(%batch,-1,%h,%w) pnnx.Output output 1 0 out )PNNXIR"; } @@ -97,26 +71,9 @@ pnnx.Output output 1 0 out return "channelshuffle"; } - bool match(const std::map& captured_params) const - { - // (1,2,58,28,28) - // (1,-1,28,28) - const std::vector& shape = captured_params.at("shape").ai; - const std::vector& shape2 = captured_params.at("shape2").ai; - - if (shape[0] != 1 || shape2[0] != 1 || shape2[1] != -1 || shape2[2] != shape[3] || shape2[3] != shape[4]) - return false; - - return true; - } - void write(Operator* op, const std::map& captured_params) const { - const std::vector& shape = captured_params.at("shape").ai; - - int groups = shape[1]; - - op->params["groups"] = groups; + op->params["groups"] = captured_params.at("groups"); } }; diff --git a/tools/pnnx/src/pass_level5/fuse_layernorm.cpp b/tools/pnnx/src/pass_level5/fuse_layernorm.cpp new file mode 100644 index 000000000..a723b4417 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_layernorm.cpp @@ -0,0 +1,87 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_layernorm.h" + +#include "pass_level2.h" + +#include +#include + +#include + +namespace pnnx { + +class fuse_layernorm_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input 0 1 input #input=(1,%c,?,?)f32 +pnnx.Attribute op_0 0 1 weight @data #weight=(%c,1,1)f32 +pnnx.Attribute op_1 0 1 bias @data #bias=(%c,1,1)f32 +torch.mean op_2 1 1 input mean dim=(1) keepdim=True +pnnx.Expression op_3 2 1 input mean 2 expr=pow(sub(@0,@1),2) +torch.mean op_4 1 1 2 var dim=(1) keepdim=True +pnnx.Expression op_5 5 1 weight input mean var bias out expr=add(mul(@0,div(sub(@1,@2),sqrt(add(@3,%eps)))),@4) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { +#if TORCH_VERSION_MAJOR >= 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 9 + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +torch.permute op_0 1 1 input a dims=(0,2,3,1) +nn.LayerNorm op_1 1 1 a b elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data @bias=%op_1.data +torch.permute op_2 1 1 b out dims=(0,3,1,2) +pnnx.Output output 1 0 out +)PNNXIR"; +#else + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +Tensor.permute op_0 1 1 input a dims=(0,2,3,1) +nn.LayerNorm op_1 1 1 a b elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data @bias=%op_1.data +Tensor.permute op_2 1 1 b out dims=(0,3,1,2) +pnnx.Output output 1 0 out +)PNNXIR"; +#endif + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + // fix weight bias shape from (c,1,1) to (c) + const int c = captured_params.at("c").i; + Operator* op_1 = ops.at("op_1"); + op_1->attrs["weight"].shape = {c}; + op_1->attrs["bias"].shape = {c}; + } +}; + +void fuse_layernorm(Graph& graph) +{ + fuse_layernorm_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_layernorm.h b/tools/pnnx/src/pass_level5/fuse_layernorm.h new file mode 100644 index 000000000..ac8c82cf8 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_layernorm.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_layernorm(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp index 2652b4fc1..4dd004264 100644 --- a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp +++ b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp @@ -44,87 +44,69 @@ public: return R"PNNXIR(7767517 14 13 pnnx.Input input 0 1 input -nn.Linear op_0 1 1 input 1 bias=%qkv_bias in_features=%embed_dim out_features=%qkv_out_features @bias @weight -Tensor.reshape op_1 1 1 1 2 shape=%shape +nn.Linear op_0 1 1 input 1 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 1 2 shape=(%batch,%size,3,%num_heads,%feat_per_head) torch.permute op_2 1 1 2 3 dims=(2,0,3,1,4) torch.unbind op_3 1 3 3 4 5 6 dim=0 -pnnx.Expression op_4 1 1 4 7 expr=%expr +pnnx.Expression op_4 1 1 4 7 expr=mul(@0,%inv_sqrt_embed_dim_per_head) torch.permute op_5 1 1 5 8 dims=(0,1,3,2) torch.matmul op_6 2 1 7 8 9 F.softmax op_7 1 1 9 10 dim=-1 torch.matmul op_8 2 1 10 6 11 torch.permute op_9 1 1 11 12 dims=(0,2,1,3) -Tensor.reshape op_10 1 1 12 13 shape=%shape2 -nn.Linear out_proj 1 1 13 out bias=%out_proj_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_10 1 1 12 13 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 13 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.MultiheadAttention"; - } - - const char* name_str() const - { - return "attention"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MultiheadAttention attention 1 1 input out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False +pnnx.Output output 1 0 out +)PNNXIR"; } bool match(const std::map& captured_params) const { const int embed_dim = captured_params.at("embed_dim").i; const int qkv_out_features = captured_params.at("qkv_out_features").i; - if (qkv_out_features != embed_dim * 3) - return false; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; + const float inv_sqrt_embed_dim_per_head = captured_params.at("inv_sqrt_embed_dim_per_head").f; - // (1,-1,3,4,16) - // (1,-1,64) - const std::vector& shape = captured_params.at("shape").ai; - const std::vector& shape2 = captured_params.at("shape2").ai; - if (shape.size() != 5 || shape2.size() != 3) - return false; - - const int num_heads = shape[3]; - if (shape[0] != shape2[0] || shape[2] != 3 || shape[3] * shape[4] != shape2[2]) + if (qkv_out_features != embed_dim * 3) return false; - // mul(@0,2.581989e-01) - const std::string& expr = captured_params.at("expr").s; - float inv_sqrt_embed_dim_per_head = 0.f; - int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head); - if (nscan != 1) + if (embed_dim != num_heads * feat_per_head) return false; - if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) + if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(feat_per_head), 0.001)) return false; return true; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - int num_heads = captured_params.at("shape").ai[captured_params.at("shape").ai.size() - 2]; + GraphRewriterPass::write(ops, captured_params, captured_attrs); - bool qkv_bias = captured_params.at("qkv_bias").b; - bool out_proj_bias = captured_params.at("out_proj_bias").b; - bool bias = qkv_bias || out_proj_bias; - - op->params["num_heads"] = num_heads; - op->params["batch_first"] = true; - op->params["add_zero_attn"] = false; - op->params["add_bias_kv"] = false; - op->params["bias"] = bias; + Operator* op = ops.at("attention"); - int embed_dim = captured_params.at("embed_dim").i; + const int embed_dim = captured_params.at("embed_dim").i; + const bool qkvbias = captured_params.at("qkvbias").b; + const bool outbias = captured_params.at("outbias").b; + const bool bias = qkvbias || outbias; - op->params["embed_dim"] = embed_dim; - op->params["kdim"] = embed_dim; - op->params["vdim"] = embed_dim; + op->params["bias"] = bias; op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight"); if (bias) { - if (qkv_bias) + if (qkvbias) { op->attrs["in_proj_bias"] = captured_attrs.at("op_0.bias"); } @@ -141,7 +123,7 @@ pnnx.Output output 1 0 out op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight"); if (bias) { - if (out_proj_bias) + if (outbias) { op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias"); } @@ -165,56 +147,25 @@ public: return R"PNNXIR(7767517 18 17 pnnx.Input input 0 1 input -nn.Linear op_0 1 1 input 1 bias=%qkv_bias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +nn.Linear op_0 1 1 input 1 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight torch.chunk op_1 1 3 1 2 3 4 chunks=3 dim=-1 -Tensor.reshape op_2 1 1 2 5 shape=%shape -Tensor.reshape op_3 1 1 3 6 shape=%shape -Tensor.reshape op_4 1 1 4 7 shape=%shape +Tensor.reshape op_2 1 1 2 5 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_3 1 1 3 6 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 4 7 shape=(%batch,%size,%num_heads,%feat_per_head) torch.permute op_5 1 1 6 8 dims=(0,2,1,3) torch.permute op_6 1 1 5 9 dims=(0,2,1,3) torch.transpose op_7 1 1 8 10 dim0=-1 dim1=-2 torch.matmul op_8 2 1 9 10 11 -pnnx.Expression op_9 1 1 11 12 expr=%expr +pnnx.Expression op_9 1 1 11 12 expr=mul(@0,%inv_sqrt_embed_dim_per_head) nn.Softmax op_10 1 1 12 13 dim=-1 torch.permute op_11 1 1 7 14 dims=(0,2,1,3) torch.matmul op_12 2 1 13 14 15 torch.permute op_13 1 1 15 16 dims=(0,2,1,3) -Tensor.reshape op_14 1 1 16 17 shape=%shape2 -nn.Linear out_proj 1 1 17 out bias=%out_proj_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_14 1 1 16 17 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 17 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - - bool match(const std::map& captured_params) const - { - const int embed_dim = captured_params.at("embed_dim").i; - const int qkv_out_features = captured_params.at("qkv_out_features").i; - if (qkv_out_features != embed_dim * 3) - return false; - - // (1,20,4,16) - // (1,20,64) - const std::vector& shape = captured_params.at("shape").ai; - const std::vector& shape2 = captured_params.at("shape2").ai; - if (shape.size() != 4 || shape2.size() != 3) - return false; - - const int num_heads = shape[2]; - if (shape[0] != shape2[0] || shape[1] != shape2[1] || shape[2] * shape[3] != shape2[2]) - return false; - - // mul(@0,2.581989e-01) - const std::string& expr = captured_params.at("expr").s; - float inv_sqrt_embed_dim_per_head = 0.f; - int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head); - if (nscan != 1) - return false; - - if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) - return false; - - return true; - } }; class fuse_multiheadattention_pass_sameqkv : public GraphRewriterPass @@ -225,112 +176,70 @@ public: return R"PNNXIR(7767517 23 22 pnnx.Input input 0 1 input -nn.Linear op_0 1 1 input 31 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 32 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 33 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -pnnx.Expression op_3 1 1 32 34 expr=%expr -Tensor.reshape op_4 1 1 31 35 shape=%q_shape -Tensor.reshape op_5 1 1 34 36 shape=%kv_shape -Tensor.reshape op_6 1 1 33 37 shape=%kv_shape +nn.Linear op_0 1 1 input 31 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 32 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 33 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Expression op_3 1 1 32 34 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +Tensor.reshape op_4 1 1 31 35 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 34 36 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_6 1 1 33 37 shape=(%batch,%size,%num_heads,%feat_per_head) torch.permute op_7 1 1 36 38 dims=(0,2,1,3) -Tensor.reshape op_8 1 1 38 39 shape=%kv_shape2 +Tensor.reshape op_8 1 1 38 39 shape=(%num_heads,%size,%feat_per_head) torch.permute op_9 1 1 35 40 dims=(0,2,1,3) -Tensor.reshape op_10 1 1 40 41 shape=%q_shape2 +Tensor.reshape op_10 1 1 40 41 shape=(%num_heads,%size,%feat_per_head) torch.permute op_11 1 1 39 42 dims=(0,2,1) torch.matmul op_12 2 1 41 42 43 F.softmax op_13 1 1 43 44 dim=-1 torch.permute op_14 1 1 37 45 dims=(0,2,1,3) -Tensor.reshape op_15 1 1 45 46 shape=%kv_shape2 +Tensor.reshape op_15 1 1 45 46 shape=(%num_heads,%size,%feat_per_head) torch.matmul op_16 2 1 44 46 47 -Tensor.reshape op_17 1 1 47 48 shape=%qkv_shape +Tensor.reshape op_17 1 1 47 48 shape=(%batch,%num_heads,%size,%feat_per_head) torch.permute op_18 1 1 48 49 dims=(0,2,1,3) -Tensor.reshape op_19 1 1 49 50 shape=%qkv_shape2 -nn.Linear out_proj 1 1 50 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 49 50 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 50 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.MultiheadAttention"; - } - - const char* name_str() const - { - return "attention"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MultiheadAttention attention 1 1 input out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False +pnnx.Output output 1 0 out +)PNNXIR"; } bool match(const std::map& captured_params) const { - // q_shape = (1,q,8,40) - // kv_shape = (1,kv,8,40) - // q_shape2 = (8,q,40) - // kv_shape2 = (8,kv,40) - // qkv_shape = (1,8,q,40) - // qkv_shape2 = (1,q,320) - const std::vector& q_shape = captured_params.at("q_shape").ai; - const std::vector& kv_shape = captured_params.at("kv_shape").ai; - const std::vector& q_shape2 = captured_params.at("q_shape2").ai; - const std::vector& kv_shape2 = captured_params.at("kv_shape2").ai; - const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; - const std::vector& qkv_shape2 = captured_params.at("qkv_shape2").ai; - if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3) - return false; + const int embed_dim = captured_params.at("embed_dim").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; + const float inv_sqrt_embed_dim_per_head = captured_params.at("inv_sqrt_embed_dim_per_head").f; - const int batch_size = q_shape[0]; - const int q_size = q_shape[1]; - const int num_heads = q_shape[2]; - const int feat_per_head = q_shape[3]; - const int kv_size = kv_shape[1]; - if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size) + if (embed_dim != num_heads * feat_per_head) return false; - if ((q_shape2[1] != q_size && q_shape2[1] != -1) || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size) + if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(feat_per_head), 0.001)) return false; - if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads) - return false; + return true; + } - if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads) - return false; + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); - // mul(@0,1.581139e-01) - const std::string& expr = captured_params.at("expr").s; - float inv_sqrt_embed_dim_per_head = 0.f; - int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head); - if (nscan != 1) - return false; + Operator* op = ops.at("attention"); const int embed_dim = captured_params.at("embed_dim").i; - if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) - return false; + const bool qbias = captured_params.at("qbias").b; + const bool kbias = captured_params.at("kbias").b; + const bool vbias = captured_params.at("vbias").b; + const bool outbias = captured_params.at("outbias").b; + const bool bias = qbias || kbias || vbias || outbias; - return true; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - int embed_dim = captured_params.at("embed_dim").i; - int kdim = captured_params.at("kdim").i; - int vdim = captured_params.at("vdim").i; - - // (1,*,8,40) - int num_heads = captured_params.at("q_shape").ai[2]; - - bool q_bias = captured_params.at("q_bias").b; - bool k_bias = captured_params.at("k_bias").b; - bool v_bias = captured_params.at("v_bias").b; - bool out_bias = captured_params.at("out_bias").b; - bool bias = q_bias || k_bias || v_bias || out_bias; - - op->params["embed_dim"] = embed_dim; - op->params["kdim"] = kdim; - op->params["vdim"] = vdim; - - op->params["num_heads"] = num_heads; - op->params["batch_first"] = true; - op->params["add_zero_attn"] = false; - op->params["add_bias_kv"] = false; op->params["bias"] = bias; op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight") + captured_attrs.at("op_1.weight") + captured_attrs.at("op_2.weight"); @@ -346,7 +255,7 @@ pnnx.Output output 1 0 out std::vector in_proj_bias(embed_dim * 3); { float* in_proj_bias_ptr = (float*)in_proj_bias.data(); - if (q_bias) + if (qbias) { auto qb = captured_attrs.at("op_0.bias").get_float32_data(); memcpy(in_proj_bias_ptr, (const void*)qb.data(), embed_dim * sizeof(float)); @@ -356,7 +265,7 @@ pnnx.Output output 1 0 out memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); } in_proj_bias_ptr += embed_dim; - if (k_bias) + if (kbias) { auto kb = captured_attrs.at("op_1.bias").get_float32_data(); memcpy(in_proj_bias_ptr, (const void*)kb.data(), embed_dim * sizeof(float)); @@ -366,7 +275,7 @@ pnnx.Output output 1 0 out memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); } in_proj_bias_ptr += embed_dim; - if (v_bias) + if (vbias) { auto vb = captured_attrs.at("op_2.bias").get_float32_data(); memcpy(in_proj_bias_ptr, (const void*)vb.data(), embed_dim * sizeof(float)); @@ -378,7 +287,7 @@ pnnx.Output output 1 0 out } op->attrs["in_proj_bias"].set_float32_data(in_proj_bias); - if (out_bias) + if (outbias) { op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias"); } @@ -394,139 +303,95 @@ pnnx.Output output 1 0 out } }; -class fuse_multiheadattention_pass_qkv : public GraphRewriterPass +class fuse_multiheadattention_pass_qkv : public fuse_multiheadattention_pass_sameqkv { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 25 24 -pnnx.Input input_q 0 1 q -pnnx.Input input_k 0 1 k -pnnx.Input input_v 0 1 v -nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 k 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 v 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -pnnx.Expression op_3 1 1 33 35 expr=%expr -Tensor.reshape op_4 1 1 32 36 shape=%q_shape -Tensor.reshape op_5 1 1 35 37 shape=%kv_shape -Tensor.reshape op_6 1 1 34 38 shape=%kv_shape +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +nn.Linear op_0 1 1 query 32 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 33 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 34 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +pnnx.Expression op_3 1 1 33 35 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +Tensor.reshape op_4 1 1 32 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 35 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_6 1 1 34 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.permute op_7 1 1 37 39 dims=(0,2,1,3) -Tensor.reshape op_8 1 1 39 40 shape=%kv_shape2 +Tensor.reshape op_8 1 1 39 40 shape=(%num_heads,%kvsize,%feat_per_head) torch.permute op_9 1 1 36 41 dims=(0,2,1,3) -Tensor.reshape op_10 1 1 41 42 shape=%q_shape2 +Tensor.reshape op_10 1 1 41 42 shape=(%num_heads,%qsize,%feat_per_head) torch.permute op_11 1 1 40 43 dims=(0,2,1) torch.matmul op_12 2 1 42 43 44 F.softmax op_13 1 1 44 45 dim=-1 torch.permute op_14 1 1 38 46 dims=(0,2,1,3) -Tensor.reshape op_15 1 1 46 47 shape=%kv_shape2 +Tensor.reshape op_15 1 1 46 47 shape=(%num_heads,%kvsize,%feat_per_head) torch.matmul op_16 2 1 45 47 48 -Tensor.reshape op_17 1 1 48 49 shape=%qkv_shape +Tensor.reshape op_17 1 1 48 49 shape=(%batch,%num_heads,%qsize,%feat_per_head) torch.permute op_18 1 1 49 50 dims=(0,2,1,3) -Tensor.reshape op_19 1 1 50 51 shape=%qkv_shape2 -nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 50 51 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 51 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.MultiheadAttention"; - } - - const char* name_str() const - { - return "attention"; + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +nn.MultiheadAttention attention 3 1 query key value out embed_dim=%embed_dim kdim=%kdim vdim=%vdim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False +pnnx.Output output 1 0 out +)PNNXIR"; } - bool match(const std::map& captured_params) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - // q_shape = (1,q,8,40) - // kv_shape = (1,kv,8,40) - // q_shape2 = (8,q,40) - // kv_shape2 = (8,kv,40) - // qkv_shape = (1,8,q,40) - // qkv_shape2 = (1,q,320) - const std::vector& q_shape = captured_params.at("q_shape").ai; - const std::vector& kv_shape = captured_params.at("kv_shape").ai; - const std::vector& q_shape2 = captured_params.at("q_shape2").ai; - const std::vector& kv_shape2 = captured_params.at("kv_shape2").ai; - const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; - const std::vector& qkv_shape2 = captured_params.at("qkv_shape2").ai; - if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3) - return false; + GraphRewriterPass::write(ops, captured_params, captured_attrs); - const int batch_size = q_shape[0]; - const int q_size = q_shape[1]; - const int num_heads = q_shape[2]; - const int feat_per_head = q_shape[3]; - const int kv_size = kv_shape[1]; - if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size) - return false; - - if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size) - return false; - - if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads) - return false; - - if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads) - return false; - - // mul(@0,1.581139e-01) - const std::string& expr = captured_params.at("expr").s; - float inv_sqrt_embed_dim_per_head = 0.f; - int nscan = sscanf(expr.c_str(), "mul(@0,%f)", &inv_sqrt_embed_dim_per_head); - if (nscan != 1) - return false; + Operator* op = ops.at("attention"); const int embed_dim = captured_params.at("embed_dim").i; - if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) - return false; + const int kdim = captured_params.at("kdim").i; + const int vdim = captured_params.at("vdim").i; + const bool qbias = captured_params.at("qbias").b; + const bool kbias = captured_params.at("kbias").b; + const bool vbias = captured_params.at("vbias").b; + const bool outbias = captured_params.at("outbias").b; + const bool bias = qbias || kbias || vbias || outbias; + const bool same_qkv_dim = (embed_dim == kdim && embed_dim == vdim); - return true; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - int embed_dim = captured_params.at("embed_dim").i; - int kdim = captured_params.at("kdim").i; - int vdim = captured_params.at("vdim").i; - - // (1,*,8,40) - int num_heads = captured_params.at("q_shape").ai[2]; - - bool q_bias = captured_params.at("q_bias").b; - bool k_bias = captured_params.at("k_bias").b; - bool v_bias = captured_params.at("v_bias").b; - bool out_bias = captured_params.at("out_bias").b; - bool bias = q_bias || k_bias || v_bias || out_bias; - - op->params["embed_dim"] = embed_dim; - op->params["kdim"] = kdim; - op->params["vdim"] = vdim; - - op->params["num_heads"] = num_heads; - op->params["batch_first"] = true; - op->params["add_zero_attn"] = false; - op->params["add_bias_kv"] = false; op->params["bias"] = bias; - op->attrs["q_proj_weight"] = captured_attrs.at("op_0.weight"); - op->attrs["k_proj_weight"] = captured_attrs.at("op_1.weight"); - op->attrs["v_proj_weight"] = captured_attrs.at("op_2.weight"); + if (same_qkv_dim) + { + // same qkv dim, merge into in_proj_weight + op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight") + captured_attrs.at("op_1.weight") + captured_attrs.at("op_2.weight"); + } + else + { + op->attrs["q_proj_weight"] = captured_attrs.at("op_0.weight"); + op->attrs["k_proj_weight"] = captured_attrs.at("op_1.weight"); + op->attrs["v_proj_weight"] = captured_attrs.at("op_2.weight"); + } + op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight"); if (bias) { op->attrs["in_proj_bias"] = Attribute(); - op->attrs["in_proj_bias"].type = op->attrs["q_proj_weight"].type; + op->attrs["in_proj_bias"].type = same_qkv_dim ? op->attrs["in_proj_weight"].type : op->attrs["q_proj_weight"].type; op->attrs["in_proj_bias"].shape = {embed_dim * 3}; // combine qkv bias std::vector in_proj_bias(embed_dim * 3); { float* in_proj_bias_ptr = (float*)in_proj_bias.data(); - if (q_bias) + if (qbias) { auto qb = captured_attrs.at("op_0.bias").get_float32_data(); memcpy(in_proj_bias_ptr, (const void*)qb.data(), embed_dim * sizeof(float)); @@ -536,7 +401,7 @@ pnnx.Output output 1 0 out memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); } in_proj_bias_ptr += embed_dim; - if (k_bias) + if (kbias) { auto kb = captured_attrs.at("op_1.bias").get_float32_data(); memcpy(in_proj_bias_ptr, (const void*)kb.data(), embed_dim * sizeof(float)); @@ -546,7 +411,7 @@ pnnx.Output output 1 0 out memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); } in_proj_bias_ptr += embed_dim; - if (v_bias) + if (vbias) { auto vb = captured_attrs.at("op_2.bias").get_float32_data(); memcpy(in_proj_bias_ptr, (const void*)vb.data(), embed_dim * sizeof(float)); @@ -558,7 +423,7 @@ pnnx.Output output 1 0 out } op->attrs["in_proj_bias"].set_float32_data(in_proj_bias); - if (out_bias) + if (outbias) { op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias"); } @@ -581,29 +446,40 @@ public: { return R"PNNXIR(7767517 24 23 -pnnx.Input input_q 0 1 q +pnnx.Input input_q 0 1 query pnnx.Input input_kv 0 1 kv -nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 kv 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 kv 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -pnnx.Expression op_3 1 1 33 35 expr=%expr -Tensor.reshape op_4 1 1 32 36 shape=%q_shape -Tensor.reshape op_5 1 1 35 37 shape=%kv_shape -Tensor.reshape op_6 1 1 34 38 shape=%kv_shape +nn.Linear op_0 1 1 query 32 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 33 bias=%kbias in_features=%kvdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 34 bias=%vbias in_features=%kvdim out_features=%embed_dim @bias @weight +pnnx.Expression op_3 1 1 33 35 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +Tensor.reshape op_4 1 1 32 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 35 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_6 1 1 34 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.permute op_7 1 1 37 39 dims=(0,2,1,3) -Tensor.reshape op_8 1 1 39 40 shape=%kv_shape2 +Tensor.reshape op_8 1 1 39 40 shape=(%num_heads,%kvsize,%feat_per_head) torch.permute op_9 1 1 36 41 dims=(0,2,1,3) -Tensor.reshape op_10 1 1 41 42 shape=%q_shape2 +Tensor.reshape op_10 1 1 41 42 shape=(%num_heads,%qsize,%feat_per_head) torch.permute op_11 1 1 40 43 dims=(0,2,1) torch.matmul op_12 2 1 42 43 44 F.softmax op_13 1 1 44 45 dim=-1 torch.permute op_14 1 1 38 46 dims=(0,2,1,3) -Tensor.reshape op_15 1 1 46 47 shape=%kv_shape2 +Tensor.reshape op_15 1 1 46 47 shape=(%num_heads,%kvsize,%feat_per_head) torch.matmul op_16 2 1 45 47 48 -Tensor.reshape op_17 1 1 48 49 shape=%qkv_shape +Tensor.reshape op_17 1 1 48 49 shape=(%batch,%num_heads,%qsize,%feat_per_head) torch.permute op_18 1 1 49 50 dims=(0,2,1,3) -Tensor.reshape op_19 1 1 50 51 shape=%qkv_shape2 -nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 50 51 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 51 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 kv +nn.MultiheadAttention attention 2 1 query kv out embed_dim=%embed_dim kdim=%kvdim vdim=%kvdim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False pnnx.Output output 1 0 out )PNNXIR"; } @@ -617,29 +493,123 @@ public: return R"PNNXIR(7767517 22 21 pnnx.Input input 0 1 input -nn.Linear op_0 1 1 input 31 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 32 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 33 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 31 35 shape=%q_shape -Tensor.reshape op_4 1 1 32 36 shape=%kv_shape -Tensor.reshape op_5 1 1 33 37 shape=%kv_shape +nn.Linear op_0 1 1 input 31 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 32 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 33 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 31 35 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 32 36 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 33 37 shape=(%batch,%size,%num_heads,%feat_per_head) torch.permute op_6 1 1 36 38 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 38 39 shape=%kv_shape2 +Tensor.reshape op_7 1 1 38 39 shape=(%num_heads,%size,%feat_per_head) torch.permute op_8 1 1 35 40 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 40 41 shape=%q_shape2 +Tensor.reshape op_9 1 1 40 41 shape=(%num_heads,%size,%feat_per_head) torch.einsum op_10 2 1 41 39 42 equation=ijl,ikl->ijk -pnnx.Expression op_11 1 1 42 43 expr=%expr +pnnx.Expression op_11 1 1 42 43 expr=mul(@0,%inv_sqrt_embed_dim_per_head) F.softmax op_12 1 1 43 44 dim=-1 torch.permute op_13 1 1 37 45 dims=(0,2,1,3) -Tensor.reshape op_14 1 1 45 46 shape=%kv_shape2 +Tensor.reshape op_14 1 1 45 46 shape=(%num_heads,%size,%feat_per_head) torch.einsum op_15 2 1 44 46 47 equation=ijl,ilk->ijk -Tensor.reshape op_16 1 1 47 48 shape=%qkv_shape +Tensor.reshape op_16 1 1 47 48 shape=(%batch,%num_heads,%size,%feat_per_head) torch.permute op_17 1 1 48 49 dims=(0,2,1,3) -Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape2 -nn.Linear out_proj 1 1 50 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_18 1 1 49 50 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 50 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_1_1 : public fuse_multiheadattention_pass_sameqkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +19 18 +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input 47 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 48 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 49 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 47 50 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 48 51 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 49 52 shape=(%batch,%size,%num_heads,%feat_per_head) +torch.transpose op_6 1 1 51 53 dim0=1 dim1=2 +torch.permute op_7 1 1 53 54 dims=(0,1,3,2) +torch.transpose op_8 1 1 50 55 dim0=1 dim1=2 +torch.matmul op_9 2 1 55 54 56 +pnnx.Expression op_10 1 1 56 57 expr=div(@0,%sqrt_feat_per_head) +F.softmax op_11 1 1 57 58 dim=-1 +torch.transpose op_12 1 1 52 59 dim0=1 dim1=2 +torch.matmul op_13 2 1 58 59 60 +torch.transpose op_14 1 1 60 61 dim0=1 dim1=2 +Tensor.reshape op_15 1 1 61 62 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 62 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + const int embed_dim = captured_params.at("embed_dim").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; + const float sqrt_feat_per_head = captured_params.at("sqrt_feat_per_head").f; + + if (embed_dim != num_heads * feat_per_head) + return false; + + if (!NearlyEqual(sqrt_feat_per_head, sqrt(feat_per_head), 0.001)) + return false; + + return true; + } +}; + +class fuse_multiheadattention_pass_1_2 : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +21 20 +pnnx.Input input_0 0 1 query #query=(%batch,%qsize,%embed_dim)f32 +pnnx.Input input_1 0 1 key #key=(%batch,%kvsize,%kdim)f32 +pnnx.Input input_2 0 1 value #value=(%batch,%kvsize,%kdim)f32 +nn.Linear op_0 1 1 query 47 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 48 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 49 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 47 50 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 48 51 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 49 52 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +torch.transpose op_6 1 1 51 53 dim0=1 dim1=2 +torch.permute op_7 1 1 53 54 dims=(0,1,3,2) +torch.transpose op_8 1 1 50 55 dim0=1 dim1=2 +torch.matmul op_9 2 1 55 54 56 +pnnx.Expression op_10 1 1 56 57 expr=div(@0,%sqrt_feat_per_head) +F.softmax op_11 1 1 57 58 dim=-1 +torch.transpose op_12 1 1 52 59 dim0=1 dim1=2 +torch.matmul op_13 2 1 58 59 60 +torch.transpose op_14 1 1 60 61 dim0=1 dim1=2 +Tensor.reshape op_15 1 1 61 62 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 62 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } + + bool match(const std::map& captured_params) const + { + const int embed_dim = captured_params.at("embed_dim").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; + const float sqrt_feat_per_head = captured_params.at("sqrt_feat_per_head").f; + + if (embed_dim != num_heads * feat_per_head) + return false; + + if (!NearlyEqual(sqrt_feat_per_head, sqrt(feat_per_head), 0.001)) + return false; + + return true; + } }; class fuse_multiheadattention_pass_2 : public fuse_multiheadattention_pass_qkv @@ -649,29 +619,29 @@ public: { return R"PNNXIR(7767517 24 23 -pnnx.Input input_q 0 1 q -pnnx.Input input_k 0 1 k -pnnx.Input input_v 0 1 v -nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 k 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 v 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 32 36 shape=%q_shape -Tensor.reshape op_4 1 1 33 37 shape=%kv_shape -Tensor.reshape op_5 1 1 34 38 shape=%kv_shape +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +nn.Linear op_0 1 1 query 32 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 33 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 34 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 32 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 33 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 34 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.permute op_6 1 1 37 39 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 39 40 shape=%kv_shape2 +Tensor.reshape op_7 1 1 39 40 shape=(%num_heads,%kvsize,%feat_per_head) torch.permute op_8 1 1 36 41 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 41 42 shape=%q_shape2 +Tensor.reshape op_9 1 1 41 42 shape=(%num_heads,%qsize,%feat_per_head) torch.einsum op_10 2 1 42 40 43 equation=ijl,ikl->ijk -pnnx.Expression op_11 1 1 43 44 expr=%expr +pnnx.Expression op_11 1 1 43 44 expr=mul(@0,%inv_sqrt_embed_dim_per_head) F.softmax op_12 1 1 44 45 dim=-1 torch.permute op_13 1 1 38 46 dims=(0,2,1,3) -Tensor.reshape op_14 1 1 46 47 shape=%kv_shape2 +Tensor.reshape op_14 1 1 46 47 shape=(%num_heads,%kvsize,%feat_per_head) torch.einsum op_15 2 1 45 47 48 equation=ijl,ilk->ijk -Tensor.reshape op_16 1 1 48 49 shape=%qkv_shape +Tensor.reshape op_16 1 1 48 49 shape=(%batch,%num_heads,%qsize,%feat_per_head) torch.permute op_17 1 1 49 50 dims=(0,2,1,3) -Tensor.reshape op_18 1 1 50 51 shape=%qkv_shape2 -nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_18 1 1 50 51 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 51 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } @@ -684,28 +654,28 @@ public: { return R"PNNXIR(7767517 23 22 -pnnx.Input input_q 0 1 q +pnnx.Input input_q 0 1 query pnnx.Input input_kv 0 1 kv -nn.Linear op_0 1 1 q 32 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 kv 33 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 kv 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 32 36 shape=%q_shape -Tensor.reshape op_4 1 1 33 37 shape=%kv_shape -Tensor.reshape op_5 1 1 34 38 shape=%kv_shape +nn.Linear op_0 1 1 query 32 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 33 bias=%kbias in_features=%kvdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 34 bias=%vbias in_features=%kvdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 32 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 33 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 34 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.permute op_6 1 1 37 39 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 39 40 shape=%kv_shape2 +Tensor.reshape op_7 1 1 39 40 shape=(%num_heads,%kvsize,%feat_per_head) torch.permute op_8 1 1 36 41 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 41 42 shape=%q_shape2 +Tensor.reshape op_9 1 1 41 42 shape=(%num_heads,%qsize,%feat_per_head) torch.einsum op_10 2 1 42 40 43 equation=ijl,ikl->ijk -pnnx.Expression op_11 1 1 43 44 expr=%expr +pnnx.Expression op_11 1 1 43 44 expr=mul(@0,%inv_sqrt_embed_dim_per_head) F.softmax op_12 1 1 44 45 dim=-1 torch.permute op_13 1 1 38 46 dims=(0,2,1,3) -Tensor.reshape op_14 1 1 46 47 shape=%kv_shape2 +Tensor.reshape op_14 1 1 46 47 shape=(%num_heads,%kvsize,%feat_per_head) torch.einsum op_15 2 1 45 47 48 equation=ijl,ilk->ijk -Tensor.reshape op_16 1 1 48 49 shape=%qkv_shape +Tensor.reshape op_16 1 1 48 49 shape=(%batch,%num_heads,%qsize,%feat_per_head) torch.permute op_17 1 1 49 50 dims=(0,2,1,3) -Tensor.reshape op_18 1 1 50 51 shape=%qkv_shape2 -nn.Linear out_proj 1 1 51 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_18 1 1 50 51 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 51 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } @@ -719,75 +689,35 @@ public: return R"PNNXIR(7767517 23 22 pnnx.Input input 0 1 input -nn.Linear op_0 1 1 input 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 33 36 shape=%q_shape -Tensor.reshape op_4 1 1 34 37 shape=%kv_shape -Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +nn.Linear op_0 1 1 input 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 34 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 35 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 34 37 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 35 38 shape=(%batch,%size,%num_heads,%feat_per_head) torch.permute op_6 1 1 36 39 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +Tensor.reshape op_7 1 1 39 40 shape=(%num_heads,%size,%feat_per_head) torch.permute op_8 1 1 37 41 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 -pnnx.Attribute op_10 0 1 43 @zeros +Tensor.reshape op_9 1 1 41 42 shape=(%num_heads,%size,%feat_per_head) +pnnx.Attribute op_10 0 1 43 @data torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2 -torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0 +torch.baddbmm op_12 3 1 43 40 44 45 alpha=%inv_sqrt_embed_dim_per_head beta=0 F.softmax op_13 1 1 45 46 dim=-1 torch.permute op_14 1 1 38 47 dims=(0,2,1,3) -Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2 +Tensor.reshape op_15 1 1 47 48 shape=(%num_heads,%size,%feat_per_head) torch.bmm op_16 2 1 46 48 49 -Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape +Tensor.reshape op_17 1 1 49 50 shape=(%batch,%num_heads,%size,%feat_per_head) torch.permute op_18 1 1 50 51 dims=(0,2,1,3) -Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2 -nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 51 52 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 52 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - bool match(const std::map& captured_params) const - { - // q_shape = (1,q,8,40) - // kv_shape = (1,kv,8,40) - // q_shape2 = (8,q,40) - // kv_shape2 = (8,kv,40) - // qkv_shape = (1,8,q,40) - // qkv_shape2 = (1,q,320) - const std::vector& q_shape = captured_params.at("q_shape").ai; - const std::vector& kv_shape = captured_params.at("kv_shape").ai; - const std::vector& q_shape2 = captured_params.at("q_shape2").ai; - const std::vector& kv_shape2 = captured_params.at("kv_shape2").ai; - const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; - const std::vector& qkv_shape2 = captured_params.at("qkv_shape2").ai; - if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3) - return false; - - const int batch_size = q_shape[0]; - const int q_size = q_shape[1]; - const int num_heads = q_shape[2]; - const int feat_per_head = q_shape[3]; - const int kv_size = kv_shape[1]; - if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size) - return false; - - if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size) - return false; - - if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads) - return false; - - if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads) - return false; - - const float inv_sqrt_embed_dim_per_head = captured_params.at("alpha").f; - const int embed_dim = captured_params.at("embed_dim").i; - if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) - return false; - - return true; - } + // TODO match data zero }; -class fuse_multiheadattention_pass_6 : public fuse_multiheadattention_pass_5 +class fuse_multiheadattention_pass_6 : public fuse_multiheadattention_pass_sameqkv { public: const char* match_pattern_graph() const @@ -795,31 +725,33 @@ public: return R"PNNXIR(7767517 24 23 pnnx.Input input 0 1 input -nn.Linear op_0 1 1 input 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 33 36 shape=%q_shape -Tensor.reshape op_4 1 1 34 37 shape=%kv_shape -Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +nn.Linear op_0 1 1 input 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 34 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 35 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 34 37 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 35 38 shape=(%batch,%size,%num_heads,%feat_per_head) torch.permute op_6 1 1 36 39 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +Tensor.reshape op_7 1 1 39 40 shape=(%num_heads,%size,%feat_per_head) torch.permute op_8 1 1 37 41 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +Tensor.reshape op_9 1 1 41 42 shape=(%num_heads,%size,%feat_per_head) pnnx.Expression op_10 2 1 40 42 43 expr=%expr_zero_shape torch.empty op_11 1 1 43 zeros torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2 -torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0 +torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%inv_sqrt_embed_dim_per_head beta=0 F.softmax op_14 1 1 45 46 dim=-1 torch.permute op_15 1 1 38 47 dims=(0,2,1,3) -Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2 +Tensor.reshape op_16 1 1 47 48 shape=(%num_heads,%size,%feat_per_head) torch.bmm op_17 2 1 46 48 49 -Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape +Tensor.reshape op_18 1 1 49 50 shape=(%batch,%num_heads,%size,%feat_per_head) torch.permute op_19 1 1 50 51 dims=(0,2,1,3) -Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2 -nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_20 1 1 51 52 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 52 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } + + // TODO match expr_zero_shape }; class fuse_multiheadattention_pass_7 : public fuse_multiheadattention_pass_qkv @@ -829,180 +761,144 @@ public: { return R"PNNXIR(7767517 25 24 -pnnx.Input input_q 0 1 q -pnnx.Input input_k 0 1 k -pnnx.Input input_v 0 1 v -nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 k 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 v 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 33 36 shape=%q_shape -Tensor.reshape op_4 1 1 34 37 shape=%kv_shape -Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +nn.Linear op_0 1 1 query 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 34 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 35 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 34 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 35 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.permute op_6 1 1 36 39 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +Tensor.reshape op_7 1 1 39 40 shape=(%num_heads,%qsize,%feat_per_head) torch.permute op_8 1 1 37 41 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 -pnnx.Attribute op_10 0 1 43 @zeros +Tensor.reshape op_9 1 1 41 42 shape=(%num_heads,%kvsize,%feat_per_head) +pnnx.Attribute op_10 0 1 43 @data torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2 -torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0 +torch.baddbmm op_12 3 1 43 40 44 45 alpha=%inv_sqrt_embed_dim_per_head beta=0 F.softmax op_13 1 1 45 46 dim=-1 torch.permute op_14 1 1 38 47 dims=(0,2,1,3) -Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2 +Tensor.reshape op_15 1 1 47 48 shape=(%num_heads,%kvsize,%feat_per_head) torch.bmm op_16 2 1 46 48 49 -Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape +Tensor.reshape op_17 1 1 49 50 shape=(%batch,%num_heads,%qsize,%feat_per_head) torch.permute op_18 1 1 50 51 dims=(0,2,1,3) -Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2 -nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 51 52 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 52 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - bool match(const std::map& captured_params) const - { - // q_shape = (1,q,8,40) - // kv_shape = (1,kv,8,40) - // q_shape2 = (8,q,40) - // kv_shape2 = (8,kv,40) - // qkv_shape = (1,8,q,40) - // qkv_shape2 = (1,q,320) - const std::vector& q_shape = captured_params.at("q_shape").ai; - const std::vector& kv_shape = captured_params.at("kv_shape").ai; - const std::vector& q_shape2 = captured_params.at("q_shape2").ai; - const std::vector& kv_shape2 = captured_params.at("kv_shape2").ai; - const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; - const std::vector& qkv_shape2 = captured_params.at("qkv_shape2").ai; - if (q_shape.size() != 4 || kv_shape.size() != 4 || q_shape2.size() != 3 || kv_shape2.size() != 3 || qkv_shape.size() != 4 || qkv_shape2.size() != 3) - return false; - - const int batch_size = q_shape[0]; - const int q_size = q_shape[1]; - const int num_heads = q_shape[2]; - const int feat_per_head = q_shape[3]; - const int kv_size = kv_shape[1]; - if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size || qkv_shape2[0] != batch_size) - return false; - - if (q_shape2[1] != q_size || kv_shape2[1] != kv_size || qkv_shape[2] != q_size || qkv_shape2[1] != q_size) - return false; - - if (kv_shape[2] != num_heads || q_shape2[0] != num_heads || kv_shape2[0] != num_heads || qkv_shape[1] != num_heads) - return false; - - if (kv_shape[3] != feat_per_head || q_shape2[2] != feat_per_head || kv_shape2[2] != feat_per_head || qkv_shape[3] != feat_per_head || qkv_shape2[2] != feat_per_head * num_heads) - return false; - - const float inv_sqrt_embed_dim_per_head = captured_params.at("alpha").f; - const int embed_dim = captured_params.at("embed_dim").i; - if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(embed_dim / num_heads), 0.001)) - return false; - - return true; - } + // TODO match data zero }; -class fuse_multiheadattention_pass_8 : public fuse_multiheadattention_pass_7 +class fuse_multiheadattention_pass_8 : public fuse_multiheadattention_pass_q_samekv { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 24 23 -pnnx.Input input_q 0 1 q +pnnx.Input input_q 0 1 query pnnx.Input input_kv 0 1 kv -nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 kv 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 kv 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 33 36 shape=%q_shape -Tensor.reshape op_4 1 1 34 37 shape=%kv_shape -Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +nn.Linear op_0 1 1 query 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 34 bias=%kbias in_features=%kvdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 35 bias=%vbias in_features=%kvdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 34 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 35 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.permute op_6 1 1 36 39 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +Tensor.reshape op_7 1 1 39 40 shape=(%num_heads,%qsize,%feat_per_head) torch.permute op_8 1 1 37 41 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 -pnnx.Attribute op_10 0 1 43 @zeros +Tensor.reshape op_9 1 1 41 42 shape=(%num_heads,%kvsize,%feat_per_head) +pnnx.Attribute op_10 0 1 43 @data torch.transpose op_11 1 1 42 44 dim0=-1 dim1=-2 -torch.baddbmm op_12 3 1 43 40 44 45 alpha=%alpha beta=0 +torch.baddbmm op_12 3 1 43 40 44 45 alpha=%inv_sqrt_embed_dim_per_head beta=0 F.softmax op_13 1 1 45 46 dim=-1 torch.permute op_14 1 1 38 47 dims=(0,2,1,3) -Tensor.reshape op_15 1 1 47 48 shape=%kv_shape2 +Tensor.reshape op_15 1 1 47 48 shape=(%num_heads,%kvsize,%feat_per_head) torch.bmm op_16 2 1 46 48 49 -Tensor.reshape op_17 1 1 49 50 shape=%qkv_shape +Tensor.reshape op_17 1 1 49 50 shape=(%batch,%num_heads,%qsize,%feat_per_head) torch.permute op_18 1 1 50 51 dims=(0,2,1,3) -Tensor.reshape op_19 1 1 51 52 shape=%qkv_shape2 -nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 51 52 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 52 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } + + // TODO match data zero }; -class fuse_multiheadattention_pass_9 : public fuse_multiheadattention_pass_7 +class fuse_multiheadattention_pass_9 : public fuse_multiheadattention_pass_qkv { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 26 25 -pnnx.Input input_q 0 1 q -pnnx.Input input_k 0 1 k -pnnx.Input input_v 0 1 v -nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 k 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 v 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 33 36 shape=%q_shape -Tensor.reshape op_4 1 1 34 37 shape=%kv_shape -Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +nn.Linear op_0 1 1 query 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 34 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 35 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 34 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 35 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.permute op_6 1 1 36 39 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +Tensor.reshape op_7 1 1 39 40 shape=(%num_heads,%qsize,%feat_per_head) torch.permute op_8 1 1 37 41 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +Tensor.reshape op_9 1 1 41 42 shape=(%num_heads,%kvsize,%feat_per_head) pnnx.Expression op_10 1 1 40 43 expr=%expr_zero_shape torch.empty op_11 1 1 43 zeros torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2 -torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0 +torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%inv_sqrt_embed_dim_per_head beta=0 F.softmax op_14 1 1 45 46 dim=-1 torch.permute op_15 1 1 38 47 dims=(0,2,1,3) -Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2 +Tensor.reshape op_16 1 1 47 48 shape=(%num_heads,%kvsize,%feat_per_head) torch.bmm op_17 2 1 46 48 49 -Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape +Tensor.reshape op_18 1 1 49 50 shape=(%batch,%num_heads,%qsize,%feat_per_head) torch.permute op_19 1 1 50 51 dims=(0,2,1,3) -Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2 -nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_20 1 1 51 52 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 52 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } + + // TODO match expr_zero_shape }; -class fuse_multiheadattention_pass_10 : public fuse_multiheadattention_pass_7 +class fuse_multiheadattention_pass_10 : public fuse_multiheadattention_pass_q_samekv { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 25 24 -pnnx.Input input_q 0 1 q +pnnx.Input input_q 0 1 query pnnx.Input input_kv 0 1 kv -nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 kv 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 kv 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 33 36 shape=%q_shape -Tensor.reshape op_4 1 1 34 37 shape=%kv_shape -Tensor.reshape op_5 1 1 35 38 shape=%kv_shape +nn.Linear op_0 1 1 query 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 34 bias=%kbias in_features=%kvdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 35 bias=%vbias in_features=%kvdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 33 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 34 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 35 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.permute op_6 1 1 36 39 dims=(0,2,1,3) -Tensor.reshape op_7 1 1 39 40 shape=%q_shape2 +Tensor.reshape op_7 1 1 39 40 shape=(%num_heads,%qsize,%feat_per_head) torch.permute op_8 1 1 37 41 dims=(0,2,1,3) -Tensor.reshape op_9 1 1 41 42 shape=%kv_shape2 +Tensor.reshape op_9 1 1 41 42 shape=(%num_heads,%kvsize,%feat_per_head) pnnx.Expression op_10 1 1 40 43 expr=%expr_zero_shape torch.empty op_11 1 1 43 zeros torch.transpose op_12 1 1 42 44 dim0=-1 dim1=-2 torch.baddbmm op_13 3 1 zeros 40 44 45 alpha=%alpha beta=0 F.softmax op_14 1 1 45 46 dim=-1 torch.permute op_15 1 1 38 47 dims=(0,2,1,3) -Tensor.reshape op_16 1 1 47 48 shape=%kv_shape2 +Tensor.reshape op_16 1 1 47 48 shape=(%num_heads,%kvsize,%feat_per_head) torch.bmm op_17 2 1 46 48 49 -Tensor.reshape op_18 1 1 49 50 shape=%qkv_shape +Tensor.reshape op_18 1 1 49 50 shape=(%batch,%num_heads,%qsize,%feat_per_head) torch.permute op_19 1 1 50 51 dims=(0,2,1,3) -Tensor.reshape op_20 1 1 51 52 shape=%qkv_shape2 -nn.Linear out_proj 1 1 52 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_20 1 1 51 52 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 52 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } @@ -1016,44 +912,30 @@ public: return R"PNNXIR(7767517 15 14 pnnx.Input input_0 0 1 input -nn.Linear op_0 1 1 input 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.view op_3 1 1 33 36 shape=%q_shape -Tensor.view op_4 1 1 34 37 shape=%kv_shape -Tensor.view op_5 1 1 35 38 shape=%kv_shape +nn.Linear op_0 1 1 input 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 34 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 35 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.view op_3 1 1 33 36 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.view op_4 1 1 34 37 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.view op_5 1 1 35 38 shape=(%batch,%size,%num_heads,%feat_per_head) torch.transpose op_6 1 1 38 39 dim0=1 dim1=2 torch.transpose op_7 1 1 37 40 dim0=1 dim1=2 torch.transpose op_8 1 1 36 41 dim0=1 dim1=2 F.scaled_dot_product_attention op_9 3 1 41 40 39 42 attn_mask=None dropout_p=0.000000e+00 is_causal=False torch.transpose op_10 1 1 42 43 dim0=1 dim1=2 -Tensor.reshape op_11 1 1 43 44 shape=%qkv_shape -nn.Linear out_proj 1 1 44 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_11 1 1 43 44 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 44 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } bool match(const std::map& captured_params) const { - // q_shape = (2,-1,8,40) - // kv_shape = (2,-1,8,40) - // qkv_shape = (2,-1,320) - const std::vector& q_shape = captured_params.at("q_shape").ai; - const std::vector& kv_shape = captured_params.at("kv_shape").ai; - const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; - if (q_shape.size() != 4 || kv_shape.size() != 4 || qkv_shape.size() != 3) - return false; - - const int batch_size = q_shape[0]; - const int num_heads = q_shape[2]; - const int feat_per_head = q_shape[3]; - if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size) - return false; - - if (kv_shape[2] != num_heads) - return false; + const int embed_dim = captured_params.at("embed_dim").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; - if (kv_shape[3] != feat_per_head || qkv_shape[2] != feat_per_head * num_heads) + if (embed_dim != num_heads * feat_per_head) return false; return true; @@ -1068,19 +950,19 @@ public: return R"PNNXIR(7767517 15 14 pnnx.Input input_0 0 1 input -nn.Linear op_0 1 1 input 14 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 15 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 16 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.reshape op_3 1 1 14 17 shape=%q_shape -Tensor.reshape op_4 1 1 15 18 shape=%kv_shape -Tensor.reshape op_5 1 1 16 19 shape=%kv_shape +nn.Linear op_0 1 1 input 14 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 15 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 16 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 14 17 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 15 18 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 16 19 shape=(%batch,%size,%num_heads,%feat_per_head) torch.permute op_6 1 1 19 20 dims=(0,2,1,3) torch.permute op_7 1 1 18 21 dims=(0,2,1,3) torch.permute op_8 1 1 17 22 dims=(0,2,1,3) F.scaled_dot_product_attention op_9 3 1 22 21 20 23 attn_mask=None dropout_p=0.000000e+00 is_causal=False torch.permute op_10 1 1 23 24 dims=(0,2,1,3) -Tensor.reshape op_11 1 1 24 25 shape=%qkv_shape -nn.Linear out_proj 1 1 25 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_11 1 1 24 25 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 25 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } @@ -1093,78 +975,76 @@ public: { return R"PNNXIR(7767517 17 16 -pnnx.Input input_0 0 1 q -pnnx.Input input_1 0 1 k -pnnx.Input input_2 0 1 v -nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 k 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 v 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.view op_3 1 1 33 36 shape=%q_shape -Tensor.view op_4 1 1 34 37 shape=%kv_shape -Tensor.view op_5 1 1 35 38 shape=%kv_shape +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +nn.Linear op_0 1 1 query 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 34 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 35 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.view op_3 1 1 33 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.view op_4 1 1 34 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.view op_5 1 1 35 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.transpose op_6 1 1 38 39 dim0=1 dim1=2 torch.transpose op_7 1 1 37 40 dim0=1 dim1=2 torch.transpose op_8 1 1 36 41 dim0=1 dim1=2 F.scaled_dot_product_attention op_9 3 1 41 40 39 42 attn_mask=None dropout_p=0.000000e+00 is_causal=False torch.transpose op_10 1 1 42 43 dim0=1 dim1=2 -Tensor.reshape op_11 1 1 43 44 shape=%qkv_shape -nn.Linear out_proj 1 1 44 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_11 1 1 43 44 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 44 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } bool match(const std::map& captured_params) const { - // q_shape = (2,-1,8,40) - // kv_shape = (2,-1,8,40) - // qkv_shape = (2,-1,320) - const std::vector& q_shape = captured_params.at("q_shape").ai; - const std::vector& kv_shape = captured_params.at("kv_shape").ai; - const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; - if (q_shape.size() != 4 || kv_shape.size() != 4 || qkv_shape.size() != 3) - return false; - - const int batch_size = q_shape[0]; - const int num_heads = q_shape[2]; - const int feat_per_head = q_shape[3]; - if (kv_shape[0] != batch_size || qkv_shape[0] != batch_size) - return false; - - if (kv_shape[2] != num_heads) - return false; + const int embed_dim = captured_params.at("embed_dim").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; - if (kv_shape[3] != feat_per_head || qkv_shape[2] != feat_per_head * num_heads) + if (embed_dim != num_heads * feat_per_head) return false; return true; } }; -class fuse_multiheadattention_pass_14 : public fuse_multiheadattention_pass_12 +class fuse_multiheadattention_pass_14 : public fuse_multiheadattention_pass_q_samekv { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 16 15 -pnnx.Input input_0 0 1 q +pnnx.Input input_0 0 1 query pnnx.Input input_1 0 1 kv -nn.Linear op_0 1 1 q 33 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 kv 34 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 kv 35 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.view op_3 1 1 33 36 shape=%q_shape -Tensor.view op_4 1 1 34 37 shape=%kv_shape -Tensor.view op_5 1 1 35 38 shape=%kv_shape +nn.Linear op_0 1 1 query 33 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 34 bias=%kbias in_features=%kvdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 kv 35 bias=%vbias in_features=%kvdim out_features=%embed_dim @bias @weight +Tensor.view op_3 1 1 33 36 shape=(%batch,%qsize,%num_heads,%feat_per_head) +Tensor.view op_4 1 1 34 37 shape=(%batch,%kvsize,%num_heads,%feat_per_head) +Tensor.view op_5 1 1 35 38 shape=(%batch,%kvsize,%num_heads,%feat_per_head) torch.transpose op_6 1 1 38 39 dim0=1 dim1=2 torch.transpose op_7 1 1 37 40 dim0=1 dim1=2 torch.transpose op_8 1 1 36 41 dim0=1 dim1=2 F.scaled_dot_product_attention op_9 3 1 41 40 39 42 attn_mask=None dropout_p=0.000000e+00 is_causal=False torch.transpose op_10 1 1 42 43 dim0=1 dim1=2 -Tensor.reshape op_11 1 1 43 44 shape=%qkv_shape -nn.Linear out_proj 1 1 44 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_11 1 1 43 44 shape=(%batch,%qsize,%embed_dim) +nn.Linear out_proj 1 1 44 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } + + bool match(const std::map& captured_params) const + { + const int embed_dim = captured_params.at("embed_dim").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; + + if (embed_dim != num_heads * feat_per_head) + return false; + + return true; + } }; class fuse_multiheadattention_pass_15 : public fuse_multiheadattention_pass_sameqkv @@ -1175,27 +1055,27 @@ public: return R"PNNXIR(7767517 23 22 pnnx.Input input 0 1 input -nn.Linear op_0 1 1 input 2 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 4 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 6 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -pnnx.Expression op_3 1 1 2 3 expr=%expr -Tensor.view op_4 1 1 3 8 shape=%q_shape -Tensor.view op_5 1 1 4 5 shape=%kv_shape -Tensor.view op_6 1 1 6 7 shape=%kv_shape +nn.Linear op_0 1 1 input 2 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 4 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 6 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Expression op_3 1 1 2 3 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +Tensor.view op_4 1 1 3 8 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.view op_5 1 1 4 5 shape=(%batch,-1,%num_heads,%feat_per_head) +Tensor.view op_6 1 1 6 7 shape=(%batch,-1,%num_heads,%feat_per_head) torch.transpose op_7 1 1 8 9 dim0=1 dim1=2 torch.transpose op_8 1 1 5 10 dim0=1 dim1=2 torch.transpose op_9 1 1 7 11 dim0=1 dim1=2 -Tensor.reshape op_10 1 1 9 14 shape=%q_shape2 -Tensor.reshape op_11 1 1 10 12 shape=%kv_shape2 -Tensor.reshape op_12 1 1 11 17 shape=%kv_shape2 +Tensor.reshape op_10 1 1 9 14 shape=(%num_heads,-1,%feat_per_head) +Tensor.reshape op_11 1 1 10 12 shape=(%num_heads,-1,%feat_per_head) +Tensor.reshape op_12 1 1 11 17 shape=(%num_heads,-1,%feat_per_head) torch.transpose op_13 1 1 12 13 dim0=1 dim1=2 torch.bmm op_14 2 1 14 13 15 F.softmax op_15 1 1 15 16 dim=-1 torch.bmm op_16 2 1 16 17 18 -Tensor.view op_17 1 1 18 19 shape=%qkv_shape +Tensor.view op_17 1 1 18 19 shape=(%batch,%num_heads,%size,%feat_per_head) torch.transpose op_18 1 1 19 20 dim0=1 dim1=2 -Tensor.reshape op_19 1 1 20 21 shape=%qkv_shape2 -nn.Linear out_proj 1 1 21 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 20 21 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } @@ -1208,81 +1088,54 @@ public: { return R"PNNXIR(7767517 27 26 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 attn_mask -nn.Linear op_0 1 1 input 3 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 5 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 7 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -pnnx.Expression op_3 1 1 3 4 expr=%expr -Tensor.view op_4 1 1 4 9 shape=%q_shape -Tensor.view op_5 1 1 5 6 shape=%kv_shape -Tensor.view op_6 1 1 7 8 shape=%kv_shape +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input 3 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 5 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 7 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Expression op_3 1 1 3 4 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +Tensor.view op_4 1 1 4 9 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.view op_5 1 1 5 6 shape=(%batch,-1,%num_heads,%feat_per_head) +Tensor.view op_6 1 1 7 8 shape=(%batch,-1,%num_heads,%feat_per_head) torch.transpose op_7 1 1 9 10 dim0=1 dim1=2 torch.transpose op_8 1 1 6 11 dim0=1 dim1=2 torch.transpose op_9 1 1 8 12 dim0=1 dim1=2 -Tensor.reshape op_10 1 1 10 15 shape=%q_shape2 -Tensor.reshape op_11 1 1 11 13 shape=%kv_shape2 -Tensor.reshape op_12 1 1 12 21 shape=%kv_shape2 +Tensor.reshape op_10 1 1 10 15 shape=(%num_heads,-1,%feat_per_head) +Tensor.reshape op_11 1 1 11 13 shape=(%num_heads,-1,%feat_per_head) +Tensor.reshape op_12 1 1 12 21 shape=(%num_heads,-1,%feat_per_head) torch.transpose op_13 1 1 13 14 dim0=1 dim1=2 torch.bmm op_14 2 1 15 14 16 -Tensor.view op_15 1 1 16 17 shape=%qk_shape -pnnx.Expression op_16 2 1 17 attn_mask 18 expr=%expr2 -Tensor.view op_17 1 1 18 19 shape=%qk_shape2 +Tensor.view op_15 1 1 16 17 shape=(%batch,%num_heads,%size,%size) +pnnx.Attribute attn_mask 0 1 attn_mask @data=(1,1,%size,%size)f32 +pnnx.Expression op_16 2 1 17 attn_mask 18 expr=add(@0,@1) +Tensor.view op_17 1 1 18 19 shape=(%num_heads,%size,%size) F.softmax op_18 1 1 19 20 dim=-1 torch.bmm op_19 2 1 20 21 22 -Tensor.view op_20 1 1 22 23 shape=%qkv_shape +Tensor.view op_20 1 1 22 23 shape=(%batch,%num_heads,%size,%feat_per_head) torch.transpose op_21 1 1 23 24 dim0=1 dim1=2 -Tensor.reshape op_22 1 1 24 25 shape=%qkv_shape2 -nn.Linear out_proj 1 1 25 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_22 1 1 24 25 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 25 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - bool match(const std::map& captured_params) const + const char* replace_pattern_graph() const { - bool matched = fuse_multiheadattention_pass_sameqkv::match(captured_params); - if (!matched) - return false; - - if (captured_params.at("expr2").s != "add(@0,@1)") - return false; - - // (1,4,20,20) - // (4,20,20) - const std::vector& qk_shape = captured_params.at("qk_shape").ai; - const std::vector& qk_shape2 = captured_params.at("qk_shape2").ai; - if (qk_shape.size() != 4 || qk_shape2.size() != 3) - return false; - - if (qk_shape[0] != 1 || qk_shape[1] != qk_shape2[0] || qk_shape[2] != qk_shape2[1] || qk_shape[3] != qk_shape2[2]) - return false; - - return true; - } - - bool match(const std::map& matched_operators) const - { - const Operator* op_16 = matched_operators.at("op_16"); - - // support constant attention mask only atm - Operand* attn_mask = op_16->inputs[1]; - if (attn_mask->consumers.size() > 1 || attn_mask->producer->type != "pnnx.Attribute") - return false; - - return true; + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute attn_mask 0 1 attn_mask @data=%attn_mask.data +nn.MultiheadAttention attention 2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - fuse_multiheadattention_pass_sameqkv::write(op, captured_params, captured_attrs); + fuse_multiheadattention_pass_sameqkv::write(ops, captured_params, captured_attrs); - Operand* attn_mask = op->inputs[1]; - Operator* op_attr = attn_mask->producer; + const int size = captured_params.at("size").i; - // hack attn_mask shape - attn_mask->shape = std::vector{attn_mask->shape[2], attn_mask->shape[3]}; - const std::string key = op_attr->attrs.begin()->first; - op_attr->attrs[key].shape = attn_mask->shape; + ops.at("attn_mask")->attrs["data"].shape = {size, size}; } }; @@ -1294,24 +1147,35 @@ public: return R"PNNXIR(7767517 20 19 pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 attn_mask -nn.Linear op_0 1 1 input 31 bias=%q_bias in_features=%embed_dim out_features=%embed_dim @bias @weight -nn.Linear op_1 1 1 input 32 bias=%k_bias in_features=%kdim out_features=%embed_dim @bias @weight -nn.Linear op_2 1 1 input 34 bias=%v_bias in_features=%vdim out_features=%embed_dim @bias @weight -Tensor.view op_3 1 1 31 36 shape=%q_shape -Tensor.view op_4 1 1 32 33 shape=%kv_shape -Tensor.view op_5 1 1 34 35 shape=%kv_shape +nn.Linear op_0 1 1 input 31 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 32 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 34 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.view op_3 1 1 31 36 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.view op_4 1 1 32 33 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.view op_5 1 1 34 35 shape=(%batch,%size,%num_heads,%feat_per_head) torch.permute op_6 1 1 36 38 dims=(0,2,1,3) torch.permute op_7 1 1 33 37 dims=(0,2,1,3) torch.permute op_8 1 1 35 43 dims=(0,2,1,3) torch.transpose op_9 1 1 37 39 dim0=-1 dim1=-2 torch.matmul op_10 2 1 38 39 40 -pnnx.Expression op_11 2 1 40 attn_mask 41 expr=%expr +pnnx.Attribute attn_mask 0 1 attn_mask @data=(1,1,1,%size)f32 +pnnx.Expression op_11 2 1 40 attn_mask 41 expr=add(div(@0,%sqrt_feat_per_head),@1) F.softmax op_12 1 1 41 42 dim=-1 torch.matmul op_13 2 1 42 43 44 torch.permute op_14 1 1 44 45 dims=(0,2,1,3) -Tensor.reshape op_15 1 1 45 46 shape=%qkv_shape -nn.Linear out_proj 1 1 46 out bias=%out_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_15 1 1 45 46 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 46 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Attribute attn_mask 0 1 attn_mask @data=%attn_mask.data +nn.MultiheadAttention attention 2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask pnnx.Output output 1 0 out )PNNXIR"; } @@ -1319,71 +1183,37 @@ pnnx.Output output 1 0 out bool match(const std::map& captured_params) const { const int embed_dim = captured_params.at("embed_dim").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; + const float sqrt_feat_per_head = captured_params.at("sqrt_feat_per_head").f; - // q_shape = (1,77,16,64) - // kv_shape = (1,77,16,64) - // qkv_shape = (1,77,1024) - const std::vector& q_shape = captured_params.at("q_shape").ai; - const std::vector& kv_shape = captured_params.at("kv_shape").ai; - const std::vector& qkv_shape = captured_params.at("qkv_shape").ai; - if (q_shape.size() != 4 || kv_shape.size() != 4 || qkv_shape.size() != 3) - return false; - - const int batch_size = q_shape[0]; - const int q_size = q_shape[1]; - const int num_heads = q_shape[2]; - const int feat_per_head = q_shape[3]; - - if (kv_shape[0] != batch_size || kv_shape[2] != num_heads || kv_shape[3] != feat_per_head) - return false; - - if (qkv_shape[0] != batch_size || qkv_shape[1] != q_size || qkv_shape[2] != feat_per_head * num_heads) - return false; - - // add(div(@0,8.000000e+00),@1) - const std::string& expr = captured_params.at("expr").s; - float sqrt_embed_dim_per_head = 0.f; - int nscan = sscanf(expr.c_str(), "add(div(@0,%f),@1)", &sqrt_embed_dim_per_head); - if (nscan != 1) + if (embed_dim != num_heads * feat_per_head) return false; - if (!NearlyEqual(sqrt_embed_dim_per_head, sqrt(embed_dim / num_heads), 0.001)) + if (!NearlyEqual(sqrt_feat_per_head, sqrt(feat_per_head), 0.001)) return false; return true; } - bool match(const std::map& matched_operators) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - const Operator* op_11 = matched_operators.at("op_11"); + fuse_multiheadattention_pass_sameqkv::write(ops, captured_params, captured_attrs); - // support constant attention mask only atm - Operand* attn_mask = op_11->inputs[1]; - if (attn_mask->consumers.size() > 1 || attn_mask->producer->type != "pnnx.Attribute") - return false; + const int size = captured_params.at("size").i; - return true; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - fuse_multiheadattention_pass_sameqkv::write(op, captured_params, captured_attrs); + Operator* op_attr = ops.at("attn_mask"); - Operand* attn_mask = op->inputs[1]; - Operator* op_attr = attn_mask->producer; - - int q_size = op->inputs[0]->shape[1]; + fprintf(stderr, "op_attr->attrs[data] type %d\n", op_attr->attrs["data"].type); // hack attn_mask shape - attn_mask->shape = std::vector{q_size, attn_mask->shape[3]}; - const std::string key = op_attr->attrs.begin()->first; - op_attr->attrs[key].shape = attn_mask->shape; + op_attr->attrs["data"].shape = {size, size}; // hack attn_mask value - std::vector& data = op_attr->attrs[key].data; + std::vector& data = op_attr->attrs["data"].data; size_t len = data.size(); - data.resize(len * q_size); - for (int i = 1; i < q_size; i++) + data.resize(len * size); + for (int i = 1; i < size; i++) { memcpy(&data[len * i], &data[0], len); } @@ -1398,64 +1228,50 @@ public: return R"PNNXIR(7767517 17 16 pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 attn_mask -nn.Linear op_0 1 1 input 8 bias=%qkv_bias in_features=%embed_dim out_features=%qkv_out_features @bias @weight -Tensor.reshape op_1 1 1 8 9 shape=%shape +nn.Linear op_0 1 1 input 8 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 8 9 shape=(%batch,%size,3,%num_heads,%feat_per_head) torch.permute op_2 1 1 9 10 dims=(2,0,3,1,4) torch.unbind op_3 1 3 10 11 12 13 dim=0 -pnnx.Expression op_4 1 1 11 14 expr=%expr +pnnx.Expression op_4 1 1 11 14 expr=mul(@0,%inv_sqrt_embed_dim_per_head) torch.transpose op_5 1 1 12 15 dim0=-2 dim1=-1 torch.matmul op_6 2 1 14 15 16 -pnnx.Expression op_7 2 1 16 attn_mask 18 expr=%expr2 +pnnx.Attribute attn_mask 0 1 attn_mask @data=(1,%num_heads,%size,%size)f32 +pnnx.Expression op_7 2 1 16 attn_mask 18 expr=add(@0,@1) F.softmax op_8 1 1 18 19 dim=-1 torch.matmul op_9 2 1 19 13 20 torch.transpose op_10 1 1 20 21 dim0=1 dim1=2 -Tensor.reshape op_11 1 1 21 22 shape=%shape2 -nn.Linear out_proj 1 1 22 out bias=%out_proj_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_11 1 1 21 22 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 22 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - bool match(const std::map& captured_params) const + const char* replace_pattern_graph() const { - bool matched = fuse_multiheadattention_pass::match(captured_params); - if (!matched) - return false; - - if (captured_params.at("expr2").s != "add(@0,@1)") - return false; - - return true; - } - - bool match(const std::map& matched_operators) const - { - const Operator* op_7 = matched_operators.at("op_7"); - - // support constant attention mask only atm - Operand* attn_mask = op_7->inputs[1]; - if (attn_mask->consumers.size() > 1 || attn_mask->producer->type != "pnnx.Attribute") - return false; - - return true; + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Attribute attn_mask 0 1 attn_mask @data=%attn_mask.data +nn.MultiheadAttention attention 2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - fuse_multiheadattention_pass::write(op, captured_params, captured_attrs); + fuse_multiheadattention_pass::write(ops, captured_params, captured_attrs); - Operand* attn_mask = op->inputs[1]; - Operator* op_attr = attn_mask->producer; + const int batch = captured_params.at("batch").i; + const int size = captured_params.at("size").i; + const int num_heads = captured_params.at("num_heads").i; - int batch = op->inputs[0]->shape[0]; + Operator* op_attr = ops.at("attn_mask"); // hack attn_mask shape - attn_mask->shape = std::vector{batch * attn_mask->shape[1], attn_mask->shape[2], attn_mask->shape[3]}; - const std::string key = op_attr->attrs.begin()->first; - op_attr->attrs[key].shape = attn_mask->shape; + op_attr->attrs["data"].shape = {batch * num_heads, size, size}; // hack attn_mask value - std::vector& data = op_attr->attrs[key].data; + std::vector& data = op_attr->attrs["data"].data; size_t len = data.size(); data.resize(len * batch); for (int i = 1; i < batch; i++) @@ -1473,97 +1289,54 @@ public: return R"PNNXIR(7767517 20 19 pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 attn_mask -nn.Linear op_0 1 1 input 25 bias=%qkv_bias in_features=%embed_dim out_features=%qkv_out_features @bias @weight -Tensor.reshape op_1 1 1 25 26 shape=%shape +nn.Linear op_0 1 1 input 25 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 25 26 shape=(%batch,%size,3,%num_heads,%feat_per_head) torch.permute op_2 1 1 26 27 dims=(2,0,3,1,4) torch.unbind op_3 1 3 27 28 29 30 dim=0 -pnnx.Expression op_4 1 1 28 31 expr=%expr +pnnx.Expression op_4 1 1 28 31 expr=mul(@0,%inv_sqrt_embed_dim_per_head) torch.transpose op_5 1 1 29 32 dim0=-2 dim1=-1 torch.matmul op_6 2 1 31 32 33 -pnnx.Expression op_7 2 1 33 attn_mask 35 expr=%expr2 -Tensor.view op_8 1 1 35 36 shape=%shapep -pnnx.Attribute op_9 0 1 37 @mask2 -pnnx.Expression op_10 2 1 36 37 38 expr=%expr2 -Tensor.view op_11 1 1 38 39 shape=%shapeq +pnnx.Attribute attn_mask 0 1 attn_mask @data=(1,%num_heads,%size,%size)f32 +pnnx.Expression op_7 2 1 33 attn_mask 35 expr=add(@0,@1) +Tensor.view op_8 1 1 35 36 shape=(1,%batch,%num_heads,%size,%size) +pnnx.Attribute op_9 0 1 37 @data=(1,%batch,1,%size,%size)f32 +pnnx.Expression op_10 2 1 36 37 38 expr=add(@0,@1) +Tensor.view op_11 1 1 38 39 shape=(-1,%num_heads,%size,%size) F.softmax op_12 1 1 39 40 dim=-1 torch.matmul op_13 2 1 40 30 41 torch.transpose op_14 1 1 41 42 dim0=1 dim1=2 -Tensor.reshape op_15 1 1 42 43 shape=%shape2 -nn.Linear out_proj 1 1 43 out bias=%out_proj_bias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_15 1 1 42 43 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 43 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight pnnx.Output output 1 0 out )PNNXIR"; } - bool match(const std::map& captured_params) const + const char* replace_pattern_graph() const { - bool matched = fuse_multiheadattention_pass::match(captured_params); - if (!matched) - return false; - - if (captured_params.at("expr2").s != "add(@0,@1)") - return false; - - // (1,64,3,49,49) - // (-1,3,49,49) - const std::vector& shapep = captured_params.at("shapep").ai; - const std::vector& shapeq = captured_params.at("shapeq").ai; - if (shapep.size() != 5 || shapeq.size() != 4) - return false; - - if (shapep[0] != 1 || (shapep[1] != shapeq[0] && shapeq[0] != -1) || shapep[2] != shapeq[1] || shapep[3] != shapeq[2] || shapep[4] != shapeq[3]) - return false; - - return true; - } - - bool match(const std::map& matched_operators) const - { - const Operator* op_7 = matched_operators.at("op_7"); - const Operator* op_10 = matched_operators.at("op_10"); - - // support constant attention mask only atm - Operand* attn_mask = op_7->inputs[1]; - if (attn_mask->consumers.size() > 1 || attn_mask->producer->type != "pnnx.Attribute") - return false; - - // @mask2=(1,64,1,49,49)f32 - Operand* attn_mask2 = op_10->inputs[1]; - if (attn_mask2->shape.size() != 5) - return false; - - if (attn_mask2->shape[0] != 1 || attn_mask2->shape[2] != 1) - return false; - - return true; + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Attribute attn_mask 0 1 attn_mask @data=%attn_mask.data +nn.MultiheadAttention attention 2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - fuse_multiheadattention_pass::write(op, captured_params, captured_attrs); + fuse_multiheadattention_pass::write(ops, captured_params, captured_attrs); - int num_heads = captured_params.at("shape").ai[captured_params.at("shape").ai.size() - 2]; - - Operand* attn_mask = op->inputs[1]; - Operator* op_attr = attn_mask->producer; - - // @mask2=(1,64,1,49,49)f32 - Attribute mask2; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 5) == "op_9.") - mask2 = x.second; - } + const int batch = captured_params.at("batch").i; + const int size = captured_params.at("size").i; + const int num_heads = captured_params.at("num_heads").i; - int batch = op->inputs[0]->shape[0]; + Operator* op_attr = ops.at("attn_mask"); // hack attn_mask shape - attn_mask->shape = std::vector{batch * attn_mask->shape[1], attn_mask->shape[2], attn_mask->shape[3]}; - const std::string key = op_attr->attrs.begin()->first; - op_attr->attrs[key].shape = attn_mask->shape; + op_attr->attrs["data"].shape = {batch * num_heads, size, size}; // hack attn_mask value - std::vector& data = op_attr->attrs[key].data; + std::vector& data = op_attr->attrs["data"].data; size_t len = data.size(); data.resize(len * batch); for (int i = 1; i < batch; i++) @@ -1573,7 +1346,8 @@ pnnx.Output output 1 0 out // add mask2 { - auto maskdata = op_attr->attrs[key].get_float32_data(); + auto mask2 = captured_attrs.at("op_9.data"); + auto maskdata = op_attr->attrs["data"].get_float32_data(); const int ls = mask2.shape[3] * mask2.shape[4]; for (int i = 0; i < batch; i++) @@ -1589,7 +1363,7 @@ pnnx.Output output 1 0 out } } - op_attr->attrs[key].set_float32_data(maskdata); + op_attr->attrs["data"].set_float32_data(maskdata); } } }; @@ -1603,6 +1377,8 @@ void fuse_multiheadattention(Graph& graph) fuse_multiheadattention_pass_qkv c; fuse_multiheadattention_pass_q_samekv d; fuse_multiheadattention_pass_1 b1; + fuse_multiheadattention_pass_1_1 b11; + fuse_multiheadattention_pass_1_2 b12; fuse_multiheadattention_pass_2 c1; fuse_multiheadattention_pass_3 d1; fuse_multiheadattention_pass_5 e; @@ -1628,6 +1404,8 @@ void fuse_multiheadattention(Graph& graph) pnnx_graph_rewrite(graph, &c, opindex); pnnx_graph_rewrite(graph, &d, opindex); pnnx_graph_rewrite(graph, &b1, opindex); + pnnx_graph_rewrite(graph, &b11, opindex); + pnnx_graph_rewrite(graph, &b12, opindex); pnnx_graph_rewrite(graph, &c1, opindex); pnnx_graph_rewrite(graph, &d1, opindex); pnnx_graph_rewrite(graph, &e, opindex); diff --git a/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp new file mode 100644 index 000000000..ab5eb886e --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp @@ -0,0 +1,92 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_scaled_dot_product_attention.h" + +#include "pass_level2.h" + +#include +#include + +#include + +namespace pnnx { + +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +class fuse_scaled_dot_product_attention_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 query #query=(%batch,%num_heads,%qsize,%feat_per_head)f32 +pnnx.Input input_1 0 1 key #key=(%batch,%num_heads,%kvsize,%feat_per_head)f32 +pnnx.Input input_2 0 1 value #value=(%batch,%num_heads,%kvsize,%feat_per_head)f32 +torch.permute op_0 1 1 key 59 dims=(0,1,3,2) +torch.matmul op_1 2 1 query 59 61 +pnnx.Expression op_2 1 1 61 62 expr=div(@0,%sqrt_embed_dim_per_head) +F.softmax op_3 1 1 62 63 dim=-1 +torch.matmul op_4 2 1 63 value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +F.scaled_dot_product_attention op_0 3 1 query key value out attn_mask=None dropout_p=0.0 is_causal=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + const int feat_per_head = captured_params.at("feat_per_head").i; + const float sqrt_embed_dim_per_head = captured_params.at("sqrt_embed_dim_per_head").f; + + if (!NearlyEqual(sqrt_embed_dim_per_head, sqrt(feat_per_head), 0.001)) + return false; + + return true; + } +}; + +void fuse_scaled_dot_product_attention(Graph& graph) +{ +#if TORCH_VERSION_MAJOR >= 2 + fuse_scaled_dot_product_attention_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +#endif +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h b/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h new file mode 100644 index 000000000..0eb13015c --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_scaled_dot_product_attention(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp index 169ee17a2..fd0071ac9 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp @@ -29,21 +29,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_mean 0 1 running_mean @qwq -pnnx.Attribute op_var 0 1 running_var @qwq +pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 +pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.BatchNorm1d"; - } - - const char* name_str() const - { - return "batchnorm"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.BatchNorm1d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data +pnnx.Output output 1 0 out +)PNNXIR"; } bool match(const std::map& matched_operators) const @@ -51,26 +51,6 @@ pnnx.Output output 1 0 out size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 2 || input_rank == 3; } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute running_mean; - Attribute running_var; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 8) == "op_mean.") - running_mean = x.second; - if (x.first.substr(0, 7) == "op_var.") - running_var = x.second; - } - - op->params["num_features"] = running_mean.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = false; - - op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = running_var; - } }; class fuse_static_Fbatchnorm_pass_1d_1 : public GraphRewriterPass @@ -81,23 +61,23 @@ public: return R"PNNXIR(7767517 7 6 pnnx.Input input 0 1 input -pnnx.Attribute op_mean 0 1 running_mean @qwq -pnnx.Attribute op_var 0 1 running_var @qwq -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 +pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 +pnnx.Attribute op_weight 0 1 weight @data +pnnx.Attribute op_bias 0 1 bias @data F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.BatchNorm1d"; - } - - const char* name_str() const - { - return "batchnorm"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.BatchNorm1d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } bool match(const std::map& matched_operators) const @@ -105,34 +85,6 @@ pnnx.Output output 1 0 out size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 2 || input_rank == 3; } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute running_mean; - Attribute running_var; - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 8) == "op_mean.") - running_mean = x.second; - if (x.first.substr(0, 7) == "op_var.") - running_var = x.second; - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["num_features"] = running_mean.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = true; - - op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = running_var; - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; - } }; class fuse_static_Fbatchnorm_pass_2d : public GraphRewriterPass @@ -143,47 +95,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_mean 0 1 running_mean @qwq -pnnx.Attribute op_var 0 1 running_var @qwq -F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps +pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 +pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 +F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps #input=(?,?,?,?)f32 pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.BatchNorm2d"; - } - - const char* name_str() const - { - return "batchnorm"; - } - - bool match(const std::map& matched_operators) const - { - size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); - return input_rank == 4; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + const char* replace_pattern_graph() const { - Attribute running_mean; - Attribute running_var; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 8) == "op_mean.") - running_mean = x.second; - if (x.first.substr(0, 7) == "op_var.") - running_var = x.second; - } - - op->params["num_features"] = running_mean.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = false; - - op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = running_var; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.BatchNorm2d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; @@ -195,57 +121,23 @@ public: return R"PNNXIR(7767517 7 6 pnnx.Input input 0 1 input -pnnx.Attribute op_mean 0 1 running_mean @qwq -pnnx.Attribute op_var 0 1 running_var @qwq -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps +pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 +pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 +pnnx.Attribute op_weight 0 1 weight @data +pnnx.Attribute op_bias 0 1 bias @data +F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps #input=(?,?,?,?)f32 pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.BatchNorm2d"; - } - - const char* name_str() const - { - return "batchnorm"; - } - - bool match(const std::map& matched_operators) const - { - size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); - return input_rank == 4; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + const char* replace_pattern_graph() const { - Attribute running_mean; - Attribute running_var; - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 8) == "op_mean.") - running_mean = x.second; - if (x.first.substr(0, 7) == "op_var.") - running_var = x.second; - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["num_features"] = running_mean.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = true; - - op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = running_var; - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.BatchNorm2d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; @@ -257,47 +149,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_mean 0 1 running_mean @qwq -pnnx.Attribute op_var 0 1 running_var @qwq -F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps +pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 +pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 +F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps #input=(?,?,?,?,?)f32 pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.BatchNorm3d"; - } - - const char* name_str() const - { - return "batchnorm"; - } - - bool match(const std::map& matched_operators) const - { - size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); - return input_rank == 5; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + const char* replace_pattern_graph() const { - Attribute running_mean; - Attribute running_var; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 8) == "op_mean.") - running_mean = x.second; - if (x.first.substr(0, 7) == "op_var.") - running_var = x.second; - } - - op->params["num_features"] = running_mean.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = false; - - op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = running_var; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.BatchNorm3d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; @@ -309,57 +175,23 @@ public: return R"PNNXIR(7767517 7 6 pnnx.Input input 0 1 input -pnnx.Attribute op_mean 0 1 running_mean @qwq -pnnx.Attribute op_var 0 1 running_var @qwq -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps +pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 +pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 +pnnx.Attribute op_weight 0 1 weight @data +pnnx.Attribute op_bias 0 1 bias @data +F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps #input=(?,?,?,?,?)f32 pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.BatchNorm3d"; - } - - const char* name_str() const - { - return "batchnorm"; - } - - bool match(const std::map& matched_operators) const - { - size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); - return input_rank == 5; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + const char* replace_pattern_graph() const { - Attribute running_mean; - Attribute running_var; - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 8) == "op_mean.") - running_mean = x.second; - if (x.first.substr(0, 7) == "op_var.") - running_var = x.second; - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["num_features"] = running_mean.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = true; - - op->attrs["running_mean"] = running_mean; - op->attrs["running_var"] = running_var; - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.BatchNorm3d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; diff --git a/tools/pnnx/src/pass_level5/fuse_static_conv.cpp b/tools/pnnx/src/pass_level5/fuse_static_conv.cpp index 6e29bcaac..4dda5006d 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_conv.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_conv.cpp @@ -29,42 +29,30 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32 F.conv1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv1d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv1d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = false; - - op->attrs["weight"] = weight; + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; + + ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups; } }; @@ -76,47 +64,31 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 F.conv1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv1d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv1d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; + + ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups; } }; @@ -128,71 +100,32 @@ public: return R"PNNXIR(7767517 6 5 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1)f32 F.conv1d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Expression op_1 2 1 a bias out expr=%expr +pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1) pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv1d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv1d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - bool match(const std::map& captured_params, const std::map& captured_attrs) const - { - const std::string& expr = captured_params.at("expr").s; - if (expr != "add(@0,@1)") - return false; - - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - int out_channels = weight.shape[0]; - if (bias.shape != std::vector{1, out_channels, 1}) - return false; - - return true; - } + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups; } }; @@ -204,42 +137,30 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32 F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv2d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv2d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2], weight.shape[3]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = false; - - op->attrs["weight"] = weight; + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; + + ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups; } }; @@ -251,47 +172,31 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv2d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv2d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2], weight.shape[3]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; + + ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups; } }; @@ -303,71 +208,32 @@ public: return R"PNNXIR(7767517 6 5 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1,1)f32 F.conv2d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Expression op_1 2 1 a bias out expr=%expr +pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1) pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv2d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv2d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - bool match(const std::map& captured_params, const std::map& captured_attrs) const - { - const std::string& expr = captured_params.at("expr").s; - if (expr != "add(@0,@1)") - return false; - - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - int out_channels = weight.shape[0]; - if (bias.shape != std::vector{1, out_channels, 1, 1}) - return false; - - return true; - } + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2], weight.shape[3]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups; } }; @@ -379,42 +245,30 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32 F.conv3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv3d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv3d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2], weight.shape[3], weight.shape[4]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = false; - - op->attrs["weight"] = weight; + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; + + ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups; } }; @@ -426,47 +280,31 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 F.conv3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv3d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv3d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2], weight.shape[3], weight.shape[4]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; + + ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups; } }; @@ -478,71 +316,32 @@ public: return R"PNNXIR(7767517 6 5 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1,1,1)f32 F.conv3d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Expression op_1 2 1 a bias out expr=%expr +pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1) pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Conv3d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - const char* name_str() const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - return "conv3d"; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - bool match(const std::map& captured_params, const std::map& captured_attrs) const - { - const std::string& expr = captured_params.at("expr").s; - if (expr != "add(@0,@1)") - return false; - - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - int out_channels = weight.shape[0]; - if (bias.shape != std::vector{1, out_channels, 1, 1, 1}) - return false; - - return true; - } + const int in_channels_per_group = captured_params.at("in_channels_per_group").i; + const int groups = captured_params.at("groups").i; - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; - op->params["out_channels"] = weight.shape[0]; - op->params["kernel_size"] = std::vector{weight.shape[2], weight.shape[3], weight.shape[4]}; - op->params["padding_mode"] = std::string("zeros"); - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["groups"] = captured_params.at("groups"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups; } }; diff --git a/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp index 6f6e16495..5d6aa66f3 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp @@ -29,44 +29,30 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kw)f32 F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.ConvTranspose1d"; - } - - const char* name_str() const - { - return "conv_transpose1d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConvTranspose1d conv_transpose1d 1 1 input out in_channels=%in_channels kernel_size=(%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); + const int out_channels_per_group = captured_params.at("out_channels_per_group").i; const int groups = captured_params.at("groups").i; - op->params["groups"] = groups; - op->params["in_channels"] = weight.shape[0]; - op->params["out_channels"] = weight.shape[1] * groups; - op->params["kernel_size"] = Parameter{weight.shape[2]}; - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["output_padding"] = captured_params.at("output_padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["bias"] = false; - - op->attrs["weight"] = weight; + ops.at("conv_transpose1d")->params["out_channels"] = out_channels_per_group * groups; } }; @@ -78,49 +64,33 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.ConvTranspose1d"; - } - - const char* name_str() const + const char* replace_pattern_graph() const { - return "conv_transpose1d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConvTranspose1d conv_transpose1d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& captured_params) const { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - + const int out_channels_per_group = captured_params.at("out_channels_per_group").i; + const int out_channels = captured_params.at("out_channels").i; const int groups = captured_params.at("groups").i; - op->params["groups"] = groups; - op->params["in_channels"] = weight.shape[0]; - op->params["out_channels"] = weight.shape[1] * groups; - op->params["kernel_size"] = Parameter{weight.shape[2]}; - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["output_padding"] = captured_params.at("output_padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + if (out_channels != out_channels_per_group * groups) + return false; + + return true; } }; @@ -132,44 +102,30 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kh,%kw)f32 F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.ConvTranspose2d"; - } - - const char* name_str() const + const char* replace_pattern_graph() const { - return "conv_transpose2d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConvTranspose2d conv_transpose2d 1 1 input out in_channels=%in_channels kernel_size=(%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); + const int out_channels_per_group = captured_params.at("out_channels_per_group").i; const int groups = captured_params.at("groups").i; - op->params["groups"] = groups; - op->params["in_channels"] = weight.shape[0]; - op->params["out_channels"] = weight.shape[1] * groups; - op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]}; - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["output_padding"] = captured_params.at("output_padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["bias"] = false; - - op->attrs["weight"] = weight; + ops.at("conv_transpose2d")->params["out_channels"] = out_channels_per_group * groups; } }; @@ -181,49 +137,33 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kh,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.ConvTranspose2d"; - } - - const char* name_str() const + const char* replace_pattern_graph() const { - return "conv_transpose2d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConvTranspose2d conv_transpose2d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& captured_params) const { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - + const int out_channels_per_group = captured_params.at("out_channels_per_group").i; + const int out_channels = captured_params.at("out_channels").i; const int groups = captured_params.at("groups").i; - op->params["groups"] = groups; - op->params["in_channels"] = weight.shape[0]; - op->params["out_channels"] = weight.shape[1] * groups; - op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]}; - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["output_padding"] = captured_params.at("output_padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + if (out_channels != out_channels_per_group * groups) + return false; + + return true; } }; @@ -235,44 +175,30 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kd,%kh,%kw)f32 F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.ConvTranspose3d"; - } - - const char* name_str() const + const char* replace_pattern_graph() const { - return "conv_transpose3d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConvTranspose3d conv_transpose3d 1 1 input out in_channels=%in_channels kernel_size=(%kd,%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); + const int out_channels_per_group = captured_params.at("out_channels_per_group").i; const int groups = captured_params.at("groups").i; - op->params["groups"] = groups; - op->params["in_channels"] = weight.shape[0]; - op->params["out_channels"] = weight.shape[1] * groups; - op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]}; - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["output_padding"] = captured_params.at("output_padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["bias"] = false; - - op->attrs["weight"] = weight; + ops.at("conv_transpose3d")->params["out_channels"] = out_channels_per_group * groups; } }; @@ -284,49 +210,33 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kd,%kh,%kw)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.ConvTranspose3d"; - } - - const char* name_str() const + const char* replace_pattern_graph() const { - return "conv_transpose3d"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConvTranspose3d conv_transpose3d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kd,%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& captured_params) const { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - + const int out_channels_per_group = captured_params.at("out_channels_per_group").i; + const int out_channels = captured_params.at("out_channels").i; const int groups = captured_params.at("groups").i; - op->params["groups"] = groups; - op->params["in_channels"] = weight.shape[0]; - op->params["out_channels"] = weight.shape[1] * groups; - op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]}; - op->params["stride"] = captured_params.at("stride"); - op->params["padding"] = captured_params.at("padding"); - op->params["output_padding"] = captured_params.at("output_padding"); - op->params["dilation"] = captured_params.at("dilation"); - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + if (out_channels != out_channels_per_group * groups) + return false; + + return true; } }; diff --git a/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp index 203168e25..da0d6112b 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp @@ -29,42 +29,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%num_channels)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%num_channels)f32 F.group_norm op_0 3 1 input weight bias out num_groups=%num_groups eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.GroupNorm"; - } - - const char* name_str() const - { - return "group_norm"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["num_channels"] = weight.shape[0]; - op->params["num_groups"] = captured_params.at("num_groups"); - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.GroupNorm group_norm 1 1 input out num_channels=%num_channels num_groups=%num_groups eps=%eps affine=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; diff --git a/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp index f612178c9..c9a1c7a7c 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp @@ -29,21 +29,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32 F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.InstanceNorm1d"; - } - - const char* name_str() const - { - return "instance_norm"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.InstanceNorm1d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } bool match(const std::map& matched_operators) const @@ -51,27 +51,6 @@ pnnx.Output output 1 0 out size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 2 || input_rank == 3; } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["num_features"] = weight.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = true; - op->params["track_running_stats"] = false; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; - } }; class fuse_static_Finstancenorm_pass_2d : public GraphRewriterPass @@ -82,48 +61,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps +pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32 +F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps #input=(?,?,?,?)f32 pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.InstanceNorm1d"; - } - - const char* name_str() const - { - return "instance_norm"; - } - - bool match(const std::map& matched_operators) const - { - size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); - return input_rank == 4; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + const char* replace_pattern_graph() const { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["num_features"] = weight.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = true; - op->params["track_running_stats"] = false; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.InstanceNorm2d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; @@ -135,48 +87,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps +pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32 +F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps #input=(?,?,?,?,?)f32 pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.InstanceNorm1d"; - } - - const char* name_str() const - { - return "instance_norm"; - } - - bool match(const std::map& matched_operators) const - { - size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); - return input_rank == 5; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + const char* replace_pattern_graph() const { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["num_features"] = weight.shape[0]; - op->params["eps"] = captured_params.at("eps"); - op->params["affine"] = true; - op->params["track_running_stats"] = false; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.InstanceNorm3d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; diff --git a/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp index d6c494f08..0b1f0dc41 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp @@ -29,41 +29,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data +pnnx.Attribute op_bias 0 1 bias @data F.layer_norm op_0 3 1 input weight bias out normalized_shape=%normalized_shape eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.LayerNorm"; - } - - const char* name_str() const - { - return "layer_norm"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["normalized_shape"] = captured_params.at("normalized_shape"); - op->params["eps"] = captured_params.at("eps"); - op->params["elementwise_affine"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.LayerNorm layer_norm 1 1 input out normalized_shape=%normalized_shape eps=%eps elementwise_affine=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; diff --git a/tools/pnnx/src/pass_level5/fuse_static_linear.cpp b/tools/pnnx/src/pass_level5/fuse_static_linear.cpp index a49665b97..7396142d4 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_linear.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_linear.cpp @@ -29,36 +29,20 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32 F.linear op_0 2 1 input weight out bias=None pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Linear"; - } - - const char* name_str() const - { - return "linear"; - } - - void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["in_features"] = weight.shape[1]; - op->params["out_features"] = weight.shape[0]; - op->params["bias"] = false; - - op->attrs["weight"] = weight; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=False @weight=%op_weight.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; @@ -70,41 +54,21 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32 +pnnx.Attribute op_bias 0 1 bias @data=(%out_features)f32 F.linear op_0 3 1 input weight bias out pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const + const char* replace_pattern_graph() const { - return "nn.Linear"; - } - - const char* name_str() const - { - return "linear"; - } - - void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["in_features"] = weight.shape[1]; - op->params["out_features"] = weight.shape[0]; - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } }; @@ -116,65 +80,31 @@ public: return R"PNNXIR(7767517 6 5 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.linear op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Expression op_1 2 1 a bias out expr=%expr +pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32 +pnnx.Attribute op_bias 0 1 bias @data=(1,%out_features,1)f32 +F.linear op_0 2 1 input weight a +pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1) pnnx.Output output 1 0 out )PNNXIR"; } - const char* type_str() const - { - return "nn.Linear"; - } - - const char* name_str() const + const char* replace_pattern_graph() const { - return "linear"; + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=True @weight=%op_weight.data @bias=%op_bias.data +pnnx.Output output 1 0 out +)PNNXIR"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const { - const std::string& expr = captured_params.at("expr").s; - if (expr != "add(@0,@1)") - return false; - - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - int out_channels = weight.shape[0]; - if (bias.shape != std::vector{1, out_channels, 1}) - return false; - - return true; - } + GraphRewriterPass::write(ops, captured_params, captured_attrs); - void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["in_features"] = weight.shape[1]; - op->params["out_features"] = weight.shape[0]; - op->params["bias"] = true; - - op->attrs["weight"] = weight; - op->attrs["bias"] = bias; + // fix bias shape + const int out_features = captured_params.at("out_features").i; + ops.at("linear")->attrs.at("bias").shape = {out_features}; } }; diff --git a/tools/pnnx/src/pass_ncnn/F_batch_norm.cpp b/tools/pnnx/src/pass_ncnn/F_batch_norm.cpp index 90fd2c661..68971c647 100644 --- a/tools/pnnx/src/pass_ncnn/F_batch_norm.cpp +++ b/tools/pnnx/src/pass_ncnn/F_batch_norm.cpp @@ -26,8 +26,8 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_mean 0 1 running_mean @qwq -pnnx.Attribute op_var 0 1 running_var @qwq +pnnx.Attribute op_mean 0 1 running_mean @data +pnnx.Attribute op_var 0 1 running_var @data F.batch_norm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps pnnx.Output output 1 0 out )PNNXIR"; @@ -45,15 +45,8 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - Attribute running_mean; - Attribute running_var; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 8) == "op_mean.") - running_mean = x.second; - if (x.first.substr(0, 7) == "op_var.") - running_var = x.second; - } + Attribute running_mean = captured_attrs.at("op_mean.data"); + Attribute running_var = captured_attrs.at("op_var.data"); op->params["0"] = running_mean.shape[0]; op->params["1"] = captured_params.at("eps"); @@ -77,10 +70,10 @@ public: return R"PNNXIR(7767517 7 6 pnnx.Input input 0 1 input -pnnx.Attribute op_mean 0 1 running_mean @qwq -pnnx.Attribute op_var 0 1 running_var @qwq -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_mean 0 1 running_mean @data +pnnx.Attribute op_var 0 1 running_var @data +pnnx.Attribute op_weight 0 1 weight @data +pnnx.Attribute op_bias 0 1 bias @data F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps pnnx.Output output 1 0 out )PNNXIR"; @@ -98,21 +91,10 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - Attribute running_mean; - Attribute running_var; - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 8) == "op_mean.") - running_mean = x.second; - if (x.first.substr(0, 7) == "op_var.") - running_var = x.second; - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } + Attribute running_mean = captured_attrs.at("op_mean.data"); + Attribute running_var = captured_attrs.at("op_var.data"); + Attribute weight = captured_attrs.at("op_weight.data"); + Attribute bias = captured_attrs.at("op_bias.data"); op->params["0"] = running_mean.shape[0]; op->params["1"] = captured_params.at("eps"); diff --git a/tools/pnnx/src/pass_ncnn/F_embedding.cpp b/tools/pnnx/src/pass_ncnn/F_embedding.cpp index ad2978533..708f69c33 100644 --- a/tools/pnnx/src/pass_ncnn/F_embedding.cpp +++ b/tools/pnnx/src/pass_ncnn/F_embedding.cpp @@ -26,7 +26,7 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data F.embedding op_0 2 1 input weight out scale_grad_by_freq=False sparse=False pnnx.Output output 1 0 out )PNNXIR"; @@ -44,12 +44,7 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } + Attribute weight = captured_attrs.at("op_weight.data"); op->params["0"] = weight.shape[1]; op->params["1"] = weight.shape[0]; diff --git a/tools/pnnx/src/pass_ncnn/F_instance_norm.cpp b/tools/pnnx/src/pass_ncnn/F_instance_norm.cpp index 003b49bc7..36d71afa6 100644 --- a/tools/pnnx/src/pass_ncnn/F_instance_norm.cpp +++ b/tools/pnnx/src/pass_ncnn/F_instance_norm.cpp @@ -67,8 +67,8 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq +pnnx.Attribute op_weight 0 1 weight @data +pnnx.Attribute op_bias 0 1 bias @data F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps pnnx.Output output 1 0 out )PNNXIR"; @@ -86,15 +86,8 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } + Attribute weight = captured_attrs.at("op_weight.data"); + Attribute bias = captured_attrs.at("op_bias.data"); op->params["0"] = weight.shape[0]; op->params["1"] = captured_params.at("eps"); diff --git a/tools/pnnx/src/pass_ncnn/F_prelu.cpp b/tools/pnnx/src/pass_ncnn/F_prelu.cpp index 14ae60922..b9036d3f1 100644 --- a/tools/pnnx/src/pass_ncnn/F_prelu.cpp +++ b/tools/pnnx/src/pass_ncnn/F_prelu.cpp @@ -26,7 +26,7 @@ public: return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_weight 0 1 weight @data F.prelu op_0 2 1 input weight out pnnx.Output output 1 0 out )PNNXIR"; @@ -44,12 +44,7 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } + Attribute weight = captured_attrs.at("op_weight.data"); op->params["0"] = weight.shape[0]; diff --git a/tools/pnnx/src/pass_ncnn/torch_addmm.cpp b/tools/pnnx/src/pass_ncnn/torch_addmm.cpp index 5a5df0771..ed97aa37e 100644 --- a/tools/pnnx/src/pass_ncnn/torch_addmm.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_addmm.cpp @@ -26,8 +26,8 @@ public: return R"PNNXIR(7767517 5 4 pnnx.Input input_0 0 1 mat1 -pnnx.Attribute op_bias 0 1 bias @qwq -pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @data +pnnx.Attribute op_weight 0 1 weight @data torch.addmm op_0 3 1 bias mat1 weight out alpha=%alpha beta=%beta pnnx.Output output 1 0 out )PNNXIR"; @@ -69,15 +69,8 @@ pnnx.Output output 1 0 out if (alpha != 1.f || beta != 1.f) return false; - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } + Attribute weight = captured_attrs.at("op_weight.data"); + Attribute bias = captured_attrs.at("op_bias.data"); if (weight.shape.size() != 2 || bias.shape.size() != 1) return false; @@ -90,15 +83,8 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } + Attribute weight = captured_attrs.at("op_weight.data"); + Attribute bias = captured_attrs.at("op_bias.data"); // transpose weight inch-outch to outch-inch const int inch = weight.shape[0]; diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index f91ee66a6..8330e3f55 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -213,6 +213,7 @@ pnnx_add_test(torch_ones) pnnx_add_test(torch_ones_like) pnnx_add_test(torch_permute) pnnx_add_test(torch_prod) +pnnx_add_test(torch_repeat_interleave) pnnx_add_test(torch_scatter_add) pnnx_add_test(torch_sum) pnnx_add_test(torch_split) @@ -298,8 +299,10 @@ pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) pnnx_add_test(pnnx_fuse_input_unpack) +pnnx_add_test(pnnx_fuse_layernorm) pnnx_add_test(pnnx_fuse_linear_batchnorm1d) pnnx_add_test(pnnx_fuse_multiheadattention) +pnnx_add_test(pnnx_fuse_scaled_dot_product_attention) pnnx_add_test(pnnx_fuse_select_to_unbind) pnnx_add_test(pnnx_fuse_slice_to_tensor_split) pnnx_add_test(pnnx_fuse_adjacent_reshape) diff --git a/tools/pnnx/tests/test_pnnx_fuse_layernorm.py b/tools/pnnx/tests/test_pnnx_fuse_layernorm.py new file mode 100644 index 000000000..d8f788ac5 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_layernorm.py @@ -0,0 +1,70 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.ln_0 = LayerNorm2d(64) + + def forward(self, x): + x = self.ln_0(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 64, 16, 16) + + a0 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_fuse_layernorm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_layernorm.pt inputshape=[1,64,16,16]") + + # pnnx inference + import test_pnnx_fuse_layernorm_pnnx + b0 = test_pnnx_fuse_layernorm_pnnx.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_scaled_dot_product_attention.py b/tools/pnnx/tests/test_pnnx_fuse_scaled_dot_product_attention.py new file mode 100644 index 000000000..b130cddea --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_scaled_dot_product_attention.py @@ -0,0 +1,127 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +from einops import rearrange +from typing import Any, Optional, Tuple, Union +from torch import Tensor +import math + +class sam_Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.attention_0 = sam_Attention(embedding_dim=64, num_heads=4, downsample_rate=1) + self.attention_1 = sam_Attention(embedding_dim=64, num_heads=4, downsample_rate=2) + + def forward(self, q, k, v): + a = self.attention_0(q, k, v) + b = self.attention_1(q, k, v) + + return a, b + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 24, 64) + z = torch.rand(1, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_pnnx_fuse_scaled_dot_product_attention.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_scaled_dot_product_attention.pt inputshape=[1,24,64],[1,24,64],[1,24,64]") + + # pnnx inference + import test_pnnx_fuse_scaled_dot_product_attention_pnnx + b = test_pnnx_fuse_scaled_dot_product_attention_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_repeat_interleave.py b/tools/pnnx/tests/test_torch_repeat_interleave.py new file mode 100644 index 000000000..b67418bf8 --- /dev/null +++ b/tools/pnnx/tests/test_torch_repeat_interleave.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.repeat_interleave(x, 2) + y = torch.repeat_interleave(y, 3, dim=1) + if version.parse(torch.__version__) >= version.parse('1.10'): + z = torch.repeat_interleave(z, torch.tensor([2, 1, 3]), dim=0, output_size=6) + else: + z = torch.repeat_interleave(z, torch.tensor([2, 1, 3]), dim=0) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3) + y = torch.rand(4, 5) + z = torch.rand(3, 7, 8) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_repeat_interleave.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_repeat_interleave.pt inputshape=[3],[4,5],[3,7,8]") + + # pnnx inference + import test_torch_repeat_interleave_pnnx + b = test_torch_repeat_interleave_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)