Browse Source

pnnx rewrite multiple ops (#4780)

fuse F.scaled_dot_product_attention
tags/20230816
nihui GitHub 3 years ago
parent
commit
b8cf8cb73e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 2280 additions and 2031 deletions
  1. +3
    -0
      tools/pnnx/src/CMakeLists.txt
  2. +160
    -72
      tools/pnnx/src/ir.cpp
  3. +4
    -0
      tools/pnnx/src/ir.h
  4. +2
    -2
      tools/pnnx/src/pass_level1.cpp
  5. +569
    -20
      tools/pnnx/src/pass_level2.cpp
  6. +7
    -1
      tools/pnnx/src/pass_level2.h
  7. +68
    -0
      tools/pnnx/src/pass_level2/torch_repeat_interleave.cpp
  8. +4
    -0
      tools/pnnx/src/pass_level5.cpp
  9. +2
    -2
      tools/pnnx/src/pass_level5/fold_constants.cpp
  10. +5
    -48
      tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp
  11. +87
    -0
      tools/pnnx/src/pass_level5/fuse_layernorm.cpp
  12. +21
    -0
      tools/pnnx/src/pass_level5/fuse_layernorm.h
  13. +584
    -806
      tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp
  14. +92
    -0
      tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp
  15. +21
    -0
      tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h
  16. +64
    -232
      tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp
  17. +132
    -333
      tools/pnnx/src/pass_level5/fuse_static_conv.cpp
  18. +84
    -174
      tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp
  19. +9
    -30
      tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp
  20. +29
    -104
      tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp
  21. +9
    -29
      tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp
  22. +33
    -103
      tools/pnnx/src/pass_level5/fuse_static_linear.cpp
  23. +12
    -30
      tools/pnnx/src/pass_ncnn/F_batch_norm.cpp
  24. +2
    -7
      tools/pnnx/src/pass_ncnn/F_embedding.cpp
  25. +4
    -11
      tools/pnnx/src/pass_ncnn/F_instance_norm.cpp
  26. +2
    -7
      tools/pnnx/src/pass_ncnn/F_prelu.cpp
  27. +6
    -20
      tools/pnnx/src/pass_ncnn/torch_addmm.cpp
  28. +3
    -0
      tools/pnnx/tests/CMakeLists.txt
  29. +70
    -0
      tools/pnnx/tests/test_pnnx_fuse_layernorm.py
  30. +127
    -0
      tools/pnnx/tests/test_pnnx_fuse_scaled_dot_product_attention.py
  31. +65
    -0
      tools/pnnx/tests/test_torch_repeat_interleave.py

+ 3
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -242,6 +242,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_randn.cpp
pass_level2/torch_randn_like.cpp
pass_level2/torch_real.cpp
pass_level2/torch_repeat_interleave.cpp
pass_level2/torch_roll.cpp
pass_level2/torch_scatter_add.cpp
pass_level2/torch_split.cpp
@@ -332,7 +333,9 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_pad_conv1d.cpp
pass_level5/fuse_pad_conv2d.cpp
pass_level5/fuse_pixel_unshuffle.cpp
pass_level5/fuse_layernorm.cpp
pass_level5/fuse_multiheadattention.cpp
pass_level5/fuse_scaled_dot_product_attention.cpp
pass_level5/fuse_select_to_unbind.cpp
pass_level5/fuse_slice_copy.cpp
pass_level5/fuse_slice_indices.cpp


+ 160
- 72
tools/pnnx/src/ir.cpp View File

@@ -588,6 +588,14 @@ Attribute operator+(const Attribute& a, const Attribute& b)

Parameter Parameter::parse_from_string(const std::string& value)
{
if (value.find('%') != std::string::npos)
{
Parameter p;
p.type = 4;
p.s = value;
return p;
}

Parameter p;
p.type = 0;

@@ -659,6 +667,96 @@ Parameter Parameter::parse_from_string(const std::string& value)
return p;
}

std::string Parameter::encode_to_string(const Parameter& param)
{
if (param.type == 0)
{
return std::string("None");
}
if (param.type == 1)
{
if (param.b)
return std::string("True");
else
return std::string("False");
}
if (param.type == 2)
{
return std::to_string(param.i);
}
if (param.type == 3)
{
char buf[64];
sprintf(buf, "%e", param.f);
return std::string(buf);
}
if (param.type == 4)
{
return param.s;
}
if (param.type == 5)
{
std::string s("(");
for (size_t i = 0; i < param.ai.size(); i++)
{
s += std::to_string(param.ai[i]);
if (i + 1 != param.ai.size())
s += std::string(",");
}
s += std::string(")");
return s;
}
if (param.type == 6)
{
std::string s("(");
for (size_t i = 0; i < param.af.size(); i++)
{
char buf[64];
sprintf(buf, "%e", param.af[i]);
s += std::string(buf);
if (i + 1 != param.af.size())
s += std::string(",");
}
s += std::string(")");
return s;
}
if (param.type == 7)
{
std::string s("(");
for (size_t i = 0; i < param.as.size(); i++)
{
s += param.as[i];
if (i + 1 != param.as.size())
s += std::string(",");
}
s += std::string(")");
return s;
}
if (param.type == 10)
{
char buf[128];
sprintf(buf, "%e+%ej", param.c.real(), param.c.imag());
return std::string(buf);
}
if (param.type == 11)
{
std::string s("(");
for (size_t i = 0; i < param.ac.size(); i++)
{
char buf[128];
sprintf(buf, "%e+%ej", param.ac[i].real(), param.ac[i].imag());
s += std::string(buf);
if (i + 1 != param.ac.size())
s += std::string(",");
}
s += std::string(")");
return s;
}

fprintf(stderr, "unknown parameter type %d\n", param.type);
return std::string();
}

Graph::Graph()
{
}
@@ -752,6 +850,14 @@ static void load_shape(Operator* op, const std::string& key, const std::string&
{
operand->shape.push_back(-1);
}
else if (elem[0] == '%')
{
// encode %abc as symbolic tag
operand->shape.push_back(-233);
int index = operand->shape.size() - 1;
std::string key = elem.substr(1);
operand->params[std::string("__shape_") + std::to_string(index)] = key;
}
else
{
int i = std::stoi(elem);
@@ -965,77 +1071,8 @@ int Graph::save(const std::string& parampath, const std::string& binpath)
fprintf(paramfp, " %s=", it.first.c_str());

const Parameter& param = it.second;
if (param.type == 0)
{
fprintf(paramfp, "None");
}
if (param.type == 1)
{
if (param.b)
fprintf(paramfp, "True");
else
fprintf(paramfp, "False");
}
if (param.type == 2)
{
fprintf(paramfp, "%d", param.i);
}
if (param.type == 3)
{
fprintf(paramfp, "%e", param.f);
}
if (param.type == 4)
{
fprintf(paramfp, "%s", param.s.c_str());
}
if (param.type == 5)
{
fprintf(paramfp, "(");
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(paramfp, "%d", param.ai[i]);
if (i + 1 != param.ai.size())
fprintf(paramfp, ",");
}
fprintf(paramfp, ")");
}
if (param.type == 6)
{
fprintf(paramfp, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(paramfp, "%e", param.af[i]);
if (i + 1 != param.af.size())
fprintf(paramfp, ",");
}
fprintf(paramfp, ")");
}
if (param.type == 7)
{
fprintf(paramfp, "(");
for (size_t i = 0; i < param.as.size(); i++)
{
fprintf(paramfp, "%s", param.as[i].c_str());
if (i + 1 != param.as.size())
fprintf(paramfp, ",");
}
fprintf(paramfp, ")");
}
if (param.type == 10)
{
fprintf(paramfp, "%e+%ej", param.c.real(), param.c.imag());
}
if (param.type == 11)
{
fprintf(paramfp, "(");
for (size_t i = 0; i < param.ac.size(); i++)
{
fprintf(paramfp, "%e+%ej", param.ac[i].real(), param.ac[i].imag());
if (i + 1 != param.ac.size())
fprintf(paramfp, ",");
}
fprintf(paramfp, ")");
}
std::string s = Parameter::encode_to_string(param);
fprintf(paramfp, "%s", s.c_str());
}

for (const auto& it : op->attrs)
@@ -2638,11 +2675,62 @@ int Graph::parse(const std::string& param)
{
// attribute
// load_attribute(op, key.substr(1), value, szr);
op->attrs[key.substr(1)] = Attribute();

Attribute& attr = op->attrs[key.substr(1)];

attr.type = 0;
if (value.empty())
continue;

if (value[0] == '%')
{
// @data=%op1.data
attr.data = std::vector<char>(value.begin(), value.end());
}

if (value[0] == '(')
{
// @data=(1,%c,?,4)f32

// type
std::string typestr = value.substr(value.find_last_of(')') + 1);
attr.type = string_to_type(typestr.c_str());

// shape
std::string lc = value.substr(1, value.find_last_of(')') - 1);
std::istringstream lcss(lc);

attr.shape.clear();
while (!lcss.eof())
{
std::string elem;
std::getline(lcss, elem, ',');

if (elem == "?")
{
attr.shape.push_back(-1);
}
else if (elem[0] == '%')
{
// encode %abc as symbolic tag
attr.shape.push_back(-233);
int index = attr.shape.size() - 1;
std::string key = elem.substr(1);
attr.params[std::string("__shape_") + std::to_string(index)] = key;
}
else
{
int i = std::stoi(elem);
attr.shape.push_back(i);
}
}
}
}
else if (key[0] == '$')
{
// operand input key
// load_input_key(op, key.substr(1), value);
load_input_key(op, key.substr(1), value);
}
else if (key[0] == '#')
{


+ 4
- 0
tools/pnnx/src/ir.h View File

@@ -171,6 +171,7 @@ public:
#endif // BUILD_PNNX

static Parameter parse_from_string(const std::string& value);
static std::string encode_to_string(const Parameter& param);

// 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others 10=c 11=ac
int type;
@@ -217,6 +218,8 @@ public:
std::vector<int> shape;

std::vector<char> data;

std::map<std::string, Parameter> params;
};

bool operator==(const Attribute& lhs, const Attribute& rhs);
@@ -246,6 +249,7 @@ private:
friend class Graph;
Operand()
{
type = 0;
}
};



+ 2
- 2
tools/pnnx/src/pass_level1.cpp View File

@@ -131,7 +131,7 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit

// sub_mod.dump(true, true, true);

op->attrs[name] = sub_mod.attr(name).toTensor();
op->attrs["data"] = sub_mod.attr(name).toTensor();
}
}
else if (n->kind() == c10::prim::Constant) // || n->kind() == c10::prim::ListConstruct)
@@ -165,7 +165,7 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit

op->params.erase("value");

op->attrs[name] = n->t(torch::jit::attr::value);
op->attrs["data"] = n->t(torch::jit::attr::value);
}
}
else if (n->kind() == c10::prim::CallMethod)


+ 569
- 20
tools/pnnx/src/pass_level2.cpp View File

@@ -24,6 +24,17 @@ GraphRewriterPass::~GraphRewriterPass()
{
}

const char* GraphRewriterPass::replace_pattern_graph() const
{
return 0;
}

const char* GraphRewriterPass::type_str() const
{
fprintf(stderr, "GraphRewriterPass type_str() should be implemented\n");
return "unk";
}

const char* GraphRewriterPass::name_str() const
{
return type_str();
@@ -46,15 +57,96 @@ bool GraphRewriterPass::match(const std::map<std::string, const Operator*>& /*ma

void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
for (auto x : captured_params)
if (replace_pattern_graph() == 0)
{
for (auto x : captured_params)
{
op->params[x.first] = x.second;
}

return;
}

for (auto x : op->params)
{
op->params[x.first] = x.second;
if (x.second.type != 4)
continue;

std::string str = x.second.s;
if (str.find('%') == std::string::npos)
continue;

// search % token and replace with captured
size_t pos = str.find('%');
while (pos != std::string::npos)
{
// %xyz
char buf[256];
sscanf(str.c_str() + pos + 1, "%255[^][,() ]", buf);
std::string key(buf);

if (captured_params.find(key) == captured_params.end())
{
fprintf(stderr, "replace pattern param %%%s missing captured\n", key.c_str());
return;
}

// replace %xyz with encoded_str
std::string encoded_str = Parameter::encode_to_string(captured_params.at(key));
str.replace(pos, key.size() + 1, encoded_str);

pos = str.find('%', pos + 1);
}

op->params[x.first] = Parameter::parse_from_string(str);
}
}

void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
write(op, captured_params);

for (auto x : op->attrs)
{
if (x.second.type != 0)
continue;

std::string key((const char*)x.second.data.data());
if (key.empty())
continue;

op->attrs[x.first] = captured_attrs.at(key);
}
}

void GraphRewriterPass::write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params) const
{
for (auto x : ops)
{
Operator* op = x.second;
write(op, captured_params);
}
}

void GraphRewriterPass::write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
write(ops, captured_params);

for (auto x : ops)
{
Operator* op = x.second;
for (auto x : op->attrs)
{
if (x.second.type != 0)
continue;

std::string key(x.second.data.begin(), x.second.data.end());
if (key.empty() || key[0] != '%')
continue;

op->attrs[x.first] = captured_attrs.at(key.substr(1));
}
}
}

static std::map<int, std::vector<const GraphRewriterPass*> > g_global_pnnx_graph_rewriter_passes;
@@ -75,6 +167,126 @@ GraphRewriterPassRegister::~GraphRewriterPassRegister()
delete pass;
}

static bool token_is_argument(const std::string& t)
{
if (t[0] != '@' || t.size() < 2)
return false;

for (size_t i = 1; i < t.size(); i++)
{
if (t[i] < '0' || t[i] > '9')
return false;
}

return true;
}

static bool match_expression(const Operator* a, const Operator* b, std::map<std::string, Parameter>& captured_params)
{
if (a->params.size() != 1 || a->params.find("expr") == a->params.end())
return false;

if (b->params.size() != 1 || b->params.find("expr") == b->params.end())
return false;

const std::string& a_expr = a->params.at("expr").s;
const std::string& b_expr = b->params.at("expr").s;

if (a_expr == b_expr)
return true;

// split into tokens
std::vector<std::string> a_tokens;
std::vector<std::string> b_tokens;
{
std::string t;
for (size_t i = 0; i < a_expr.size(); i++)
{
char ch = a_expr[i];

if (ch == '[') // list
{
t += ch;
a_tokens.push_back(t);
t.clear();
}
else if (ch == '(' || ch == ')' || ch == ',' || ch == ']')
{
if (!t.empty())
{
a_tokens.push_back(t);
t.clear();
}
}
else
{
t += ch;
}
}

if (!t.empty())
{
a_tokens.push_back(t);
}
}
{
std::string t;
for (size_t i = 0; i < b_expr.size(); i++)
{
char ch = b_expr[i];

if (ch == '[') // list
{
t += ch;
b_tokens.push_back(t);
t.clear();
}
else if (ch == '(' || ch == ')' || ch == ',' || ch == ']')
{
if (!t.empty())
{
b_tokens.push_back(t);
t.clear();
}
}
else
{
t += ch;
}
}

if (!t.empty())
{
b_tokens.push_back(t);
}
}

if (a_tokens.size() != b_tokens.size())
return false;

// capture values
for (size_t i = 0; i < a_tokens.size(); i++)
{
const std::string& at = a_tokens[i];
const std::string& bt = b_tokens[i];

if (at == bt)
continue;

if (bt[0] != '%')
return false;

if (token_is_argument(at))
return false;

std::string key = bt.substr(1);

captured_params[key] = Parameter::parse_from_string(at);
}

return true;
}

static bool match_parameter(const Parameter& a, const Parameter& b, std::map<std::string, Parameter>& captured_params)
{
if (b.type == 4 && b.s[0] == '%')
@@ -97,6 +309,77 @@ static bool match_parameter(const Parameter& a, const Parameter& b, std::map<std
return true;
}

if (b.type == 4 && (b.s[0] == '(' || b.s[0] == '[') && b.s.find('%') != std::string::npos)
{
// list with pattern
if (a.type != 5 && a.type != 6 && a.type != 7)
return false;

std::string lc = b.s.substr(1, b.s.size() - 2);
std::istringstream lcss(lc);

size_t i = 0;
while (!lcss.eof())
{
std::string elem;
std::getline(lcss, elem, ',');

if (elem[0] == '%')
{
std::string key = elem.substr(1);
if (captured_params.find(key) != captured_params.end())
{
// match previous captured parameter
if (a.type == 5 && captured_params.at(key).i != a.ai[i])
return false;
if (a.type == 6 && captured_params.at(key).f != a.af[i])
return false;
if (a.type == 7 && captured_params.at(key).s != a.as[i])
return false;
}

// captured parameter
if (a.type == 5)
captured_params[key] = a.ai[i];
if (a.type == 6)
captured_params[key] = a.af[i];
if (a.type == 7)
captured_params[key] = a.as[i];
}
else if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9')))
{
// string
if (a.type != 7)
return false;

if (a.as[i] != elem)
return false;
}
else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos)
{
// float
if (a.type != 6)
return false;

if (a.af[i] != std::stof(elem))
return false;
}
else
{
// integer
if (a.type != 5)
return false;

if (a.ai[i] != std::stoi(elem))
return false;
}

i++;
}

return true;
}

if (a.type != b.type)
{
if (a.type == 2 && b.type == 3)
@@ -174,6 +457,76 @@ static bool match_parameter(const Parameter& a, const Parameter& b, std::map<std
return false;
}

static bool match_attribute(const Attribute& a, const Attribute& b, std::map<std::string, Parameter>& captured_params, const std::string& attrname, std::map<std::string, Attribute>& captured_attrs)
{
// @data
// @data=(1,2,3,4)f32
// @data=%op1.data

if (b.type == 0)
{
std::string bs(b.data.begin(), b.data.end());
if (bs.empty())
{
// capture any shape
captured_attrs[attrname] = a;
return true;
}

if (bs[0] == '%')
{
// the captured replace
return true;
}

fprintf(stderr, "malformed attribute pattern %s\n", bs.c_str());
return false;
}

const std::vector<int>& a_shape = a.shape;
const std::vector<int>& b_shape = b.shape;
if (b_shape.empty())
return false;

if (a_shape.empty())
return false;

if (a_shape.size() != b_shape.size())
return false;

for (size_t j = 0; j < a_shape.size(); j++)
{
int ai = a_shape[j];
int bi = b_shape[j];
if (ai == bi)
continue;

if (bi == -1)
continue;

if (bi > 0)
return false;

if (bi != -233)
return false;

std::string key = b.params.at(std::string("__shape_") + std::to_string(j)).s;

if (captured_params.find(key) != captured_params.end())
{
// match previous captured parameter
if (captured_params.at(key).i != ai)
return false;
}

// captured parameter
captured_params[key] = ai;
}

captured_attrs[attrname] = a;
return true;
}

static bool match_operator(const Operator* a, const Operator* b, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
{
if (a->type != b->type)
@@ -197,6 +550,11 @@ static bool match_operator(const Operator* a, const Operator* b, std::map<std::s
captured_params[b->name + '.' + pkey] = pp;
}
}
else if (a->type == "pnnx.Expression")
{
if (!match_expression(a, b, captured_params))
return false;
}
else
{
if (a->params.size() != b->params.size())
@@ -215,13 +573,120 @@ static bool match_operator(const Operator* a, const Operator* b, std::map<std::s
}
}

// match shapes
for (size_t i = 0; i < a->inputs.size(); i++)
{
int a_type = a->inputs[i]->type;
int b_type = b->inputs[i]->type;
if (b_type != 0 && a_type != b_type)
return false;

const std::vector<int>& a_shape = a->inputs[i]->shape;
const std::vector<int>& b_shape = b->inputs[i]->shape;
if (b_shape.empty())
continue;

if (a_shape.empty())
return false;

if (a_shape.size() != b_shape.size())
return false;

for (size_t j = 0; j < a_shape.size(); j++)
{
int ai = a_shape[j];
int bi = b_shape[j];
if (ai == bi)
continue;

if (bi == -1)
continue;

if (bi > 0)
return false;

if (bi != -233)
return false;

std::string key = b->inputs[i]->params.at(std::string("__shape_") + std::to_string(j)).s;

if (captured_params.find(key) != captured_params.end())
{
// match previous captured parameter
if (captured_params.at(key).i != ai)
return false;
}

// captured parameter
captured_params[key] = ai;
}
}

for (size_t i = 0; i < a->outputs.size(); i++)
{
int a_type = a->outputs[i]->type;
int b_type = b->outputs[i]->type;
if (b_type != 0 && a_type != b_type)
return false;

const std::vector<int>& a_shape = a->outputs[i]->shape;
const std::vector<int>& b_shape = b->outputs[i]->shape;
if (b_shape.empty())
continue;

if (a_shape.empty())
return false;

if (a_shape.size() != b_shape.size())
return false;

for (size_t j = 0; j < a_shape.size(); j++)
{
int ai = a_shape[j];
int bi = b_shape[j];
if (ai == bi)
continue;

if (bi == -1)
continue;

if (bi > 0)
return false;

if (bi != -233)
return false;

std::string key = b->outputs[i]->params.at(std::string("__shape_") + std::to_string(j)).s;

if (captured_params.find(key) != captured_params.end())
{
// match previous captured parameter
if (captured_params.at(key).i != ai)
return false;
}

// captured parameter
captured_params[key] = ai;
}
}

for (const auto& p : a->attrs)
{
const std::string& akey = p.first;
const Attribute& aa = p.second;

// capture all attributes
captured_attrs[b->name + '.' + akey] = aa;
std::string attrname = b->name + '.' + akey;

if (b->attrs.find(akey) == b->attrs.end())
{
// capture all attributes
captured_attrs[attrname] = aa;
}
else
{
if (!match_attribute(aa, b->attrs.at(akey), captured_params, attrname, captured_attrs))
return false;
}
}

return true;
@@ -484,27 +949,111 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
cur = graph.ops[cur_index];
}

Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur);

for (const auto& k : pattern_graph_inputs)
if (pass->replace_pattern_graph() == 0)
{
Operand* r = (Operand*)matched_inputs.at(k);
r->consumers.push_back(op);
op->inputs.push_back(r);
// insert single
Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur);

op->inputnames.push_back(k);
}
for (const auto& k : pattern_graph_inputs)
{
Operand* r = (Operand*)matched_inputs.at(k);
r->consumers.push_back(op);
op->inputs.push_back(r);

for (const auto& k : pattern_graph_outputs)
{
Operand* r = (Operand*)matched_outputs.at(k);
r->producer = op;
op->outputs.push_back(r);
op->inputnames.push_back(k);
}

for (const auto& k : pattern_graph_outputs)
{
Operand* r = (Operand*)matched_outputs.at(k);
r->producer = op;
op->outputs.push_back(r);
}

pass->write(op, captured_params, captured_attrs);

new_ops.push_back(op);
}
else
{
// insert multiple
Graph replace_graph;
replace_graph.parse(pass->replace_pattern_graph());

// move operators and operands from replace_graph to graph except input and output
std::map<std::string, Operator*> ops;
for (size_t i = 0; i < replace_graph.ops.size(); i++)
{
Operator* op = replace_graph.ops[i];
if (op->type == "pnnx.Input" || op->type == "pnnx.Output")
continue;

graph.ops.insert(std::find(graph.ops.begin(), graph.ops.end(), cur), op);
replace_graph.ops[i] = 0;
ops[op->name] = op;
}

for (size_t i = 0; i < replace_graph.operands.size(); i++)
{
Operand* r = replace_graph.operands[i];
if (r->producer->type == "pnnx.Input" || (r->consumers.size() == 1 && r->consumers[0]->type == "pnnx.Output"))
continue;

graph.operands.push_back(r);
replace_graph.operands[i] = 0;
}

replace_graph.ops.erase(std::remove(replace_graph.ops.begin(), replace_graph.ops.end(), (Operator*)0), replace_graph.ops.end());
replace_graph.operands.erase(std::remove(replace_graph.operands.begin(), replace_graph.operands.end(), (Operand*)0), replace_graph.operands.end());

pass->write(op, captured_params, captured_attrs);
for (size_t i = 0; i < pattern_graph_inputs.size(); i++)
{
const std::string& k = pattern_graph_inputs[i];
Operand* r = (Operand*)matched_inputs.at(k);
const Operand* rr = replace_graph.get_operand(k);

for (auto x : rr->consumers)
{
r->consumers.push_back(x);

new_ops.push_back(op);
x->inputnames.resize(x->inputs.size());
for (size_t j = 0; j < x->inputs.size(); j++)
{
if (x->inputs[j]->name == k)
{
x->inputs[j] = r;
x->inputnames[j] = k;
break;
}
}
}
}

for (size_t i = 0; i < pattern_graph_outputs.size(); i++)
{
const std::string& k = pattern_graph_outputs[i];
Operand* r = (Operand*)matched_outputs.at(k);
const Operand* rr = replace_graph.get_operand(k);

r->producer = rr->producer;

for (size_t j = 0; j < r->producer->outputs.size(); j++)
{
if (r->producer->outputs[j]->name == k)
{
r->producer->outputs[j] = r;
break;
}
}
}

pass->write(ops, captured_params, captured_attrs);

for (auto x : ops)
{
new_ops.push_back(x.second);
}
}
}

// assign new op name number


+ 7
- 1
tools/pnnx/src/pass_level2.h View File

@@ -26,7 +26,9 @@ public:

virtual const char* match_pattern_graph() const = 0;

virtual const char* type_str() const = 0;
virtual const char* replace_pattern_graph() const;

virtual const char* type_str() const;

virtual const char* name_str() const;

@@ -39,6 +41,10 @@ public:
virtual void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const;

virtual void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const;

virtual void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params) const;

virtual void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const;
};

class GraphRewriterPassRegister


+ 68
- 0
tools/pnnx/src/pass_level2/torch_repeat_interleave.cpp View File

@@ -0,0 +1,68 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

#include <torch/csrc/api/include/torch/torch.h>

namespace pnnx {

class torch_repeat_interleave : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 repeats
pnnx.Input input_2 0 1 dim
prim::Constant op_0 0 1 output_size value=*
aten::repeat_interleave op_1 4 1 input repeats dim output_size out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.repeat_interleave";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_repeat_interleave, 20)

class torch_repeat_interleave_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 repeats
pnnx.Input input_2 0 1 dim
aten::repeat_interleave op_0 3 1 input repeats dim out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.repeat_interleave";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_repeat_interleave_1, 20)

} // namespace pnnx

+ 4
- 0
tools/pnnx/src/pass_level5.cpp View File

@@ -35,10 +35,12 @@
#include "pass_level5/fuse_convtranspose1d_batchnorm1d.h"
#include "pass_level5/fuse_convtranspose2d_batchnorm2d.h"
#include "pass_level5/fuse_contiguous_view.h"
#include "pass_level5/fuse_layernorm.h"
#include "pass_level5/fuse_linear_batchnorm1d.h"
#include "pass_level5/fuse_multiheadattention.h"
#include "pass_level5/fuse_pad_conv1d.h"
#include "pass_level5/fuse_pad_conv2d.h"
#include "pass_level5/fuse_scaled_dot_product_attention.h"
#include "pass_level5/fuse_select_to_unbind.h"
#include "pass_level5/fuse_slice_copy.h"
#include "pass_level5/fuse_slice_indices.h"
@@ -124,7 +126,9 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons
eliminate_reshape_shape_expression(g);

fuse_channel_shuffle(g);
fuse_layernorm(g);
fuse_multiheadattention(g);
fuse_scaled_dot_product_attention(g);

fuse_index_expression(g);



+ 2
- 2
tools/pnnx/src/pass_level5/fold_constants.cpp View File

@@ -43,9 +43,9 @@ void fold_constants(Graph& graph, const std::set<std::string>& foldable_constant
// replace producer with attribute
Operator* op_new = graph.new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op);

op_new->attrs[std::string("pnnx_fold_") + name] = Attribute();
op_new->attrs["data"] = Attribute();

Attribute& t2 = op_new->attrs[std::string("pnnx_fold_") + name];
Attribute& t2 = op_new->attrs["data"];
t2.type = operand->type;
t2.shape = operand->shape;
size_t size = zip.get_file_size(name);


+ 5
- 48
tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp View File

@@ -26,9 +26,9 @@ public:
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Expression op_0 1 1 input 13 expr=%expr
pnnx.Expression op_0 1 1 input 13 expr=[int(size(@0,0)),2,int(floor_divide(size(@0,1),%groups)),int(size(@0,2)),int(size(@0,3))]
Tensor.view op_1 2 1 input 13 14
pnnx.Expression op_2 1 1 input 15 expr=%expr2
pnnx.Expression op_2 1 1 input 15 expr=[int(size(@0,0)),-1,int(size(@0,2)),int(size(@0,3))]
torch.transpose op_3 1 1 14 16 dim0=1 dim1=2
Tensor.reshape op_4 2 1 16 15 out
pnnx.Output output 1 0 out
@@ -44,32 +44,6 @@ pnnx.Output output 1 0 out
{
return "channelshuffle";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const std::string& expr = captured_params.at("expr").s;
const std::string& expr2 = captured_params.at("expr2").s;

if (expr2 != "[int(size(@0,0)),-1,int(size(@0,2)),int(size(@0,3))]")
return false;

int groups;
int nscan = sscanf(expr.c_str(), "[int(size(@0,0)),2,int(floor_divide(size(@0,1),%d)),int(size(@0,2)),int(size(@0,3))]", &groups);
if (nscan != 1)
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const std::string& expr = captured_params.at("expr").s;

int groups;
sscanf(expr.c_str(), "[int(size(@0,0)),2,int(floor_divide(size(@0,1),%d)),int(size(@0,2)),int(size(@0,3))]", &groups);

op->params["groups"] = groups;
}
};

class fuse_channel_shuffle_pass_1 : public GraphRewriterPass
@@ -80,9 +54,9 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
Tensor.view op_0 1 1 input 13 shape=%shape
Tensor.view op_0 1 1 input 13 shape=(%batch,%groups,%channels_per_group,%h,%w)
torch.transpose op_1 1 1 13 14 dim0=1 dim1=2
Tensor.reshape op_2 1 1 14 out shape=%shape2
Tensor.reshape op_2 1 1 14 out shape=(%batch,-1,%h,%w)
pnnx.Output output 1 0 out
)PNNXIR";
}
@@ -97,26 +71,9 @@ pnnx.Output output 1 0 out
return "channelshuffle";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
// (1,2,58,28,28)
// (1,-1,28,28)
const std::vector<int>& shape = captured_params.at("shape").ai;
const std::vector<int>& shape2 = captured_params.at("shape2").ai;

if (shape[0] != 1 || shape2[0] != 1 || shape2[1] != -1 || shape2[2] != shape[3] || shape2[3] != shape[4])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const std::vector<int>& shape = captured_params.at("shape").ai;

int groups = shape[1];

op->params["groups"] = groups;
op->params["groups"] = captured_params.at("groups");
}
};



+ 87
- 0
tools/pnnx/src/pass_level5/fuse_layernorm.cpp View File

@@ -0,0 +1,87 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_layernorm.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

#include <torch/csrc/api/include/torch/torch.h>

namespace pnnx {

class fuse_layernorm_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
8 7
pnnx.Input input 0 1 input #input=(1,%c,?,?)f32
pnnx.Attribute op_0 0 1 weight @data #weight=(%c,1,1)f32
pnnx.Attribute op_1 0 1 bias @data #bias=(%c,1,1)f32
torch.mean op_2 1 1 input mean dim=(1) keepdim=True
pnnx.Expression op_3 2 1 input mean 2 expr=pow(sub(@0,@1),2)
torch.mean op_4 1 1 2 var dim=(1) keepdim=True
pnnx.Expression op_5 5 1 weight input mean var bias out expr=add(mul(@0,div(sub(@1,@2),sqrt(add(@3,%eps)))),@4)
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
#if TORCH_VERSION_MAJOR >= 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 9
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
torch.permute op_0 1 1 input a dims=(0,2,3,1)
nn.LayerNorm op_1 1 1 a b elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data @bias=%op_1.data
torch.permute op_2 1 1 b out dims=(0,3,1,2)
pnnx.Output output 1 0 out
)PNNXIR";
#else
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
Tensor.permute op_0 1 1 input a dims=(0,2,3,1)
nn.LayerNorm op_1 1 1 a b elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data @bias=%op_1.data
Tensor.permute op_2 1 1 b out dims=(0,3,1,2)
pnnx.Output output 1 0 out
)PNNXIR";
#endif
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);

// fix weight bias shape from (c,1,1) to (c)
const int c = captured_params.at("c").i;
Operator* op_1 = ops.at("op_1");
op_1->attrs["weight"].shape = {c};
op_1->attrs["bias"].shape = {c};
}
};

void fuse_layernorm(Graph& graph)
{
fuse_layernorm_pass a;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_layernorm.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_layernorm(Graph& graph);

} // namespace pnnx

+ 584
- 806
tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp
File diff suppressed because it is too large
View File


+ 92
- 0
tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp View File

@@ -0,0 +1,92 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_scaled_dot_product_attention.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

#include <torch/csrc/api/include/torch/torch.h>

namespace pnnx {

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

class fuse_scaled_dot_product_attention_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
9 8
pnnx.Input input_0 0 1 query #query=(%batch,%num_heads,%qsize,%feat_per_head)f32
pnnx.Input input_1 0 1 key #key=(%batch,%num_heads,%kvsize,%feat_per_head)f32
pnnx.Input input_2 0 1 value #value=(%batch,%num_heads,%kvsize,%feat_per_head)f32
torch.permute op_0 1 1 key 59 dims=(0,1,3,2)
torch.matmul op_1 2 1 query 59 61
pnnx.Expression op_2 1 1 61 62 expr=div(@0,%sqrt_embed_dim_per_head)
F.softmax op_3 1 1 62 63 dim=-1
torch.matmul op_4 2 1 63 value out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
F.scaled_dot_product_attention op_0 3 1 query key value out attn_mask=None dropout_p=0.0 is_causal=False
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int feat_per_head = captured_params.at("feat_per_head").i;
const float sqrt_embed_dim_per_head = captured_params.at("sqrt_embed_dim_per_head").f;

if (!NearlyEqual(sqrt_embed_dim_per_head, sqrt(feat_per_head), 0.001))
return false;

return true;
}
};

void fuse_scaled_dot_product_attention(Graph& graph)
{
#if TORCH_VERSION_MAJOR >= 2
fuse_scaled_dot_product_attention_pass a;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
#endif
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_scaled_dot_product_attention(Graph& graph);

} // namespace pnnx

+ 64
- 232
tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp View File

@@ -29,21 +29,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32
pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32
F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.BatchNorm1d";
}
const char* name_str() const
{
return "batchnorm";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.BatchNorm1d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
@@ -51,26 +51,6 @@ pnnx.Output output 1 0 out
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 2 || input_rank == 3;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = false;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
}
};

class fuse_static_Fbatchnorm_pass_1d_1 : public GraphRewriterPass
@@ -81,23 +61,23 @@ public:
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32
pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32
pnnx.Attribute op_weight 0 1 weight @data
pnnx.Attribute op_bias 0 1 bias @data
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.BatchNorm1d";
}
const char* name_str() const
{
return "batchnorm";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.BatchNorm1d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
@@ -105,34 +85,6 @@ pnnx.Output output 1 0 out
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 2 || input_rank == 3;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Fbatchnorm_pass_2d : public GraphRewriterPass
@@ -143,47 +95,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32
pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32
F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps #input=(?,?,?,?)f32
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm2d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 4;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
const char* replace_pattern_graph() const
{
Attribute running_mean;
Attribute running_var;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = false;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.BatchNorm2d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};

@@ -195,57 +121,23 @@ public:
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32
pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32
pnnx.Attribute op_weight 0 1 weight @data
pnnx.Attribute op_bias 0 1 bias @data
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps #input=(?,?,?,?)f32
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm2d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 4;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
const char* replace_pattern_graph() const
{
Attribute running_mean;
Attribute running_var;
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.BatchNorm2d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};

@@ -257,47 +149,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32
pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32
F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps #input=(?,?,?,?,?)f32
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm3d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 5;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
const char* replace_pattern_graph() const
{
Attribute running_mean;
Attribute running_var;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = false;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.BatchNorm3d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=False @running_mean=%op_mean.data @running_var=%op_var.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};

@@ -309,57 +175,23 @@ public:
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
pnnx.Attribute op_mean 0 1 running_mean @data=(%num_features)f32
pnnx.Attribute op_var 0 1 running_var @data=(%num_features)f32
pnnx.Attribute op_weight 0 1 weight @data
pnnx.Attribute op_bias 0 1 bias @data
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps #input=(?,?,?,?,?)f32
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.BatchNorm3d";
}

const char* name_str() const
{
return "batchnorm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 5;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
const char* replace_pattern_graph() const
{
Attribute running_mean;
Attribute running_var;
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = running_mean.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;

op->attrs["running_mean"] = running_mean;
op->attrs["running_var"] = running_var;
op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.BatchNorm3d batchnorm 1 1 input out num_features=%num_features eps=%eps affine=True @running_mean=%op_mean.data @running_var=%op_var.data @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};



+ 132
- 333
tools/pnnx/src/pass_level5/fuse_static_conv.cpp View File

@@ -29,42 +29,30 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32
F.conv1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv1d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv1d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = false;

op->attrs["weight"] = weight;
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups;
}
};

@@ -76,47 +64,31 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32
F.conv1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv1d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv1d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups;
}
};

@@ -128,71 +100,32 @@ public:
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1)f32
F.conv1d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Expression op_1 2 1 a bias out expr=%expr
pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1)
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv1d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv1d conv1d 1 1 input out out_channels=%out_channels kernel_size=(%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv1d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::string& expr = captured_params.at("expr").s;
if (expr != "add(@0,@1)")
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

int out_channels = weight.shape[0];
if (bias.shape != std::vector<int>{1, out_channels, 1})
return false;

return true;
}
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
ops.at("conv1d")->params["in_channels"] = in_channels_per_group * groups;
}
};

@@ -204,42 +137,30 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32
F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv2d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv2d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = false;

op->attrs["weight"] = weight;
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups;
}
};

@@ -251,47 +172,31 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32
F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv2d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv2d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups;
}
};

@@ -303,71 +208,32 @@ public:
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kh,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1,1)f32
F.conv2d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Expression op_1 2 1 a bias out expr=%expr
pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1)
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv2d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv2d conv2d 1 1 input out out_channels=%out_channels kernel_size=(%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv2d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::string& expr = captured_params.at("expr").s;
if (expr != "add(@0,@1)")
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

int out_channels = weight.shape[0];
if (bias.shape != std::vector<int>{1, out_channels, 1, 1})
return false;

return true;
}
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
ops.at("conv2d")->params["in_channels"] = in_channels_per_group * groups;
}
};

@@ -379,42 +245,30 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32
F.conv3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv3d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv3d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3], weight.shape[4]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = false;

op->attrs["weight"] = weight;
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups;
}
};

@@ -426,47 +280,31 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32
F.conv3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv3d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv3d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3], weight.shape[4]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups;
}
};

@@ -478,71 +316,32 @@ public:
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_channels,%in_channels_per_group,%kd,%kh,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(1,%out_channels,1,1,1)f32
F.conv3d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Expression op_1 2 1 a bias out expr=%expr
pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1)
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Conv3d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv3d conv3d 1 1 input out out_channels=%out_channels kernel_size=(%kd,%kh,%kw) padding_mode=zeros stride=%stride padding=%padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* name_str() const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
return "conv3d";
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::string& expr = captured_params.at("expr").s;
if (expr != "add(@0,@1)")
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

int out_channels = weight.shape[0];
if (bias.shape != std::vector<int>{1, out_channels, 1, 1, 1})
return false;

return true;
}
const int in_channels_per_group = captured_params.at("in_channels_per_group").i;
const int groups = captured_params.at("groups").i;

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i;
op->params["out_channels"] = weight.shape[0];
op->params["kernel_size"] = std::vector<int>{weight.shape[2], weight.shape[3], weight.shape[4]};
op->params["padding_mode"] = std::string("zeros");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
ops.at("conv3d")->params["in_channels"] = in_channels_per_group * groups;
}
};



+ 84
- 174
tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp View File

@@ -29,44 +29,30 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kw)f32
F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.ConvTranspose1d";
}
const char* name_str() const
{
return "conv_transpose1d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.ConvTranspose1d conv_transpose1d 1 1 input out in_channels=%in_channels kernel_size=(%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

const int out_channels_per_group = captured_params.at("out_channels_per_group").i;
const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = false;

op->attrs["weight"] = weight;
ops.at("conv_transpose1d")->params["out_channels"] = out_channels_per_group * groups;
}
};

@@ -78,49 +64,33 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32
F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose1d";
}

const char* name_str() const
const char* replace_pattern_graph() const
{
return "conv_transpose1d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.ConvTranspose1d conv_transpose1d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int out_channels_per_group = captured_params.at("out_channels_per_group").i;
const int out_channels = captured_params.at("out_channels").i;
const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
if (out_channels != out_channels_per_group * groups)
return false;

return true;
}
};

@@ -132,44 +102,30 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kh,%kw)f32
F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose2d";
}

const char* name_str() const
const char* replace_pattern_graph() const
{
return "conv_transpose2d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.ConvTranspose2d conv_transpose2d 1 1 input out in_channels=%in_channels kernel_size=(%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

const int out_channels_per_group = captured_params.at("out_channels_per_group").i;
const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = false;

op->attrs["weight"] = weight;
ops.at("conv_transpose2d")->params["out_channels"] = out_channels_per_group * groups;
}
};

@@ -181,49 +137,33 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kh,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32
F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose2d";
}

const char* name_str() const
const char* replace_pattern_graph() const
{
return "conv_transpose2d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.ConvTranspose2d conv_transpose2d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int out_channels_per_group = captured_params.at("out_channels_per_group").i;
const int out_channels = captured_params.at("out_channels").i;
const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
if (out_channels != out_channels_per_group * groups)
return false;

return true;
}
};

@@ -235,44 +175,30 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kd,%kh,%kw)f32
F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose3d";
}

const char* name_str() const
const char* replace_pattern_graph() const
{
return "conv_transpose3d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.ConvTranspose3d conv_transpose3d 1 1 input out in_channels=%in_channels kernel_size=(%kd,%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=False @weight=%op_weight.data
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

const int out_channels_per_group = captured_params.at("out_channels_per_group").i;
const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = false;

op->attrs["weight"] = weight;
ops.at("conv_transpose3d")->params["out_channels"] = out_channels_per_group * groups;
}
};

@@ -284,49 +210,33 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%in_channels,%out_channels_per_group,%kd,%kh,%kw)f32
pnnx.Attribute op_bias 0 1 bias @data=(%out_channels)f32
F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose3d";
}

const char* name_str() const
const char* replace_pattern_graph() const
{
return "conv_transpose3d";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.ConvTranspose3d conv_transpose3d 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=(%kd,%kh,%kw) stride=%stride padding=%padding output_padding=%output_padding dilation=%dilation groups=%groups bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
bool match(const std::map<std::string, Parameter>& captured_params) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

const int out_channels_per_group = captured_params.at("out_channels_per_group").i;
const int out_channels = captured_params.at("out_channels").i;
const int groups = captured_params.at("groups").i;

op->params["groups"] = groups;
op->params["in_channels"] = weight.shape[0];
op->params["out_channels"] = weight.shape[1] * groups;
op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]};
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
if (out_channels != out_channels_per_group * groups)
return false;

return true;
}
};



+ 9
- 30
tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp View File

@@ -29,42 +29,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%num_channels)f32
pnnx.Attribute op_bias 0 1 bias @data=(%num_channels)f32
F.group_norm op_0 3 1 input weight bias out num_groups=%num_groups eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.GroupNorm";
}

const char* name_str() const
{
return "group_norm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_channels"] = weight.shape[0];
op->params["num_groups"] = captured_params.at("num_groups");
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.GroupNorm group_norm 1 1 input out num_channels=%num_channels num_groups=%num_groups eps=%eps affine=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};



+ 29
- 104
tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp View File

@@ -29,21 +29,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32
pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.InstanceNorm1d";
}
const char* name_str() const
{
return "instance_norm";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.InstanceNorm1d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
@@ -51,27 +51,6 @@ pnnx.Output output 1 0 out
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 2 || input_rank == 3;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = weight.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;
op->params["track_running_stats"] = false;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
}
};

class fuse_static_Finstancenorm_pass_2d : public GraphRewriterPass
@@ -82,48 +61,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps
pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32
pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps #input=(?,?,?,?)f32
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.InstanceNorm1d";
}

const char* name_str() const
{
return "instance_norm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 4;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
const char* replace_pattern_graph() const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = weight.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;
op->params["track_running_stats"] = false;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.InstanceNorm2d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};

@@ -135,48 +87,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps
pnnx.Attribute op_weight 0 1 weight @data=(%num_features)f32
pnnx.Attribute op_bias 0 1 bias @data=(%num_features)f32
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps #input=(?,?,?,?,?)f32
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.InstanceNorm1d";
}

const char* name_str() const
{
return "instance_norm";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
size_t input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
return input_rank == 5;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
const char* replace_pattern_graph() const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["num_features"] = weight.shape[0];
op->params["eps"] = captured_params.at("eps");
op->params["affine"] = true;
op->params["track_running_stats"] = false;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.InstanceNorm3d instance_norm 1 1 input out num_features=%num_features eps=%eps affine=True track_running_stats=False @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};



+ 9
- 29
tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp View File

@@ -29,41 +29,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data
pnnx.Attribute op_bias 0 1 bias @data
F.layer_norm op_0 3 1 input weight bias out normalized_shape=%normalized_shape eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.LayerNorm";
}

const char* name_str() const
{
return "layer_norm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["normalized_shape"] = captured_params.at("normalized_shape");
op->params["eps"] = captured_params.at("eps");
op->params["elementwise_affine"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.LayerNorm layer_norm 1 1 input out normalized_shape=%normalized_shape eps=%eps elementwise_affine=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};



+ 33
- 103
tools/pnnx/src/pass_level5/fuse_static_linear.cpp View File

@@ -29,36 +29,20 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32
F.linear op_0 2 1 input weight out bias=None
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Linear";
}

const char* name_str() const
{
return "linear";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["in_features"] = weight.shape[1];
op->params["out_features"] = weight.shape[0];
op->params["bias"] = false;

op->attrs["weight"] = weight;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=False @weight=%op_weight.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};

@@ -70,41 +54,21 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32
pnnx.Attribute op_bias 0 1 bias @data=(%out_features)f32
F.linear op_0 3 1 input weight bias out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
const char* replace_pattern_graph() const
{
return "nn.Linear";
}

const char* name_str() const
{
return "linear";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_features"] = weight.shape[1];
op->params["out_features"] = weight.shape[0];
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};

@@ -116,65 +80,31 @@ public:
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
F.linear op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups
pnnx.Expression op_1 2 1 a bias out expr=%expr
pnnx.Attribute op_weight 0 1 weight @data=(%out_features,%in_features)f32
pnnx.Attribute op_bias 0 1 bias @data=(1,%out_features,1)f32
F.linear op_0 2 1 input weight a
pnnx.Expression op_1 2 1 a bias out expr=add(@0,@1)
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Linear";
}

const char* name_str() const
const char* replace_pattern_graph() const
{
return "linear";
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Linear linear 1 1 input out in_features=%in_features out_features=%out_features bias=True @weight=%op_weight.data @bias=%op_bias.data
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::string& expr = captured_params.at("expr").s;
if (expr != "add(@0,@1)")
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

int out_channels = weight.shape[0];
if (bias.shape != std::vector<int>{1, out_channels, 1})
return false;

return true;
}
GraphRewriterPass::write(ops, captured_params, captured_attrs);

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

op->params["in_features"] = weight.shape[1];
op->params["out_features"] = weight.shape[0];
op->params["bias"] = true;

op->attrs["weight"] = weight;
op->attrs["bias"] = bias;
// fix bias shape
const int out_features = captured_params.at("out_features").i;
ops.at("linear")->attrs.at("bias").shape = {out_features};
}
};



+ 12
- 30
tools/pnnx/src/pass_ncnn/F_batch_norm.cpp View File

@@ -26,8 +26,8 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_mean 0 1 running_mean @data
pnnx.Attribute op_var 0 1 running_var @data
F.batch_norm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
@@ -45,15 +45,8 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
}
Attribute running_mean = captured_attrs.at("op_mean.data");
Attribute running_var = captured_attrs.at("op_var.data");

op->params["0"] = running_mean.shape[0];
op->params["1"] = captured_params.at("eps");
@@ -77,10 +70,10 @@ public:
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Attribute op_mean 0 1 running_mean @qwq
pnnx.Attribute op_var 0 1 running_var @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_mean 0 1 running_mean @data
pnnx.Attribute op_var 0 1 running_var @data
pnnx.Attribute op_weight 0 1 weight @data
pnnx.Attribute op_bias 0 1 bias @data
F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
@@ -98,21 +91,10 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute running_mean;
Attribute running_var;
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 8) == "op_mean.")
running_mean = x.second;
if (x.first.substr(0, 7) == "op_var.")
running_var = x.second;
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}
Attribute running_mean = captured_attrs.at("op_mean.data");
Attribute running_var = captured_attrs.at("op_var.data");
Attribute weight = captured_attrs.at("op_weight.data");
Attribute bias = captured_attrs.at("op_bias.data");

op->params["0"] = running_mean.shape[0];
op->params["1"] = captured_params.at("eps");


+ 2
- 7
tools/pnnx/src/pass_ncnn/F_embedding.cpp View File

@@ -26,7 +26,7 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data
F.embedding op_0 2 1 input weight out scale_grad_by_freq=False sparse=False
pnnx.Output output 1 0 out
)PNNXIR";
@@ -44,12 +44,7 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}
Attribute weight = captured_attrs.at("op_weight.data");

op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[0];


+ 4
- 11
tools/pnnx/src/pass_ncnn/F_instance_norm.cpp View File

@@ -67,8 +67,8 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @data
pnnx.Attribute op_bias 0 1 bias @data
F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
@@ -86,15 +86,8 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}
Attribute weight = captured_attrs.at("op_weight.data");
Attribute bias = captured_attrs.at("op_bias.data");

op->params["0"] = weight.shape[0];
op->params["1"] = captured_params.at("eps");


+ 2
- 7
tools/pnnx/src/pass_ncnn/F_prelu.cpp View File

@@ -26,7 +26,7 @@ public:
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_weight 0 1 weight @data
F.prelu op_0 2 1 input weight out
pnnx.Output output 1 0 out
)PNNXIR";
@@ -44,12 +44,7 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}
Attribute weight = captured_attrs.at("op_weight.data");

op->params["0"] = weight.shape[0];



+ 6
- 20
tools/pnnx/src/pass_ncnn/torch_addmm.cpp View File

@@ -26,8 +26,8 @@ public:
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 mat1
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
pnnx.Attribute op_bias 0 1 bias @data
pnnx.Attribute op_weight 0 1 weight @data
torch.addmm op_0 3 1 bias mat1 weight out alpha=%alpha beta=%beta
pnnx.Output output 1 0 out
)PNNXIR";
@@ -69,15 +69,8 @@ pnnx.Output output 1 0 out
if (alpha != 1.f || beta != 1.f)
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}
Attribute weight = captured_attrs.at("op_weight.data");
Attribute bias = captured_attrs.at("op_bias.data");

if (weight.shape.size() != 2 || bias.shape.size() != 1)
return false;
@@ -90,15 +83,8 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}
Attribute weight = captured_attrs.at("op_weight.data");
Attribute bias = captured_attrs.at("op_bias.data");

// transpose weight inch-outch to outch-inch
const int inch = weight.shape[0];


+ 3
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -213,6 +213,7 @@ pnnx_add_test(torch_ones)
pnnx_add_test(torch_ones_like)
pnnx_add_test(torch_permute)
pnnx_add_test(torch_prod)
pnnx_add_test(torch_repeat_interleave)
pnnx_add_test(torch_scatter_add)
pnnx_add_test(torch_sum)
pnnx_add_test(torch_split)
@@ -298,8 +299,10 @@ pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d)
pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d)
pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d)
pnnx_add_test(pnnx_fuse_input_unpack)
pnnx_add_test(pnnx_fuse_layernorm)
pnnx_add_test(pnnx_fuse_linear_batchnorm1d)
pnnx_add_test(pnnx_fuse_multiheadattention)
pnnx_add_test(pnnx_fuse_scaled_dot_product_attention)
pnnx_add_test(pnnx_fuse_select_to_unbind)
pnnx_add_test(pnnx_fuse_slice_to_tensor_split)
pnnx_add_test(pnnx_fuse_adjacent_reshape)


+ 70
- 0
tools/pnnx/tests/test_pnnx_fuse_layernorm.py View File

@@ -0,0 +1,70 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.ln_0 = LayerNorm2d(64)

def forward(self, x):
x = self.ln_0(x)
return x

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 64, 16, 16)

a0 = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_pnnx_fuse_layernorm.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_fuse_layernorm.pt inputshape=[1,64,16,16]")

# pnnx inference
import test_pnnx_fuse_layernorm_pnnx
b0 = test_pnnx_fuse_layernorm_pnnx.test_inference()

return torch.allclose(a0, b0, 1e-4, 1e-4)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 127
- 0
tools/pnnx/tests/test_pnnx_fuse_scaled_dot_product_attention.py View File

@@ -0,0 +1,127 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

from einops import rearrange
from typing import Any, Optional, Tuple, Union
from torch import Tensor
import math

class sam_Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""

def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."

self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)

def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head

def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C

def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)

# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)

# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)

# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)

return out

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.attention_0 = sam_Attention(embedding_dim=64, num_heads=4, downsample_rate=1)
self.attention_1 = sam_Attention(embedding_dim=64, num_heads=4, downsample_rate=2)

def forward(self, q, k, v):
a = self.attention_0(q, k, v)
b = self.attention_1(q, k, v)

return a, b

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 24, 64)
y = torch.rand(1, 24, 64)
z = torch.rand(1, 24, 64)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_pnnx_fuse_scaled_dot_product_attention.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_fuse_scaled_dot_product_attention.pt inputshape=[1,24,64],[1,24,64],[1,24,64]")

# pnnx inference
import test_pnnx_fuse_scaled_dot_product_attention_pnnx
b = test_pnnx_fuse_scaled_dot_product_attention_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 65
- 0
tools/pnnx/tests/test_torch_repeat_interleave.py View File

@@ -0,0 +1,65 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = torch.repeat_interleave(x, 2)
y = torch.repeat_interleave(y, 3, dim=1)
if version.parse(torch.__version__) >= version.parse('1.10'):
z = torch.repeat_interleave(z, torch.tensor([2, 1, 3]), dim=0, output_size=6)
else:
z = torch.repeat_interleave(z, torch.tensor([2, 1, 3]), dim=0)
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3)
y = torch.rand(4, 5)
z = torch.rand(3, 7, 8)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_repeat_interleave.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_repeat_interleave.pt inputshape=[3],[4,5],[3,7,8]")

# pnnx inference
import test_torch_repeat_interleave_pnnx
b = test_torch_repeat_interleave_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save