| @@ -242,6 +242,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_randn.cpp | |||
| pass_level2/torch_randn_like.cpp | |||
| pass_level2/torch_real.cpp | |||
| pass_level2/torch_repeat_interleave.cpp | |||
| pass_level2/torch_roll.cpp | |||
| pass_level2/torch_scatter_add.cpp | |||
| pass_level2/torch_split.cpp | |||
| @@ -332,7 +333,9 @@ set(pnnx_pass_level5_SRCS | |||
| pass_level5/fuse_pad_conv1d.cpp | |||
| pass_level5/fuse_pad_conv2d.cpp | |||
| pass_level5/fuse_pixel_unshuffle.cpp | |||
| pass_level5/fuse_layernorm.cpp | |||
| pass_level5/fuse_multiheadattention.cpp | |||
| pass_level5/fuse_scaled_dot_product_attention.cpp | |||
| pass_level5/fuse_select_to_unbind.cpp | |||
| pass_level5/fuse_slice_copy.cpp | |||
| pass_level5/fuse_slice_indices.cpp | |||
| @@ -588,6 +588,14 @@ Attribute operator+(const Attribute& a, const Attribute& b) | |||
| Parameter Parameter::parse_from_string(const std::string& value) | |||
| { | |||
| if (value.find('%') != std::string::npos) | |||
| { | |||
| Parameter p; | |||
| p.type = 4; | |||
| p.s = value; | |||
| return p; | |||
| } | |||
| Parameter p; | |||
| p.type = 0; | |||
| @@ -659,6 +667,96 @@ Parameter Parameter::parse_from_string(const std::string& value) | |||
| return p; | |||
| } | |||
| std::string Parameter::encode_to_string(const Parameter& param) | |||
| { | |||
| if (param.type == 0) | |||
| { | |||
| return std::string("None"); | |||
| } | |||
| if (param.type == 1) | |||
| { | |||
| if (param.b) | |||
| return std::string("True"); | |||
| else | |||
| return std::string("False"); | |||
| } | |||
| if (param.type == 2) | |||
| { | |||
| return std::to_string(param.i); | |||
| } | |||
| if (param.type == 3) | |||
| { | |||
| char buf[64]; | |||
| sprintf(buf, "%e", param.f); | |||
| return std::string(buf); | |||
| } | |||
| if (param.type == 4) | |||
| { | |||
| return param.s; | |||
| } | |||
| if (param.type == 5) | |||
| { | |||
| std::string s("("); | |||
| for (size_t i = 0; i < param.ai.size(); i++) | |||
| { | |||
| s += std::to_string(param.ai[i]); | |||
| if (i + 1 != param.ai.size()) | |||
| s += std::string(","); | |||
| } | |||
| s += std::string(")"); | |||
| return s; | |||
| } | |||
| if (param.type == 6) | |||
| { | |||
| std::string s("("); | |||
| for (size_t i = 0; i < param.af.size(); i++) | |||
| { | |||
| char buf[64]; | |||
| sprintf(buf, "%e", param.af[i]); | |||
| s += std::string(buf); | |||
| if (i + 1 != param.af.size()) | |||
| s += std::string(","); | |||
| } | |||
| s += std::string(")"); | |||
| return s; | |||
| } | |||
| if (param.type == 7) | |||
| { | |||
| std::string s("("); | |||
| for (size_t i = 0; i < param.as.size(); i++) | |||
| { | |||
| s += param.as[i]; | |||
| if (i + 1 != param.as.size()) | |||
| s += std::string(","); | |||
| } | |||
| s += std::string(")"); | |||
| return s; | |||
| } | |||
| if (param.type == 10) | |||
| { | |||
| char buf[128]; | |||
| sprintf(buf, "%e+%ej", param.c.real(), param.c.imag()); | |||
| return std::string(buf); | |||
| } | |||
| if (param.type == 11) | |||
| { | |||
| std::string s("("); | |||
| for (size_t i = 0; i < param.ac.size(); i++) | |||
| { | |||
| char buf[128]; | |||
| sprintf(buf, "%e+%ej", param.ac[i].real(), param.ac[i].imag()); | |||
| s += std::string(buf); | |||
| if (i + 1 != param.ac.size()) | |||
| s += std::string(","); | |||
| } | |||
| s += std::string(")"); | |||
| return s; | |||
| } | |||
| fprintf(stderr, "unknown parameter type %d\n", param.type); | |||
| return std::string(); | |||
| } | |||
| Graph::Graph() | |||
| { | |||
| } | |||
| @@ -752,6 +850,14 @@ static void load_shape(Operator* op, const std::string& key, const std::string& | |||
| { | |||
| operand->shape.push_back(-1); | |||
| } | |||
| else if (elem[0] == '%') | |||
| { | |||
| // encode %abc as symbolic tag | |||
| operand->shape.push_back(-233); | |||
| int index = operand->shape.size() - 1; | |||
| std::string key = elem.substr(1); | |||
| operand->params[std::string("__shape_") + std::to_string(index)] = key; | |||
| } | |||
| else | |||
| { | |||
| int i = std::stoi(elem); | |||
| @@ -965,77 +1071,8 @@ int Graph::save(const std::string& parampath, const std::string& binpath) | |||
| fprintf(paramfp, " %s=", it.first.c_str()); | |||
| const Parameter& param = it.second; | |||
| if (param.type == 0) | |||
| { | |||
| fprintf(paramfp, "None"); | |||
| } | |||
| if (param.type == 1) | |||
| { | |||
| if (param.b) | |||
| fprintf(paramfp, "True"); | |||
| else | |||
| fprintf(paramfp, "False"); | |||
| } | |||
| if (param.type == 2) | |||
| { | |||
| fprintf(paramfp, "%d", param.i); | |||
| } | |||
| if (param.type == 3) | |||
| { | |||
| fprintf(paramfp, "%e", param.f); | |||
| } | |||
| if (param.type == 4) | |||
| { | |||
| fprintf(paramfp, "%s", param.s.c_str()); | |||
| } | |||
| if (param.type == 5) | |||
| { | |||
| fprintf(paramfp, "("); | |||
| for (size_t i = 0; i < param.ai.size(); i++) | |||
| { | |||
| fprintf(paramfp, "%d", param.ai[i]); | |||
| if (i + 1 != param.ai.size()) | |||
| fprintf(paramfp, ","); | |||
| } | |||
| fprintf(paramfp, ")"); | |||
| } | |||
| if (param.type == 6) | |||
| { | |||
| fprintf(paramfp, "("); | |||
| for (size_t i = 0; i < param.af.size(); i++) | |||
| { | |||
| fprintf(paramfp, "%e", param.af[i]); | |||
| if (i + 1 != param.af.size()) | |||
| fprintf(paramfp, ","); | |||
| } | |||
| fprintf(paramfp, ")"); | |||
| } | |||
| if (param.type == 7) | |||
| { | |||
| fprintf(paramfp, "("); | |||
| for (size_t i = 0; i < param.as.size(); i++) | |||
| { | |||
| fprintf(paramfp, "%s", param.as[i].c_str()); | |||
| if (i + 1 != param.as.size()) | |||
| fprintf(paramfp, ","); | |||
| } | |||
| fprintf(paramfp, ")"); | |||
| } | |||
| if (param.type == 10) | |||
| { | |||
| fprintf(paramfp, "%e+%ej", param.c.real(), param.c.imag()); | |||
| } | |||
| if (param.type == 11) | |||
| { | |||
| fprintf(paramfp, "("); | |||
| for (size_t i = 0; i < param.ac.size(); i++) | |||
| { | |||
| fprintf(paramfp, "%e+%ej", param.ac[i].real(), param.ac[i].imag()); | |||
| if (i + 1 != param.ac.size()) | |||
| fprintf(paramfp, ","); | |||
| } | |||
| fprintf(paramfp, ")"); | |||
| } | |||
| std::string s = Parameter::encode_to_string(param); | |||
| fprintf(paramfp, "%s", s.c_str()); | |||
| } | |||
| for (const auto& it : op->attrs) | |||
| @@ -2638,11 +2675,62 @@ int Graph::parse(const std::string& param) | |||
| { | |||
| // attribute | |||
| // load_attribute(op, key.substr(1), value, szr); | |||
| op->attrs[key.substr(1)] = Attribute(); | |||
| Attribute& attr = op->attrs[key.substr(1)]; | |||
| attr.type = 0; | |||
| if (value.empty()) | |||
| continue; | |||
| if (value[0] == '%') | |||
| { | |||
| // @data=%op1.data | |||
| attr.data = std::vector<char>(value.begin(), value.end()); | |||
| } | |||
| if (value[0] == '(') | |||
| { | |||
| // @data=(1,%c,?,4)f32 | |||
| // type | |||
| std::string typestr = value.substr(value.find_last_of(')') + 1); | |||
| attr.type = string_to_type(typestr.c_str()); | |||
| // shape | |||
| std::string lc = value.substr(1, value.find_last_of(')') - 1); | |||
| std::istringstream lcss(lc); | |||
| attr.shape.clear(); | |||
| while (!lcss.eof()) | |||
| { | |||
| std::string elem; | |||
| std::getline(lcss, elem, ','); | |||
| if (elem == "?") | |||
| { | |||
| attr.shape.push_back(-1); | |||
| } | |||
| else if (elem[0] == '%') | |||
| { | |||
| // encode %abc as symbolic tag | |||
| attr.shape.push_back(-233); | |||
| int index = attr.shape.size() - 1; | |||
| std::string key = elem.substr(1); | |||
| attr.params[std::string("__shape_") + std::to_string(index)] = key; | |||
| } | |||
| else | |||
| { | |||
| int i = std::stoi(elem); | |||
| attr.shape.push_back(i); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| else if (key[0] == '$') | |||
| { | |||
| // operand input key | |||
| // load_input_key(op, key.substr(1), value); | |||
| load_input_key(op, key.substr(1), value); | |||
| } | |||
| else if (key[0] == '#') | |||
| { | |||
| @@ -171,6 +171,7 @@ public: | |||
| #endif // BUILD_PNNX | |||
| static Parameter parse_from_string(const std::string& value); | |||
| static std::string encode_to_string(const Parameter& param); | |||
| // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others 10=c 11=ac | |||
| int type; | |||
| @@ -217,6 +218,8 @@ public: | |||
| std::vector<int> shape; | |||
| std::vector<char> data; | |||
| std::map<std::string, Parameter> params; | |||
| }; | |||
| bool operator==(const Attribute& lhs, const Attribute& rhs); | |||
| @@ -246,6 +249,7 @@ private: | |||
| friend class Graph; | |||
| Operand() | |||
| { | |||
| type = 0; | |||
| } | |||
| }; | |||
| @@ -131,7 +131,7 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit | |||
| // sub_mod.dump(true, true, true); | |||
| op->attrs[name] = sub_mod.attr(name).toTensor(); | |||
| op->attrs["data"] = sub_mod.attr(name).toTensor(); | |||
| } | |||
| } | |||
| else if (n->kind() == c10::prim::Constant) // || n->kind() == c10::prim::ListConstruct) | |||
| @@ -165,7 +165,7 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit | |||
| op->params.erase("value"); | |||
| op->attrs[name] = n->t(torch::jit::attr::value); | |||
| op->attrs["data"] = n->t(torch::jit::attr::value); | |||
| } | |||
| } | |||
| else if (n->kind() == c10::prim::CallMethod) | |||
| @@ -24,6 +24,17 @@ GraphRewriterPass::~GraphRewriterPass() | |||
| { | |||
| } | |||
| const char* GraphRewriterPass::replace_pattern_graph() const | |||
| { | |||
| return 0; | |||
| } | |||
| const char* GraphRewriterPass::type_str() const | |||
| { | |||
| fprintf(stderr, "GraphRewriterPass type_str() should be implemented\n"); | |||
| return "unk"; | |||
| } | |||
| const char* GraphRewriterPass::name_str() const | |||
| { | |||
| return type_str(); | |||
| @@ -46,15 +57,96 @@ bool GraphRewriterPass::match(const std::map<std::string, const Operator*>& /*ma | |||
| void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| for (auto x : captured_params) | |||
| if (replace_pattern_graph() == 0) | |||
| { | |||
| for (auto x : captured_params) | |||
| { | |||
| op->params[x.first] = x.second; | |||
| } | |||
| return; | |||
| } | |||
| for (auto x : op->params) | |||
| { | |||
| op->params[x.first] = x.second; | |||
| if (x.second.type != 4) | |||
| continue; | |||
| std::string str = x.second.s; | |||
| if (str.find('%') == std::string::npos) | |||
| continue; | |||
| // search % token and replace with captured | |||
| size_t pos = str.find('%'); | |||
| while (pos != std::string::npos) | |||
| { | |||
| // %xyz | |||
| char buf[256]; | |||
| sscanf(str.c_str() + pos + 1, "%255[^][,() ]", buf); | |||
| std::string key(buf); | |||
| if (captured_params.find(key) == captured_params.end()) | |||
| { | |||
| fprintf(stderr, "replace pattern param %%%s missing captured\n", key.c_str()); | |||
| return; | |||
| } | |||
| // replace %xyz with encoded_str | |||
| std::string encoded_str = Parameter::encode_to_string(captured_params.at(key)); | |||
| str.replace(pos, key.size() + 1, encoded_str); | |||
| pos = str.find('%', pos + 1); | |||
| } | |||
| op->params[x.first] = Parameter::parse_from_string(str); | |||
| } | |||
| } | |||
| void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const | |||
| void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| write(op, captured_params); | |||
| for (auto x : op->attrs) | |||
| { | |||
| if (x.second.type != 0) | |||
| continue; | |||
| std::string key((const char*)x.second.data.data()); | |||
| if (key.empty()) | |||
| continue; | |||
| op->attrs[x.first] = captured_attrs.at(key); | |||
| } | |||
| } | |||
| void GraphRewriterPass::write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| for (auto x : ops) | |||
| { | |||
| Operator* op = x.second; | |||
| write(op, captured_params); | |||
| } | |||
| } | |||
| void GraphRewriterPass::write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| write(ops, captured_params); | |||
| for (auto x : ops) | |||
| { | |||
| Operator* op = x.second; | |||
| for (auto x : op->attrs) | |||
| { | |||
| if (x.second.type != 0) | |||
| continue; | |||
| std::string key(x.second.data.begin(), x.second.data.end()); | |||
| if (key.empty() || key[0] != '%') | |||
| continue; | |||
| op->attrs[x.first] = captured_attrs.at(key.substr(1)); | |||
| } | |||
| } | |||
| } | |||
| static std::map<int, std::vector<const GraphRewriterPass*> > g_global_pnnx_graph_rewriter_passes; | |||
| @@ -75,6 +167,126 @@ GraphRewriterPassRegister::~GraphRewriterPassRegister() | |||
| delete pass; | |||
| } | |||
| static bool token_is_argument(const std::string& t) | |||
| { | |||
| if (t[0] != '@' || t.size() < 2) | |||
| return false; | |||
| for (size_t i = 1; i < t.size(); i++) | |||
| { | |||
| if (t[i] < '0' || t[i] > '9') | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| static bool match_expression(const Operator* a, const Operator* b, std::map<std::string, Parameter>& captured_params) | |||
| { | |||
| if (a->params.size() != 1 || a->params.find("expr") == a->params.end()) | |||
| return false; | |||
| if (b->params.size() != 1 || b->params.find("expr") == b->params.end()) | |||
| return false; | |||
| const std::string& a_expr = a->params.at("expr").s; | |||
| const std::string& b_expr = b->params.at("expr").s; | |||
| if (a_expr == b_expr) | |||
| return true; | |||
| // split into tokens | |||
| std::vector<std::string> a_tokens; | |||
| std::vector<std::string> b_tokens; | |||
| { | |||
| std::string t; | |||
| for (size_t i = 0; i < a_expr.size(); i++) | |||
| { | |||
| char ch = a_expr[i]; | |||
| if (ch == '[') // list | |||
| { | |||
| t += ch; | |||
| a_tokens.push_back(t); | |||
| t.clear(); | |||
| } | |||
| else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') | |||
| { | |||
| if (!t.empty()) | |||
| { | |||
| a_tokens.push_back(t); | |||
| t.clear(); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| t += ch; | |||
| } | |||
| } | |||
| if (!t.empty()) | |||
| { | |||
| a_tokens.push_back(t); | |||
| } | |||
| } | |||
| { | |||
| std::string t; | |||
| for (size_t i = 0; i < b_expr.size(); i++) | |||
| { | |||
| char ch = b_expr[i]; | |||
| if (ch == '[') // list | |||
| { | |||
| t += ch; | |||
| b_tokens.push_back(t); | |||
| t.clear(); | |||
| } | |||
| else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') | |||
| { | |||
| if (!t.empty()) | |||
| { | |||
| b_tokens.push_back(t); | |||
| t.clear(); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| t += ch; | |||
| } | |||
| } | |||
| if (!t.empty()) | |||
| { | |||
| b_tokens.push_back(t); | |||
| } | |||
| } | |||
| if (a_tokens.size() != b_tokens.size()) | |||
| return false; | |||
| // capture values | |||
| for (size_t i = 0; i < a_tokens.size(); i++) | |||
| { | |||
| const std::string& at = a_tokens[i]; | |||
| const std::string& bt = b_tokens[i]; | |||
| if (at == bt) | |||
| continue; | |||
| if (bt[0] != '%') | |||
| return false; | |||
| if (token_is_argument(at)) | |||
| return false; | |||
| std::string key = bt.substr(1); | |||
| captured_params[key] = Parameter::parse_from_string(at); | |||
| } | |||
| return true; | |||
| } | |||
| static bool match_parameter(const Parameter& a, const Parameter& b, std::map<std::string, Parameter>& captured_params) | |||
| { | |||
| if (b.type == 4 && b.s[0] == '%') | |||
| @@ -97,6 +309,77 @@ static bool match_parameter(const Parameter& a, const Parameter& b, std::map<std | |||
| return true; | |||
| } | |||
| if (b.type == 4 && (b.s[0] == '(' || b.s[0] == '[') && b.s.find('%') != std::string::npos) | |||
| { | |||
| // list with pattern | |||
| if (a.type != 5 && a.type != 6 && a.type != 7) | |||
| return false; | |||
| std::string lc = b.s.substr(1, b.s.size() - 2); | |||
| std::istringstream lcss(lc); | |||
| size_t i = 0; | |||
| while (!lcss.eof()) | |||
| { | |||
| std::string elem; | |||
| std::getline(lcss, elem, ','); | |||
| if (elem[0] == '%') | |||
| { | |||
| std::string key = elem.substr(1); | |||
| if (captured_params.find(key) != captured_params.end()) | |||
| { | |||
| // match previous captured parameter | |||
| if (a.type == 5 && captured_params.at(key).i != a.ai[i]) | |||
| return false; | |||
| if (a.type == 6 && captured_params.at(key).f != a.af[i]) | |||
| return false; | |||
| if (a.type == 7 && captured_params.at(key).s != a.as[i]) | |||
| return false; | |||
| } | |||
| // captured parameter | |||
| if (a.type == 5) | |||
| captured_params[key] = a.ai[i]; | |||
| if (a.type == 6) | |||
| captured_params[key] = a.af[i]; | |||
| if (a.type == 7) | |||
| captured_params[key] = a.as[i]; | |||
| } | |||
| else if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) | |||
| { | |||
| // string | |||
| if (a.type != 7) | |||
| return false; | |||
| if (a.as[i] != elem) | |||
| return false; | |||
| } | |||
| else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) | |||
| { | |||
| // float | |||
| if (a.type != 6) | |||
| return false; | |||
| if (a.af[i] != std::stof(elem)) | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| // integer | |||
| if (a.type != 5) | |||
| return false; | |||
| if (a.ai[i] != std::stoi(elem)) | |||
| return false; | |||
| } | |||
| i++; | |||
| } | |||
| return true; | |||
| } | |||
| if (a.type != b.type) | |||
| { | |||
| if (a.type == 2 && b.type == 3) | |||
| @@ -174,6 +457,76 @@ static bool match_parameter(const Parameter& a, const Parameter& b, std::map<std | |||
| return false; | |||
| } | |||
| static bool match_attribute(const Attribute& a, const Attribute& b, std::map<std::string, Parameter>& captured_params, const std::string& attrname, std::map<std::string, Attribute>& captured_attrs) | |||
| { | |||
| // @data | |||
| // @data=(1,2,3,4)f32 | |||
| // @data=%op1.data | |||
| if (b.type == 0) | |||
| { | |||
| std::string bs(b.data.begin(), b.data.end()); | |||
| if (bs.empty()) | |||
| { | |||
| // capture any shape | |||
| captured_attrs[attrname] = a; | |||
| return true; | |||
| } | |||
| if (bs[0] == '%') | |||
| { | |||
| // the captured replace | |||
| return true; | |||
| } | |||
| fprintf(stderr, "malformed attribute pattern %s\n", bs.c_str()); | |||
| return false; | |||
| } | |||
| const std::vector<int>& a_shape = a.shape; | |||
| const std::vector<int>& b_shape = b.shape; | |||
| if (b_shape.empty()) | |||
| return false; | |||
| if (a_shape.empty()) | |||
| return false; | |||
| if (a_shape.size() != b_shape.size()) | |||
| return false; | |||
| for (size_t j = 0; j < a_shape.size(); j++) | |||
| { | |||
| int ai = a_shape[j]; | |||
| int bi = b_shape[j]; | |||
| if (ai == bi) | |||
| continue; | |||
| if (bi == -1) | |||
| continue; | |||
| if (bi > 0) | |||
| return false; | |||
| if (bi != -233) | |||
| return false; | |||
| std::string key = b.params.at(std::string("__shape_") + std::to_string(j)).s; | |||
| if (captured_params.find(key) != captured_params.end()) | |||
| { | |||
| // match previous captured parameter | |||
| if (captured_params.at(key).i != ai) | |||
| return false; | |||
| } | |||
| // captured parameter | |||
| captured_params[key] = ai; | |||
| } | |||
| captured_attrs[attrname] = a; | |||
| return true; | |||
| } | |||
| static bool match_operator(const Operator* a, const Operator* b, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs) | |||
| { | |||
| if (a->type != b->type) | |||
| @@ -197,6 +550,11 @@ static bool match_operator(const Operator* a, const Operator* b, std::map<std::s | |||
| captured_params[b->name + '.' + pkey] = pp; | |||
| } | |||
| } | |||
| else if (a->type == "pnnx.Expression") | |||
| { | |||
| if (!match_expression(a, b, captured_params)) | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| if (a->params.size() != b->params.size()) | |||
| @@ -215,13 +573,120 @@ static bool match_operator(const Operator* a, const Operator* b, std::map<std::s | |||
| } | |||
| } | |||
| // match shapes | |||
| for (size_t i = 0; i < a->inputs.size(); i++) | |||
| { | |||
| int a_type = a->inputs[i]->type; | |||
| int b_type = b->inputs[i]->type; | |||
| if (b_type != 0 && a_type != b_type) | |||
| return false; | |||
| const std::vector<int>& a_shape = a->inputs[i]->shape; | |||
| const std::vector<int>& b_shape = b->inputs[i]->shape; | |||
| if (b_shape.empty()) | |||
| continue; | |||
| if (a_shape.empty()) | |||
| return false; | |||
| if (a_shape.size() != b_shape.size()) | |||
| return false; | |||
| for (size_t j = 0; j < a_shape.size(); j++) | |||
| { | |||
| int ai = a_shape[j]; | |||
| int bi = b_shape[j]; | |||
| if (ai == bi) | |||
| continue; | |||
| if (bi == -1) | |||
| continue; | |||
| if (bi > 0) | |||
| return false; | |||
| if (bi != -233) | |||
| return false; | |||
| std::string key = b->inputs[i]->params.at(std::string("__shape_") + std::to_string(j)).s; | |||
| if (captured_params.find(key) != captured_params.end()) | |||
| { | |||
| // match previous captured parameter | |||
| if (captured_params.at(key).i != ai) | |||
| return false; | |||
| } | |||
| // captured parameter | |||
| captured_params[key] = ai; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < a->outputs.size(); i++) | |||
| { | |||
| int a_type = a->outputs[i]->type; | |||
| int b_type = b->outputs[i]->type; | |||
| if (b_type != 0 && a_type != b_type) | |||
| return false; | |||
| const std::vector<int>& a_shape = a->outputs[i]->shape; | |||
| const std::vector<int>& b_shape = b->outputs[i]->shape; | |||
| if (b_shape.empty()) | |||
| continue; | |||
| if (a_shape.empty()) | |||
| return false; | |||
| if (a_shape.size() != b_shape.size()) | |||
| return false; | |||
| for (size_t j = 0; j < a_shape.size(); j++) | |||
| { | |||
| int ai = a_shape[j]; | |||
| int bi = b_shape[j]; | |||
| if (ai == bi) | |||
| continue; | |||
| if (bi == -1) | |||
| continue; | |||
| if (bi > 0) | |||
| return false; | |||
| if (bi != -233) | |||
| return false; | |||
| std::string key = b->outputs[i]->params.at(std::string("__shape_") + std::to_string(j)).s; | |||
| if (captured_params.find(key) != captured_params.end()) | |||
| { | |||
| // match previous captured parameter | |||
| if (captured_params.at(key).i != ai) | |||
| return false; | |||
| } | |||
| // captured parameter | |||
| captured_params[key] = ai; | |||
| } | |||
| } | |||
| for (const auto& p : a->attrs) | |||
| { | |||
| const std::string& akey = p.first; | |||
| const Attribute& aa = p.second; | |||
| // capture all attributes | |||
| captured_attrs[b->name + '.' + akey] = aa; | |||
| std::string attrname = b->name + '.' + akey; | |||
| if (b->attrs.find(akey) == b->attrs.end()) | |||
| { | |||
| // capture all attributes | |||
| captured_attrs[attrname] = aa; | |||
| } | |||
| else | |||
| { | |||
| if (!match_attribute(aa, b->attrs.at(akey), captured_params, attrname, captured_attrs)) | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| @@ -484,27 +949,111 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde | |||
| cur = graph.ops[cur_index]; | |||
| } | |||
| Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur); | |||
| for (const auto& k : pattern_graph_inputs) | |||
| if (pass->replace_pattern_graph() == 0) | |||
| { | |||
| Operand* r = (Operand*)matched_inputs.at(k); | |||
| r->consumers.push_back(op); | |||
| op->inputs.push_back(r); | |||
| // insert single | |||
| Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur); | |||
| op->inputnames.push_back(k); | |||
| } | |||
| for (const auto& k : pattern_graph_inputs) | |||
| { | |||
| Operand* r = (Operand*)matched_inputs.at(k); | |||
| r->consumers.push_back(op); | |||
| op->inputs.push_back(r); | |||
| for (const auto& k : pattern_graph_outputs) | |||
| { | |||
| Operand* r = (Operand*)matched_outputs.at(k); | |||
| r->producer = op; | |||
| op->outputs.push_back(r); | |||
| op->inputnames.push_back(k); | |||
| } | |||
| for (const auto& k : pattern_graph_outputs) | |||
| { | |||
| Operand* r = (Operand*)matched_outputs.at(k); | |||
| r->producer = op; | |||
| op->outputs.push_back(r); | |||
| } | |||
| pass->write(op, captured_params, captured_attrs); | |||
| new_ops.push_back(op); | |||
| } | |||
| else | |||
| { | |||
| // insert multiple | |||
| Graph replace_graph; | |||
| replace_graph.parse(pass->replace_pattern_graph()); | |||
| // move operators and operands from replace_graph to graph except input and output | |||
| std::map<std::string, Operator*> ops; | |||
| for (size_t i = 0; i < replace_graph.ops.size(); i++) | |||
| { | |||
| Operator* op = replace_graph.ops[i]; | |||
| if (op->type == "pnnx.Input" || op->type == "pnnx.Output") | |||
| continue; | |||
| graph.ops.insert(std::find(graph.ops.begin(), graph.ops.end(), cur), op); | |||
| replace_graph.ops[i] = 0; | |||
| ops[op->name] = op; | |||
| } | |||
| for (size_t i = 0; i < replace_graph.operands.size(); i++) | |||
| { | |||
| Operand* r = replace_graph.operands[i]; | |||
| if (r->producer->type == "pnnx.Input" || (r->consumers.size() == 1 && r->consumers[0]->type == "pnnx.Output")) | |||
| continue; | |||
| graph.operands.push_back(r); | |||
| replace_graph.operands[i] = 0; | |||
| } | |||
| replace_graph.ops.erase(std::remove(replace_graph.ops.begin(), replace_graph.ops.end(), (Operator*)0), replace_graph.ops.end()); | |||
| replace_graph.operands.erase(std::remove(replace_graph.operands.begin(), replace_graph.operands.end(), (Operand*)0), replace_graph.operands.end()); | |||
| pass->write(op, captured_params, captured_attrs); | |||
| for (size_t i = 0; i < pattern_graph_inputs.size(); i++) | |||
| { | |||
| const std::string& k = pattern_graph_inputs[i]; | |||
| Operand* r = (Operand*)matched_inputs.at(k); | |||
| const Operand* rr = replace_graph.get_operand(k); | |||
| for (auto x : rr->consumers) | |||
| { | |||
| r->consumers.push_back(x); | |||
| new_ops.push_back(op); | |||
| x->inputnames.resize(x->inputs.size()); | |||
| for (size_t j = 0; j < x->inputs.size(); j++) | |||
| { | |||
| if (x->inputs[j]->name == k) | |||
| { | |||
| x->inputs[j] = r; | |||
| x->inputnames[j] = k; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| for (size_t i = 0; i < pattern_graph_outputs.size(); i++) | |||
| { | |||
| const std::string& k = pattern_graph_outputs[i]; | |||
| Operand* r = (Operand*)matched_outputs.at(k); | |||
| const Operand* rr = replace_graph.get_operand(k); | |||
| r->producer = rr->producer; | |||
| for (size_t j = 0; j < r->producer->outputs.size(); j++) | |||
| { | |||
| if (r->producer->outputs[j]->name == k) | |||
| { | |||
| r->producer->outputs[j] = r; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| pass->write(ops, captured_params, captured_attrs); | |||
| for (auto x : ops) | |||
| { | |||
| new_ops.push_back(x.second); | |||
| } | |||
| } | |||
| } | |||
| // assign new op name number | |||
| @@ -26,7 +26,9 @@ public: | |||
| virtual const char* match_pattern_graph() const = 0; | |||
| virtual const char* type_str() const = 0; | |||
| virtual const char* replace_pattern_graph() const; | |||
| virtual const char* type_str() const; | |||
| virtual const char* name_str() const; | |||
| @@ -39,6 +41,10 @@ public: | |||
| virtual void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const; | |||
| virtual void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const; | |||
| virtual void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params) const; | |||
| virtual void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const; | |||
| }; | |||
| class GraphRewriterPassRegister | |||
| @@ -0,0 +1,68 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| #include <torch/csrc/api/include/torch/torch.h> | |||
| namespace pnnx { | |||
| class torch_repeat_interleave : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 6 5 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 repeats | |||
| pnnx.Input input_2 0 1 dim | |||
| prim::Constant op_0 0 1 output_size value=* | |||
| aten::repeat_interleave op_1 4 1 input repeats dim output_size out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.repeat_interleave"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_repeat_interleave, 20) | |||
| class torch_repeat_interleave_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 repeats | |||
| pnnx.Input input_2 0 1 dim | |||
| aten::repeat_interleave op_0 3 1 input repeats dim out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.repeat_interleave"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_repeat_interleave_1, 20) | |||
| } // namespace pnnx | |||
| @@ -35,10 +35,12 @@ | |||
| #include "pass_level5/fuse_convtranspose1d_batchnorm1d.h" | |||
| #include "pass_level5/fuse_convtranspose2d_batchnorm2d.h" | |||
| #include "pass_level5/fuse_contiguous_view.h" | |||
| #include "pass_level5/fuse_layernorm.h" | |||
| #include "pass_level5/fuse_linear_batchnorm1d.h" | |||
| #include "pass_level5/fuse_multiheadattention.h" | |||
| #include "pass_level5/fuse_pad_conv1d.h" | |||
| #include "pass_level5/fuse_pad_conv2d.h" | |||
| #include "pass_level5/fuse_scaled_dot_product_attention.h" | |||
| #include "pass_level5/fuse_select_to_unbind.h" | |||
| #include "pass_level5/fuse_slice_copy.h" | |||
| #include "pass_level5/fuse_slice_indices.h" | |||
| @@ -124,7 +126,9 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons | |||
| eliminate_reshape_shape_expression(g); | |||
| fuse_channel_shuffle(g); | |||
| fuse_layernorm(g); | |||
| fuse_multiheadattention(g); | |||
| fuse_scaled_dot_product_attention(g); | |||
| fuse_index_expression(g); | |||
| @@ -43,9 +43,9 @@ void fold_constants(Graph& graph, const std::set<std::string>& foldable_constant | |||
| // replace producer with attribute | |||
| Operator* op_new = graph.new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op); | |||
| op_new->attrs[std::string("pnnx_fold_") + name] = Attribute(); | |||
| op_new->attrs["data"] = Attribute(); | |||
| Attribute& t2 = op_new->attrs[std::string("pnnx_fold_") + name]; | |||
| Attribute& t2 = op_new->attrs["data"]; | |||
| t2.type = operand->type; | |||
| t2.shape = operand->shape; | |||
| size_t size = zip.get_file_size(name); | |||
| @@ -26,9 +26,9 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Expression op_0 1 1 input 13 expr=%expr | |||
| pnnx.Expression op_0 1 1 input 13 expr=[int(size(@0,0)),2,int(floor_divide(size(@0,1),%groups)),int(size(@0,2)),int(size(@0,3))] | |||
| Tensor.view op_1 2 1 input 13 14 | |||
| pnnx.Expression op_2 1 1 input 15 expr=%expr2 | |||
| pnnx.Expression op_2 1 1 input 15 expr=[int(size(@0,0)),-1,int(size(@0,2)),int(size(@0,3))] | |||
| torch.transpose op_3 1 1 14 16 dim0=1 dim1=2 | |||
| Tensor.reshape op_4 2 1 16 15 out | |||
| pnnx.Output output 1 0 out | |||
| @@ -44,32 +44,6 @@ pnnx.Output output 1 0 out | |||
| { | |||
| return "channelshuffle"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const std::string& expr = captured_params.at("expr").s; | |||
| const std::string& expr2 = captured_params.at("expr2").s; | |||
| if (expr2 != "[int(size(@0,0)),-1,int(size(@0,2)),int(size(@0,3))]") | |||
| return false; | |||
| int groups; | |||
| int nscan = sscanf(expr.c_str(), "[int(size(@0,0)),2,int(floor_divide(size(@0,1),%d)),int(size(@0,2)),int(size(@0,3))]", &groups); | |||
| if (nscan != 1) | |||
| return false; | |||
| return true; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const std::string& expr = captured_params.at("expr").s; | |||
| int groups; | |||
| sscanf(expr.c_str(), "[int(size(@0,0)),2,int(floor_divide(size(@0,1),%d)),int(size(@0,2)),int(size(@0,3))]", &groups); | |||
| op->params["groups"] = groups; | |||
| } | |||
| }; | |||
| class fuse_channel_shuffle_pass_1 : public GraphRewriterPass | |||
| @@ -80,9 +54,9 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| Tensor.view op_0 1 1 input 13 shape=%shape | |||
| Tensor.view op_0 1 1 input 13 shape=(%batch,%groups,%channels_per_group,%h,%w) | |||
| torch.transpose op_1 1 1 13 14 dim0=1 dim1=2 | |||
| Tensor.reshape op_2 1 1 14 out shape=%shape2 | |||
| Tensor.reshape op_2 1 1 14 out shape=(%batch,-1,%h,%w) | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| @@ -97,26 +71,9 @@ pnnx.Output output 1 0 out | |||
| return "channelshuffle"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| // (1,2,58,28,28) | |||
| // (1,-1,28,28) | |||
| const std::vector<int>& shape = captured_params.at("shape").ai; | |||
| const std::vector<int>& shape2 = captured_params.at("shape2").ai; | |||
| if (shape[0] != 1 || shape2[0] != 1 || shape2[1] != -1 || shape2[2] != shape[3] || shape2[3] != shape[4]) | |||
| return false; | |||
| return true; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const std::vector<int>& shape = captured_params.at("shape").ai; | |||
| int groups = shape[1]; | |||
| op->params["groups"] = groups; | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| } | |||
| }; | |||
| @@ -0,0 +1,87 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_layernorm.h" | |||
| #include "pass_level2.h" | |||
| #include <math.h> | |||
| #include <string.h> | |||
| #include <torch/csrc/api/include/torch/torch.h> | |||
| namespace pnnx { | |||
| class fuse_layernorm_pass : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 8 7 | |||
| pnnx.Input input 0 1 input #input=(1,%c,?,?)f32 | |||
| pnnx.Attribute op_0 0 1 weight @data #weight=(%c,1,1)f32 | |||
| pnnx.Attribute op_1 0 1 bias @data #bias=(%c,1,1)f32 | |||
| torch.mean op_2 1 1 input mean dim=(1) keepdim=True | |||
| pnnx.Expression op_3 2 1 input mean 2 expr=pow(sub(@0,@1),2) | |||
| torch.mean op_4 1 1 2 var dim=(1) keepdim=True | |||
| pnnx.Expression op_5 5 1 weight input mean var bias out expr=add(mul(@0,div(sub(@1,@2),sqrt(add(@3,%eps)))),@4) | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| #if TORCH_VERSION_MAJOR >= 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 9 | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| torch.permute op_0 1 1 input a dims=(0,2,3,1) | |||
| nn.LayerNorm op_1 1 1 a b elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data @bias=%op_1.data | |||
| torch.permute op_2 1 1 b out dims=(0,3,1,2) | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| #else | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| Tensor.permute op_0 1 1 input a dims=(0,2,3,1) | |||
| nn.LayerNorm op_1 1 1 a b elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data @bias=%op_1.data | |||
| Tensor.permute op_2 1 1 b out dims=(0,3,1,2) | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| #endif | |||
| } | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| // fix weight bias shape from (c,1,1) to (c) | |||
| const int c = captured_params.at("c").i; | |||
| Operator* op_1 = ops.at("op_1"); | |||
| op_1->attrs["weight"].shape = {c}; | |||
| op_1->attrs["bias"].shape = {c}; | |||
| } | |||
| }; | |||
| void fuse_layernorm(Graph& graph) | |||
| { | |||
| fuse_layernorm_pass a; | |||
| int opindex = 0; | |||
| pnnx_graph_rewrite(graph, &a, opindex); | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void fuse_layernorm(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,92 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_scaled_dot_product_attention.h" | |||
| #include "pass_level2.h" | |||
| #include <math.h> | |||
| #include <string.h> | |||
| #include <torch/csrc/api/include/torch/torch.h> | |||
| namespace pnnx { | |||
| static bool NearlyEqual(float a, float b, float epsilon) | |||
| { | |||
| if (a == b) | |||
| return true; | |||
| float diff = (float)fabs(a - b); | |||
| if (diff <= epsilon) | |||
| return true; | |||
| // relative error | |||
| return diff < epsilon * std::max(fabs(a), fabs(b)); | |||
| } | |||
| class fuse_scaled_dot_product_attention_pass : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 9 8 | |||
| pnnx.Input input_0 0 1 query #query=(%batch,%num_heads,%qsize,%feat_per_head)f32 | |||
| pnnx.Input input_1 0 1 key #key=(%batch,%num_heads,%kvsize,%feat_per_head)f32 | |||
| pnnx.Input input_2 0 1 value #value=(%batch,%num_heads,%kvsize,%feat_per_head)f32 | |||
| torch.permute op_0 1 1 key 59 dims=(0,1,3,2) | |||
| torch.matmul op_1 2 1 query 59 61 | |||
| pnnx.Expression op_2 1 1 61 62 expr=div(@0,%sqrt_embed_dim_per_head) | |||
| F.softmax op_3 1 1 62 63 dim=-1 | |||
| torch.matmul op_4 2 1 63 value out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 query | |||
| pnnx.Input input_1 0 1 key | |||
| pnnx.Input input_2 0 1 value | |||
| F.scaled_dot_product_attention op_0 3 1 query key value out attn_mask=None dropout_p=0.0 is_causal=False | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const int feat_per_head = captured_params.at("feat_per_head").i; | |||
| const float sqrt_embed_dim_per_head = captured_params.at("sqrt_embed_dim_per_head").f; | |||
| if (!NearlyEqual(sqrt_embed_dim_per_head, sqrt(feat_per_head), 0.001)) | |||
| return false; | |||
| return true; | |||
| } | |||
| }; | |||
| void fuse_scaled_dot_product_attention(Graph& graph) | |||
| { | |||
| #if TORCH_VERSION_MAJOR >= 2 | |||
| fuse_scaled_dot_product_attention_pass a; | |||
| int opindex = 0; | |||
| pnnx_graph_rewrite(graph, &a, opindex); | |||
| #endif | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void fuse_scaled_dot_product_attention(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -29,21 +29,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_mean 0 1 running_mean @qwq | |||
| pnnx.Attribute op_var 0 1 running_var @qwq | |||
| pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 | |||
| pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 | |||
| F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.BatchNorm1d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "batchnorm"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.BatchNorm1d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| @@ -51,26 +51,6 @@ pnnx.Output output 1 0 out | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 2 || input_rank == 3; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute running_mean; | |||
| Attribute running_var; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 8) == "op_mean.") | |||
| running_mean = x.second; | |||
| if (x.first.substr(0, 7) == "op_var.") | |||
| running_var = x.second; | |||
| } | |||
| op->params["num_features"] = running_mean.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = false; | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = running_var; | |||
| } | |||
| }; | |||
| class fuse_static_Fbatchnorm_pass_1d_1 : public GraphRewriterPass | |||
| @@ -81,23 +61,23 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_mean 0 1 running_mean @qwq | |||
| pnnx.Attribute op_var 0 1 running_var @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 | |||
| pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| pnnx.Attribute op_bias 0 1 bias @data | |||
| F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.BatchNorm1d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "batchnorm"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.BatchNorm1d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| @@ -105,34 +85,6 @@ pnnx.Output output 1 0 out | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 2 || input_rank == 3; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute running_mean; | |||
| Attribute running_var; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 8) == "op_mean.") | |||
| running_mean = x.second; | |||
| if (x.first.substr(0, 7) == "op_var.") | |||
| running_var = x.second; | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["num_features"] = running_mean.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = true; | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = running_var; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| } | |||
| }; | |||
| class fuse_static_Fbatchnorm_pass_2d : public GraphRewriterPass | |||
| @@ -143,47 +95,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_mean 0 1 running_mean @qwq | |||
| pnnx.Attribute op_var 0 1 running_var @qwq | |||
| F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps | |||
| pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 | |||
| pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 | |||
| F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps #input=(?,?,?,?)f32 | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.BatchNorm2d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "batchnorm"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| { | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 4; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| Attribute running_mean; | |||
| Attribute running_var; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 8) == "op_mean.") | |||
| running_mean = x.second; | |||
| if (x.first.substr(0, 7) == "op_var.") | |||
| running_var = x.second; | |||
| } | |||
| op->params["num_features"] = running_mean.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = false; | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = running_var; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.BatchNorm2d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -195,57 +121,23 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_mean 0 1 running_mean @qwq | |||
| pnnx.Attribute op_var 0 1 running_var @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps | |||
| pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 | |||
| pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| pnnx.Attribute op_bias 0 1 bias @data | |||
| F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps #input=(?,?,?,?)f32 | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.BatchNorm2d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "batchnorm"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| { | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 4; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| Attribute running_mean; | |||
| Attribute running_var; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 8) == "op_mean.") | |||
| running_mean = x.second; | |||
| if (x.first.substr(0, 7) == "op_var.") | |||
| running_var = x.second; | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["num_features"] = running_mean.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = true; | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = running_var; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.BatchNorm2d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -257,47 +149,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_mean 0 1 running_mean @qwq | |||
| pnnx.Attribute op_var 0 1 running_var @qwq | |||
| F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps | |||
| pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 | |||
| pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 | |||
| F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps #input=(?,?,?,?,?)f32 | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.BatchNorm3d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "batchnorm"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| { | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 5; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| Attribute running_mean; | |||
| Attribute running_var; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 8) == "op_mean.") | |||
| running_mean = x.second; | |||
| if (x.first.substr(0, 7) == "op_var.") | |||
| running_var = x.second; | |||
| } | |||
| op->params["num_features"] = running_mean.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = false; | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = running_var; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.BatchNorm3d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -309,57 +175,23 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_mean 0 1 running_mean @qwq | |||
| pnnx.Attribute op_var 0 1 running_var @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps | |||
| pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32 | |||
| pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32 | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| pnnx.Attribute op_bias 0 1 bias @data | |||
| F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps #input=(?,?,?,?,?)f32 | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.BatchNorm3d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "batchnorm"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| { | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 5; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| Attribute running_mean; | |||
| Attribute running_var; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 8) == "op_mean.") | |||
| running_mean = x.second; | |||
| if (x.first.substr(0, 7) == "op_var.") | |||
| running_var = x.second; | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["num_features"] = running_mean.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = true; | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = running_var; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.BatchNorm3d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -29,42 +29,30 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32 | |||
| F.conv1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv1d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv1d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = false; | |||
| op->attrs["weight"] = weight; | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -76,47 +64,31 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 | |||
| F.conv1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv1d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv1d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -128,71 +100,32 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 6 5 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1)f32 | |||
| F.conv1d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Expression op_1 2 1 a bias out expr=%expr | |||
| pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1) | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv1d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv1d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::string& expr = captured_params.at("expr").s; | |||
| if (expr != "add(@0,@1)") | |||
| return false; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| int out_channels = weight.shape[0]; | |||
| if (bias.shape != std::vector<int>{1, out_channels, 1}) | |||
| return false; | |||
| return true; | |||
| } | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -204,42 +137,30 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32 | |||
| F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv2d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv2d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = false; | |||
| op->attrs["weight"] = weight; | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -251,47 +172,31 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 | |||
| F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv2d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv2d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -303,71 +208,32 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 6 5 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1,1)f32 | |||
| F.conv2d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Expression op_1 2 1 a bias out expr=%expr | |||
| pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1) | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv2d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv2d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::string& expr = captured_params.at("expr").s; | |||
| if (expr != "add(@0,@1)") | |||
| return false; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| int out_channels = weight.shape[0]; | |||
| if (bias.shape != std::vector<int>{1, out_channels, 1, 1}) | |||
| return false; | |||
| return true; | |||
| } | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -379,42 +245,30 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32 | |||
| F.conv3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv3d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv3d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3], weight.shape[4]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = false; | |||
| op->attrs["weight"] = weight; | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -426,47 +280,31 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 | |||
| F.conv3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv3d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv3d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3], weight.shape[4]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -478,71 +316,32 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 6 5 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1,1,1)f32 | |||
| F.conv3d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Expression op_1 2 1 a bias out expr=%expr | |||
| pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1) | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Conv3d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* name_str() const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| return "conv3d"; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::string& expr = captured_params.at("expr").s; | |||
| if (expr != "add(@0,@1)") | |||
| return false; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| int out_channels = weight.shape[0]; | |||
| if (bias.shape != std::vector<int>{1, out_channels, 1, 1, 1}) | |||
| return false; | |||
| return true; | |||
| } | |||
| const int in_channels_per_group = captured_params.at("in_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; | |||
| op->params["out_channels"] = weight.shape[0]; | |||
| op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3], weight.shape[4]}; | |||
| op->params["padding_mode"] = std::string("zeros"); | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["groups"] = captured_params.at("groups"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -29,44 +29,30 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kw)f32 | |||
| F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.ConvTranspose1d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "conv_transpose1d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.ConvTranspose1d conv_transpose1d 1 1 input out in_channels=%in_channels kernel_size=(%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| const int out_channels_per_group = captured_params.at("out_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| op->params["groups"] = groups; | |||
| op->params["in_channels"] = weight.shape[0]; | |||
| op->params["out_channels"] = weight.shape[1] * groups; | |||
| op->params["kernel_size"] = Parameter{weight.shape[2]}; | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["output_padding"] = captured_params.at("output_padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["bias"] = false; | |||
| op->attrs["weight"] = weight; | |||
| ops.at("conv_transpose1d")->params["out_channels"] = out_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -78,49 +64,33 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 | |||
| F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.ConvTranspose1d"; | |||
| } | |||
| const char* name_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "conv_transpose1d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.ConvTranspose1d conv_transpose1d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| const int out_channels_per_group = captured_params.at("out_channels_per_group").i; | |||
| const int out_channels = captured_params.at("out_channels").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| op->params["groups"] = groups; | |||
| op->params["in_channels"] = weight.shape[0]; | |||
| op->params["out_channels"] = weight.shape[1] * groups; | |||
| op->params["kernel_size"] = Parameter{weight.shape[2]}; | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["output_padding"] = captured_params.at("output_padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| if (out_channels != out_channels_per_group * groups) | |||
| return false; | |||
| return true; | |||
| } | |||
| }; | |||
| @@ -132,44 +102,30 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kh,%kw)f32 | |||
| F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.ConvTranspose2d"; | |||
| } | |||
| const char* name_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "conv_transpose2d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.ConvTranspose2d conv_transpose2d 1 1 input out in_channels=%in_channels kernel_size=(%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| const int out_channels_per_group = captured_params.at("out_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| op->params["groups"] = groups; | |||
| op->params["in_channels"] = weight.shape[0]; | |||
| op->params["out_channels"] = weight.shape[1] * groups; | |||
| op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]}; | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["output_padding"] = captured_params.at("output_padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["bias"] = false; | |||
| op->attrs["weight"] = weight; | |||
| ops.at("conv_transpose2d")->params["out_channels"] = out_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -181,49 +137,33 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kh,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 | |||
| F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.ConvTranspose2d"; | |||
| } | |||
| const char* name_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "conv_transpose2d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.ConvTranspose2d conv_transpose2d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| const int out_channels_per_group = captured_params.at("out_channels_per_group").i; | |||
| const int out_channels = captured_params.at("out_channels").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| op->params["groups"] = groups; | |||
| op->params["in_channels"] = weight.shape[0]; | |||
| op->params["out_channels"] = weight.shape[1] * groups; | |||
| op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]}; | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["output_padding"] = captured_params.at("output_padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| if (out_channels != out_channels_per_group * groups) | |||
| return false; | |||
| return true; | |||
| } | |||
| }; | |||
| @@ -235,44 +175,30 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kd,%kh,%kw)f32 | |||
| F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.ConvTranspose3d"; | |||
| } | |||
| const char* name_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "conv_transpose3d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.ConvTranspose3d conv_transpose3d 1 1 input out in_channels=%in_channels kernel_size=(%kd,%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| const int out_channels_per_group = captured_params.at("out_channels_per_group").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| op->params["groups"] = groups; | |||
| op->params["in_channels"] = weight.shape[0]; | |||
| op->params["out_channels"] = weight.shape[1] * groups; | |||
| op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]}; | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["output_padding"] = captured_params.at("output_padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["bias"] = false; | |||
| op->attrs["weight"] = weight; | |||
| ops.at("conv_transpose3d")->params["out_channels"] = out_channels_per_group * groups; | |||
| } | |||
| }; | |||
| @@ -284,49 +210,33 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kd,%kh,%kw)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32 | |||
| F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.ConvTranspose3d"; | |||
| } | |||
| const char* name_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "conv_transpose3d"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.ConvTranspose3d conv_transpose3d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kd,%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| const int out_channels_per_group = captured_params.at("out_channels_per_group").i; | |||
| const int out_channels = captured_params.at("out_channels").i; | |||
| const int groups = captured_params.at("groups").i; | |||
| op->params["groups"] = groups; | |||
| op->params["in_channels"] = weight.shape[0]; | |||
| op->params["out_channels"] = weight.shape[1] * groups; | |||
| op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]}; | |||
| op->params["stride"] = captured_params.at("stride"); | |||
| op->params["padding"] = captured_params.at("padding"); | |||
| op->params["output_padding"] = captured_params.at("output_padding"); | |||
| op->params["dilation"] = captured_params.at("dilation"); | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| if (out_channels != out_channels_per_group * groups) | |||
| return false; | |||
| return true; | |||
| } | |||
| }; | |||
| @@ -29,42 +29,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%num_channels)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%num_channels)f32 | |||
| F.group_norm op_0 3 1 input weight bias out num_groups=%num_groups eps=%eps | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.GroupNorm"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "group_norm"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["num_channels"] = weight.shape[0]; | |||
| op->params["num_groups"] = captured_params.at("num_groups"); | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.GroupNorm group_norm 1 1 input out num_channels=%num_channels num_groups=%num_groups eps=%eps affine=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -29,21 +29,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32 | |||
| F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.InstanceNorm1d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "instance_norm"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.InstanceNorm1d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| @@ -51,27 +51,6 @@ pnnx.Output output 1 0 out | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 2 || input_rank == 3; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["num_features"] = weight.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = true; | |||
| op->params["track_running_stats"] = false; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| } | |||
| }; | |||
| class fuse_static_Finstancenorm_pass_2d : public GraphRewriterPass | |||
| @@ -82,48 +61,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32 | |||
| F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps #input=(?,?,?,?)f32 | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.InstanceNorm1d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "instance_norm"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| { | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 4; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["num_features"] = weight.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = true; | |||
| op->params["track_running_stats"] = false; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.InstanceNorm2d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -135,48 +87,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32 | |||
| F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps #input=(?,?,?,?,?)f32 | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.InstanceNorm1d"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "instance_norm"; | |||
| } | |||
| bool match(const std::map<std::string, const Operator*>& matched_operators) const | |||
| { | |||
| size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); | |||
| return input_rank == 5; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["num_features"] = weight.shape[0]; | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["affine"] = true; | |||
| op->params["track_running_stats"] = false; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.InstanceNorm3d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -29,41 +29,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| pnnx.Attribute op_bias 0 1 bias @data | |||
| F.layer_norm op_0 3 1 input weight bias out normalized_shape=%normalized_shape eps=%eps | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.LayerNorm"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "layer_norm"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["normalized_shape"] = captured_params.at("normalized_shape"); | |||
| op->params["eps"] = captured_params.at("eps"); | |||
| op->params["elementwise_affine"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.LayerNorm layer_norm 1 1 input out normalized_shape=%normalized_shape eps=%eps elementwise_affine=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -29,36 +29,20 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32 | |||
| F.linear op_0 2 1 input weight out bias=None | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Linear"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "linear"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| op->params["in_features"] = weight.shape[1]; | |||
| op->params["out_features"] = weight.shape[0]; | |||
| op->params["bias"] = false; | |||
| op->attrs["weight"] = weight; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=False @weight=%op_weight.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -70,41 +54,21 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(%out_features)f32 | |||
| F.linear op_0 3 1 input weight bias out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "nn.Linear"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "linear"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["in_features"] = weight.shape[1]; | |||
| op->params["out_features"] = weight.shape[0]; | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| @@ -116,65 +80,31 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 6 5 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| F.linear op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups | |||
| pnnx.Expression op_1 2 1 a bias out expr=%expr | |||
| pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32 | |||
| pnnx.Attribute op_bias 0 1 bias @data=(1,%out_features,1)f32 | |||
| F.linear op_0 2 1 input weight a | |||
| pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1) | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "nn.Linear"; | |||
| } | |||
| const char* name_str() const | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return "linear"; | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=True @weight=%op_weight.data @bias=%op_bias.data | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::string& expr = captured_params.at("expr").s; | |||
| if (expr != "add(@0,@1)") | |||
| return false; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| int out_channels = weight.shape[0]; | |||
| if (bias.shape != std::vector<int>{1, out_channels, 1}) | |||
| return false; | |||
| return true; | |||
| } | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| op->params["in_features"] = weight.shape[1]; | |||
| op->params["out_features"] = weight.shape[0]; | |||
| op->params["bias"] = true; | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = bias; | |||
| // fix bias shape | |||
| const int out_features = captured_params.at("out_features").i; | |||
| ops.at("linear")->attrs.at("bias").shape = {out_features}; | |||
| } | |||
| }; | |||
| @@ -26,8 +26,8 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_mean 0 1 running_mean @qwq | |||
| pnnx.Attribute op_var 0 1 running_var @qwq | |||
| pnnx.Attribute op_mean 0 1 running_mean @data | |||
| pnnx.Attribute op_var 0 1 running_var @data | |||
| F.batch_norm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| @@ -45,15 +45,8 @@ pnnx.Output output 1 0 out | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute running_mean; | |||
| Attribute running_var; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 8) == "op_mean.") | |||
| running_mean = x.second; | |||
| if (x.first.substr(0, 7) == "op_var.") | |||
| running_var = x.second; | |||
| } | |||
| Attribute running_mean = captured_attrs.at("op_mean.data"); | |||
| Attribute running_var = captured_attrs.at("op_var.data"); | |||
| op->params["0"] = running_mean.shape[0]; | |||
| op->params["1"] = captured_params.at("eps"); | |||
| @@ -77,10 +70,10 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_mean 0 1 running_mean @qwq | |||
| pnnx.Attribute op_var 0 1 running_var @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_mean 0 1 running_mean @data | |||
| pnnx.Attribute op_var 0 1 running_var @data | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| pnnx.Attribute op_bias 0 1 bias @data | |||
| F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| @@ -98,21 +91,10 @@ pnnx.Output output 1 0 out | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute running_mean; | |||
| Attribute running_var; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 8) == "op_mean.") | |||
| running_mean = x.second; | |||
| if (x.first.substr(0, 7) == "op_var.") | |||
| running_var = x.second; | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| Attribute running_mean = captured_attrs.at("op_mean.data"); | |||
| Attribute running_var = captured_attrs.at("op_var.data"); | |||
| Attribute weight = captured_attrs.at("op_weight.data"); | |||
| Attribute bias = captured_attrs.at("op_bias.data"); | |||
| op->params["0"] = running_mean.shape[0]; | |||
| op->params["1"] = captured_params.at("eps"); | |||
| @@ -26,7 +26,7 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| F.embedding op_0 2 1 input weight out scale_grad_by_freq=False sparse=False | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| @@ -44,12 +44,7 @@ pnnx.Output output 1 0 out | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| Attribute weight = captured_attrs.at("op_weight.data"); | |||
| op->params["0"] = weight.shape[1]; | |||
| op->params["1"] = weight.shape[0]; | |||
| @@ -67,8 +67,8 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| pnnx.Attribute op_bias 0 1 bias @data | |||
| F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| @@ -86,15 +86,8 @@ pnnx.Output output 1 0 out | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| Attribute weight = captured_attrs.at("op_weight.data"); | |||
| Attribute bias = captured_attrs.at("op_bias.data"); | |||
| op->params["0"] = weight.shape[0]; | |||
| op->params["1"] = captured_params.at("eps"); | |||
| @@ -26,7 +26,7 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| F.prelu op_0 2 1 input weight out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| @@ -44,12 +44,7 @@ pnnx.Output output 1 0 out | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| Attribute weight = captured_attrs.at("op_weight.data"); | |||
| op->params["0"] = weight.shape[0]; | |||
| @@ -26,8 +26,8 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 mat1 | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| pnnx.Attribute op_bias 0 1 bias @data | |||
| pnnx.Attribute op_weight 0 1 weight @data | |||
| torch.addmm op_0 3 1 bias mat1 weight out alpha=%alpha beta=%beta | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| @@ -69,15 +69,8 @@ pnnx.Output output 1 0 out | |||
| if (alpha != 1.f || beta != 1.f) | |||
| return false; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| Attribute weight = captured_attrs.at("op_weight.data"); | |||
| Attribute bias = captured_attrs.at("op_bias.data"); | |||
| if (weight.shape.size() != 2 || bias.shape.size() != 1) | |||
| return false; | |||
| @@ -90,15 +83,8 @@ pnnx.Output output 1 0 out | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| Attribute weight = captured_attrs.at("op_weight.data"); | |||
| Attribute bias = captured_attrs.at("op_bias.data"); | |||
| // transpose weight inch-outch to outch-inch | |||
| const int inch = weight.shape[0]; | |||
| @@ -213,6 +213,7 @@ pnnx_add_test(torch_ones) | |||
| pnnx_add_test(torch_ones_like) | |||
| pnnx_add_test(torch_permute) | |||
| pnnx_add_test(torch_prod) | |||
| pnnx_add_test(torch_repeat_interleave) | |||
| pnnx_add_test(torch_scatter_add) | |||
| pnnx_add_test(torch_sum) | |||
| pnnx_add_test(torch_split) | |||
| @@ -298,8 +299,10 @@ pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) | |||
| pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) | |||
| pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) | |||
| pnnx_add_test(pnnx_fuse_input_unpack) | |||
| pnnx_add_test(pnnx_fuse_layernorm) | |||
| pnnx_add_test(pnnx_fuse_linear_batchnorm1d) | |||
| pnnx_add_test(pnnx_fuse_multiheadattention) | |||
| pnnx_add_test(pnnx_fuse_scaled_dot_product_attention) | |||
| pnnx_add_test(pnnx_fuse_select_to_unbind) | |||
| pnnx_add_test(pnnx_fuse_slice_to_tensor_split) | |||
| pnnx_add_test(pnnx_fuse_adjacent_reshape) | |||
| @@ -0,0 +1,70 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class LayerNorm2d(nn.Module): | |||
| def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | |||
| super().__init__() | |||
| self.weight = nn.Parameter(torch.ones(num_channels)) | |||
| self.bias = nn.Parameter(torch.zeros(num_channels)) | |||
| self.eps = eps | |||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
| u = x.mean(1, keepdim=True) | |||
| s = (x - u).pow(2).mean(1, keepdim=True) | |||
| x = (x - u) / torch.sqrt(s + self.eps) | |||
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |||
| return x | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.ln_0 = LayerNorm2d(64) | |||
| def forward(self, x): | |||
| x = self.ln_0(x) | |||
| return x | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 64, 16, 16) | |||
| a0 = net(x) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, x) | |||
| mod.save("test_pnnx_fuse_layernorm.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_pnnx_fuse_layernorm.pt inputshape=[1,64,16,16]") | |||
| # pnnx inference | |||
| import test_pnnx_fuse_layernorm_pnnx | |||
| b0 = test_pnnx_fuse_layernorm_pnnx.test_inference() | |||
| return torch.allclose(a0, b0, 1e-4, 1e-4) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,127 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from packaging import version | |||
| from einops import rearrange | |||
| from typing import Any, Optional, Tuple, Union | |||
| from torch import Tensor | |||
| import math | |||
| class sam_Attention(nn.Module): | |||
| """ | |||
| An attention layer that allows for downscaling the size of the embedding | |||
| after projection to queries, keys, and values. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| embedding_dim: int, | |||
| num_heads: int, | |||
| downsample_rate: int = 1, | |||
| ) -> None: | |||
| super().__init__() | |||
| self.embedding_dim = embedding_dim | |||
| self.internal_dim = embedding_dim // downsample_rate | |||
| self.num_heads = num_heads | |||
| assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." | |||
| self.q_proj = nn.Linear(embedding_dim, self.internal_dim) | |||
| self.k_proj = nn.Linear(embedding_dim, self.internal_dim) | |||
| self.v_proj = nn.Linear(embedding_dim, self.internal_dim) | |||
| self.out_proj = nn.Linear(self.internal_dim, embedding_dim) | |||
| def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: | |||
| b, n, c = x.shape | |||
| x = x.reshape(b, n, num_heads, c // num_heads) | |||
| return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head | |||
| def _recombine_heads(self, x: Tensor) -> Tensor: | |||
| b, n_heads, n_tokens, c_per_head = x.shape | |||
| x = x.transpose(1, 2) | |||
| return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C | |||
| def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: | |||
| # Input projections | |||
| q = self.q_proj(q) | |||
| k = self.k_proj(k) | |||
| v = self.v_proj(v) | |||
| # Separate into heads | |||
| q = self._separate_heads(q, self.num_heads) | |||
| k = self._separate_heads(k, self.num_heads) | |||
| v = self._separate_heads(v, self.num_heads) | |||
| # Attention | |||
| _, _, _, c_per_head = q.shape | |||
| attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens | |||
| attn = attn / math.sqrt(c_per_head) | |||
| attn = torch.softmax(attn, dim=-1) | |||
| # Get output | |||
| out = attn @ v | |||
| out = self._recombine_heads(out) | |||
| out = self.out_proj(out) | |||
| return out | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.attention_0 = sam_Attention(embedding_dim=64, num_heads=4, downsample_rate=1) | |||
| self.attention_1 = sam_Attention(embedding_dim=64, num_heads=4, downsample_rate=2) | |||
| def forward(self, q, k, v): | |||
| a = self.attention_0(q, k, v) | |||
| b = self.attention_1(q, k, v) | |||
| return a, b | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 24, 64) | |||
| y = torch.rand(1, 24, 64) | |||
| z = torch.rand(1, 24, 64) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_pnnx_fuse_scaled_dot_product_attention.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_pnnx_fuse_scaled_dot_product_attention.pt inputshape=[1,24,64],[1,24,64],[1,24,64]") | |||
| # pnnx inference | |||
| import test_pnnx_fuse_scaled_dot_product_attention_pnnx | |||
| b = test_pnnx_fuse_scaled_dot_product_attention_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,65 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from packaging import version | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.repeat_interleave(x, 2) | |||
| y = torch.repeat_interleave(y, 3, dim=1) | |||
| if version.parse(torch.__version__) >= version.parse('1.10'): | |||
| z = torch.repeat_interleave(z, torch.tensor([2, 1, 3]), dim=0, output_size=6) | |||
| else: | |||
| z = torch.repeat_interleave(z, torch.tensor([2, 1, 3]), dim=0) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3) | |||
| y = torch.rand(4, 5) | |||
| z = torch.rand(3, 7, 8) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_repeat_interleave.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_repeat_interleave.pt inputshape=[3],[4,5],[3,7,8]") | |||
| # pnnx inference | |||
| import test_torch_repeat_interleave_pnnx | |||
| b = test_torch_repeat_interleave_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||