Browse Source

pnnx eval typed expression, optimize type conversion (#5921)

tags/20250428
nihui GitHub 1 year ago
parent
commit
67dd7a7bd5
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 385 additions and 203 deletions
  1. +385
    -203
      tools/pnnx/src/pass_level5/eval_expression.cpp

+ 385
- 203
tools/pnnx/src/pass_level5/eval_expression.cpp View File

@@ -110,14 +110,75 @@ static std::string eval_expression(const Operator* op)
}

// scan and stack
std::stack<std::string> exprstack;
struct typed_expr
{
std::string expr;
int type; // 0=i 1=f 2=cp 3=other
int literal;
int i;
float f;

typed_expr()
: type(3), literal(0), i(0), f(0.f)
{
}

typed_expr(int _i)
: type(0), literal(1), i(_i), f(0.f)
{
// fprintf(stderr, "typed_expr i %d\n", i);
}

typed_expr(float _f)
: type(1), literal(1), i(0), f(_f)
{
// fprintf(stderr, "typed_expr f %f\n", f);
}

typed_expr(const std::string& _expr)
: expr(_expr), type(3), literal(0), i(0), f(0.f)
{
// fprintf(stderr, "typed_expr ? %s\n", expr.c_str());
}

typed_expr(const std::string& _expr, int _type)
: expr(_expr), type(_type), literal(0), i(0), f(0.f)
{
// fprintf(stderr, "typed_expr %d %s\n", type, expr.c_str());
}

bool is_literal() const
{
return literal == 1;
}

bool is_interger_literal() const
{
return type == 0 && literal == 1;
}

std::string to_expr() const
{
if (literal == 1)
{
if (type == 0)
return std::to_string(i);
if (type == 1)
return std::to_string(f);
}

return expr;
}
};

std::stack<typed_expr> exprstack;
for (int i = (int)tokens.size() - 1; i >= 0; i--)
{
const std::string& t = tokens[i];

if (t == "size")
{
std::string a = exprstack.top();
std::string a = exprstack.top().to_expr();
exprstack.pop();

if (exprstack.empty())
@@ -127,20 +188,21 @@ static std::string eval_expression(const Operator* op)
}
else
{
std::string b = exprstack.top();
typed_expr b = exprstack.top();
exprstack.pop();

if (token_is_argument(a) && token_is_literal(b))
if (token_is_argument(a) && b.is_interger_literal())
{
int bi = b.i;

int input_index = std::stoi(a.substr(1));
if (op->inputs[input_index]->shape.empty())
{
std::string r = std::string("size(") + a + "," + b + ")";
exprstack.push(r);
std::string r = std::string("size(") + a + "," + std::to_string(bi) + ")";
exprstack.push(typed_expr(r, 0));
}
else
{
int bi = std::stoi(b);
if (bi < 0)
bi = op->inputs[input_index]->shape.size() + bi;
int r = op->inputs[input_index]->shape[bi];
@@ -148,386 +210,491 @@ static std::string eval_expression(const Operator* op)
{
// do not evaluate dynamic size info as -1
// just keep the size expression
std::string r = std::string("size(") + a + "," + b + ")";
exprstack.push(r);
std::string r = std::string("size(") + a + "," + std::to_string(bi) + ")";
exprstack.push(typed_expr(r, 0));
}
else
{
exprstack.push(std::to_string(r));
exprstack.push(r);
}
}
}
else
{
std::string r = std::string("size(") + a + "," + b + ")";
exprstack.push(r);
std::string r = std::string("size(") + a + "," + b.to_expr() + ")";
exprstack.push(typed_expr(r, 0));
}
}
}
else if (t == "int"
|| t == "abs"
|| t == "ceil"
|| t == "floor"
|| t == "round"
|| t == "trunc")
{
typed_expr a = exprstack.top();
exprstack.pop();

if (a.is_interger_literal())
{
// noop
exprstack.push(a);
}
else if (a.is_literal())
{
const float af = a.f;

int r = 0;
if (t == "int")
{
r = int(af);
}
if (t == "ceil")
{
r = ceil(af);
}
if (t == "floor")
{
r = floor(af);
}
if (t == "round")
{
// round to nearest even
int old_rm = fegetround();
fesetround(FE_TONEAREST);
r = nearbyintf(af);
fesetround(old_rm);
}
if (t == "trunc")
{
r = trunc(af);
}
exprstack.push(r);
}
else if (a.type == 0)
{
// noop
exprstack.push(a);
}
else
{
std::string r = t + "(" + a.to_expr() + ")";
if (a.type < 2)
{
exprstack.push(typed_expr(r, 0));
}
else
{
exprstack.push(r);
}
}
}
else if (t == "neg"
|| t == "sign"
|| t == "square")
{
typed_expr a = exprstack.top();
exprstack.pop();

if (a.is_interger_literal())
{
const int ai = a.i;

int r = 0;
if (t == "neg")
{
r = -ai;
}
if (t == "sign")
{
r = ai > 0 ? 1 : (ai == 0 ? 0 : -1);
}
if (t == "square")
{
r = ai * ai;
}

exprstack.push(r);
}
else if (a.is_literal())
{
const float af = a.f;

float r = 0;
if (t == "neg")
{
r = -af;
}
if (t == "sign")
{
r = af > 0.f ? 1.f : (af == 0.f ? 0.f : -1.f);
}
if (t == "square")
{
r = af * af;
}
exprstack.push(r);
}
else
{
std::string r = t + "(" + a.to_expr() + ")";
exprstack.push(typed_expr(r, a.type));
}
}
else if (t == "abs"
|| t == "acos"
|| t == "acosh"
|| t == "asin"
|| t == "asinh"
|| t == "atan"
|| t == "atanh"
|| t == "ceil"
|| t == "cos"
|| t == "cosh"
|| t == "erf"
|| t == "exp"
|| t == "floor"
|| t == "log"
|| t == "log10"
|| t == "neg"
|| t == "reciprocal"
|| t == "round"
|| t == "rsqrt"
|| t == "sign"
|| t == "sin"
|| t == "sinh"
|| t == "sqrt"
|| t == "square"
|| t == "tan"
|| t == "tanh"
|| t == "trunc"
|| t == "torch.bool"
|| t == "torch.float"
|| t == "torch.long")
|| t == "torch.float")
{
std::string a = exprstack.top();
typed_expr a = exprstack.top();
exprstack.pop();

if (token_is_literal(a))
if (a.is_literal())
{
float af = std::stof(a);

if (t == "int")
{
int r = int(af);
if (token_is_interger_literal(a))
{
r = std::stoi(a);
}
float af = a.type == 0 ? a.i : a.f;

exprstack.push(std::to_string(r));
}
float r = 0.f;
if (t == "abs")
{
float r = abs(af);
exprstack.push(std::to_string(r));
r = abs(af);
}
if (t == "acos")
{
float r = acos(af);
exprstack.push(std::to_string(r));
r = acos(af);
}
if (t == "acosh")
{
float r = acosh(af);
exprstack.push(std::to_string(r));
r = acosh(af);
}
if (t == "asin")
{
float r = asin(af);
exprstack.push(std::to_string(r));
r = asin(af);
}
if (t == "asinh")
{
float r = asinh(af);
exprstack.push(std::to_string(r));
r = asinh(af);
}
if (t == "atan")
{
float r = atan(af);
exprstack.push(std::to_string(r));
r = atan(af);
}
if (t == "atanh")
{
float r = atanh(af);
exprstack.push(std::to_string(r));
}
if (t == "ceil")
{
float r = ceil(af);
exprstack.push(std::to_string(r));
r = atanh(af);
}
if (t == "cos")
{
float r = cos(af);
exprstack.push(std::to_string(r));
r = cos(af);
}
if (t == "cosh")
{
float r = cosh(af);
exprstack.push(std::to_string(r));
r = cosh(af);
}
if (t == "erf")
{
float r = erf(af);
exprstack.push(std::to_string(r));
r = erf(af);
}
if (t == "exp")
{
float r = exp(af);
exprstack.push(std::to_string(r));
}
if (t == "floor")
{
float r = floor(af);
exprstack.push(std::to_string(r));
r = exp(af);
}
if (t == "log")
{
float r = log(af);
exprstack.push(std::to_string(r));
r = log(af);
}
if (t == "log10")
{
float r = log10(af);
exprstack.push(std::to_string(r));
}
if (t == "neg")
{
float r = -af;
exprstack.push(std::to_string(r));
r = log10(af);
}
if (t == "reciprocal")
{
float r = 1.f / af;
exprstack.push(std::to_string(r));
}
if (t == "round")
{
// round to nearest even
int old_rm = fegetround();
fesetround(FE_TONEAREST);
float r = nearbyintf(af);
fesetround(old_rm);
exprstack.push(std::to_string(r));
r = 1.f / af;
}
if (t == "rsqrt")
{
float r = 1.f / sqrt(af);
exprstack.push(std::to_string(r));
}
if (t == "sign")
{
float r = af > 0.f ? 1.f : (af == 0.f ? 0.f : -1.f);
exprstack.push(std::to_string(r));
r = 1.f / sqrt(af);
}
if (t == "sin")
{
float r = sin(af);
exprstack.push(std::to_string(r));
r = sin(af);
}
if (t == "sinh")
{
float r = sinh(af);
exprstack.push(std::to_string(r));
r = sinh(af);
}
if (t == "sqrt")
{
float r = sqrt(af);
exprstack.push(std::to_string(r));
}
if (t == "square")
{
float r = af * af;
exprstack.push(std::to_string(r));
r = sqrt(af);
}
if (t == "tan")
{
float r = tan(af);
exprstack.push(std::to_string(r));
r = tan(af);
}
if (t == "tanh")
{
float r = tanh(af);
exprstack.push(std::to_string(r));
r = tanh(af);
}
if (t == "trunc")
if (t == "torch.float")
{
float r = trunc(af);
exprstack.push(std::to_string(r));
// noop
r = af;
}
if (t == "torch.bool")
{
int r = int(af);
if (token_is_interger_literal(a))
{
r = std::stoi(a);
}
exprstack.push(r);
}
else
{
std::string r = t + "(" + a.to_expr() + ")";
exprstack.push(r);
}
}
else if (t == "torch.bool"
|| t == "torch.long")
{
typed_expr a = exprstack.top();
exprstack.pop();

exprstack.push(r == 0 ? "False" : "True");
}
if (t == "torch.float")
if (a.is_literal())
{
if (t == "torch.bool")
{
float r = af;
exprstack.push(std::to_string(r));
bool r = a.type == 0 ? (a.i != 0) : (a.f != 0.f);
std::string rs = r ? "True" : "False";
exprstack.push(rs);
}
if (t == "torch.long")
{
long r = long(af);
if (token_is_interger_literal(a))
{
r = std::stol(a);
}

long r = a.type == 0 ? long(a.i) : long(a.f);
exprstack.push(std::to_string(r));
}
}
else
{
std::string r = t + "(" + a + ")";
std::string r = t + "(" + a.to_expr() + ")";
exprstack.push(r);
}
}
else if (t == "atan2"
|| t == "add"
else if (t == "add"
|| t == "sub"
|| t == "max"
|| t == "maximum"
|| t == "min"
|| t == "minimum"
|| t == "mul"
|| t == "div"
|| t == "floor_divide"
|| t == "fmod"
|| t == "pow"
|| t == "remainder"
|| t == "logaddexp")
|| t == "remainder")
{
std::string a = exprstack.top();
typed_expr a = exprstack.top();
exprstack.pop();
std::string b = exprstack.top();
typed_expr b = exprstack.top();
exprstack.pop();

if (token_is_literal(a) && token_is_literal(b))
if (a.is_interger_literal() && b.is_interger_literal())
{
float af = std::stof(a);
float bf = std::stof(b);
const int ai = a.i;
const int bi = b.i;

if (t == "atan2")
{
float r = atan2(af, bf);
exprstack.push(std::to_string(r));
}
int r = 0;
if (t == "add")
{
float r = af + bf;
exprstack.push(std::to_string(r));
r = ai + bi;
}
if (t == "sub")
{
float r = af - bf;
exprstack.push(std::to_string(r));
r = ai - bi;
}
if (t == "max" || t == "maximum")
{
float r = std::max(af, bf);
exprstack.push(std::to_string(r));
r = std::max(ai, bi);
}
if (t == "minimum")
if (t == "min" || t == "minimum")
{
float r = std::min(af, bf);
exprstack.push(std::to_string(r));
r = std::min(ai, bi);
}
if (t == "mul")
{
float r = af * bf;
exprstack.push(std::to_string(r));
r = ai * bi;
}
if (t == "div")
if (t == "floor_divide")
{
float r = af / bf;
exprstack.push(std::to_string(r));
r = ai / bi;
}
if (t == "fmod")
if (t == "remainder")
{
float r = fmod(af, bf);
exprstack.push(std::to_string(r));
r = ai % bi;
}
if (t == "floor_divide")
exprstack.push(r);
}
else if (a.is_literal() && b.is_literal())
{
const float af = a.type == 0 ? a.i : a.f;
const float bf = b.type == 0 ? b.i : b.f;

float r = 0.f;
if (t == "add")
{
int r = (int)af / (int)bf;
exprstack.push(std::to_string(r));
r = af + bf;
}
if (t == "pow")
if (t == "sub")
{
float r = pow(af, bf);
exprstack.push(std::to_string(r));
r = af - bf;
}
if (t == "max" || t == "maximum")
{
r = std::max(af, bf);
}
if (t == "min" || t == "minimum")
{
r = std::min(af, bf);
}
if (t == "mul")
{
r = af * bf;
}
if (t == "floor_divide")
{
r = (int)af / (int)bf;
}
if (t == "remainder")
{
float r = fmod(af, bf);
r = fmod(af, bf);
if (af * bf < 0)
r += bf;
exprstack.push(std::to_string(r));
}
exprstack.push(r);
}
else
{
std::string r = t + "(" + a.to_expr() + "," + b.to_expr() + ")";
if (a.type == 0 && b.type == 0)
{
exprstack.push(typed_expr(r, 0));
}
else if (a.type == 1 || b.type == 1)
{
exprstack.push(typed_expr(r, 1));
}
else
{
exprstack.push(r);
}
}
}
else if (t == "atan2"
|| t == "div"
|| t == "fmod"
|| t == "pow"
|| t == "logaddexp")
{
typed_expr a = exprstack.top();
exprstack.pop();
typed_expr b = exprstack.top();
exprstack.pop();

if (a.is_literal() && b.is_literal())
{
const float af = a.type == 0 ? a.i : a.f;
const float bf = b.type == 0 ? b.i : b.f;

float r = 0.f;
if (t == "atan2")
{
r = atan2(af, bf);
}
if (t == "div")
{
r = af / bf;
}
if (t == "fmod")
{
r = fmod(af, bf);
}
if (t == "pow")
{
r = pow(af, bf);
}
if (t == "logaddexp")
{
float r = log(exp(af) + exp(bf));
exprstack.push(std::to_string(r));
r = log(exp(af) + exp(bf));
}
exprstack.push(r);
}
else
{
std::string r = t + "(" + a + "," + b + ")";
exprstack.push(r);
std::string r = t + "(" + a.to_expr() + "," + b.to_expr() + ")";
if (a.type == 1 || b.type == 1)
{
exprstack.push(typed_expr(r, 1));
}
else
{
exprstack.push(r);
}
}
}
else if (t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
{
std::string a = exprstack.top();
typed_expr a = exprstack.top();
exprstack.pop();
std::string b = exprstack.top();
typed_expr b = exprstack.top();
exprstack.pop();

if (token_is_interger_literal(a) && token_is_interger_literal(b))
if (a.is_interger_literal() && b.is_interger_literal())
{
int ai = std::stoi(a);
int bi = std::stoi(b);
const int ai = a.i;
const int bi = b.i;

int r = 0;
if (t == "and")
{
int r = ai & bi;
exprstack.push(std::to_string(r));
r = ai & bi;
}
if (t == "or")
{
int r = ai | bi;
exprstack.push(std::to_string(r));
r = ai | bi;
}
if (t == "xor")
{
int r = ai ^ bi;
exprstack.push(std::to_string(r));
r = ai ^ bi;
}
if (t == "lshift")
{
int r = ai << bi;
exprstack.push(std::to_string(r));
r = ai << bi;
}
if (t == "rshift")
{
int r = ai >> bi;
exprstack.push(std::to_string(r));
r = ai >> bi;
}
exprstack.push(r);
}
else
{
std::string r = t + "(" + a + "," + b + ")";
exprstack.push(r);
std::string r = t + "(" + a.to_expr() + "," + b.to_expr() + ")";
exprstack.push(typed_expr(r, 0)); // bitwise always produce integer
}
}
else if (t == "[") // list
{
std::vector<std::string> elements;
std::vector<typed_expr> elements;
while (!exprstack.empty())
{
std::string a = exprstack.top();
typed_expr a = exprstack.top();
exprstack.pop();

elements.push_back(a);
@@ -536,13 +703,13 @@ static std::string eval_expression(const Operator* op)
std::string r = "[";
for (int j = 0; j < (int)elements.size() - 1; j++)
{
r += elements[j];
r += elements[j].to_expr();
if (j + 1 != (int)elements.size())
r += ",";
}
if (!elements.empty())
{
r += elements[elements.size() - 1];
r += elements[elements.size() - 1].to_expr();
}
r += "]";

@@ -555,19 +722,34 @@ static std::string eval_expression(const Operator* op)
else
{
// literal
exprstack.push(t);
if (token_is_complex(t))
{
exprstack.push(t);
}
else if (token_is_interger_literal(t))
{
exprstack.push(std::stoi(t));
}
else if (token_is_literal(t))
{
exprstack.push(std::stof(t));
}
else
{
exprstack.push(t);
}
}
}

std::string r = exprstack.top();
std::string r = exprstack.top().to_expr();
exprstack.pop();
while (!exprstack.empty())
{
r += std::string(",") + exprstack.top();
r += std::string(",") + exprstack.top().to_expr();
exprstack.pop();
}

// fprintf(stderr, "eval_expression return %s\n", r.c_str());
// fprintf(stderr, "eval_expression return %s\n", r.c_str());

return r;
}


Loading…
Cancel
Save