diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp index 441cfca1c..3619a452f 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.cpp +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -110,14 +110,75 @@ static std::string eval_expression(const Operator* op) } // scan and stack - std::stack exprstack; + struct typed_expr + { + std::string expr; + int type; // 0=i 1=f 2=cp 3=other + int literal; + int i; + float f; + + typed_expr() + : type(3), literal(0), i(0), f(0.f) + { + } + + typed_expr(int _i) + : type(0), literal(1), i(_i), f(0.f) + { + // fprintf(stderr, "typed_expr i %d\n", i); + } + + typed_expr(float _f) + : type(1), literal(1), i(0), f(_f) + { + // fprintf(stderr, "typed_expr f %f\n", f); + } + + typed_expr(const std::string& _expr) + : expr(_expr), type(3), literal(0), i(0), f(0.f) + { + // fprintf(stderr, "typed_expr ? %s\n", expr.c_str()); + } + + typed_expr(const std::string& _expr, int _type) + : expr(_expr), type(_type), literal(0), i(0), f(0.f) + { + // fprintf(stderr, "typed_expr %d %s\n", type, expr.c_str()); + } + + bool is_literal() const + { + return literal == 1; + } + + bool is_interger_literal() const + { + return type == 0 && literal == 1; + } + + std::string to_expr() const + { + if (literal == 1) + { + if (type == 0) + return std::to_string(i); + if (type == 1) + return std::to_string(f); + } + + return expr; + } + }; + + std::stack exprstack; for (int i = (int)tokens.size() - 1; i >= 0; i--) { const std::string& t = tokens[i]; if (t == "size") { - std::string a = exprstack.top(); + std::string a = exprstack.top().to_expr(); exprstack.pop(); if (exprstack.empty()) @@ -127,20 +188,21 @@ static std::string eval_expression(const Operator* op) } else { - std::string b = exprstack.top(); + typed_expr b = exprstack.top(); exprstack.pop(); - if (token_is_argument(a) && token_is_literal(b)) + if (token_is_argument(a) && b.is_interger_literal()) { + int bi = b.i; + int input_index = std::stoi(a.substr(1)); if (op->inputs[input_index]->shape.empty()) { - std::string r = std::string("size(") + a + "," + b + ")"; - exprstack.push(r); + std::string r = std::string("size(") + a + "," + std::to_string(bi) + ")"; + exprstack.push(typed_expr(r, 0)); } else { - int bi = std::stoi(b); if (bi < 0) bi = op->inputs[input_index]->shape.size() + bi; int r = op->inputs[input_index]->shape[bi]; @@ -148,386 +210,491 @@ static std::string eval_expression(const Operator* op) { // do not evaluate dynamic size info as -1 // just keep the size expression - std::string r = std::string("size(") + a + "," + b + ")"; - exprstack.push(r); + std::string r = std::string("size(") + a + "," + std::to_string(bi) + ")"; + exprstack.push(typed_expr(r, 0)); } else { - exprstack.push(std::to_string(r)); + exprstack.push(r); } } } else { - std::string r = std::string("size(") + a + "," + b + ")"; - exprstack.push(r); + std::string r = std::string("size(") + a + "," + b.to_expr() + ")"; + exprstack.push(typed_expr(r, 0)); } } } else if (t == "int" - || t == "abs" + || t == "ceil" + || t == "floor" + || t == "round" + || t == "trunc") + { + typed_expr a = exprstack.top(); + exprstack.pop(); + + if (a.is_interger_literal()) + { + // noop + exprstack.push(a); + } + else if (a.is_literal()) + { + const float af = a.f; + + int r = 0; + if (t == "int") + { + r = int(af); + } + if (t == "ceil") + { + r = ceil(af); + } + if (t == "floor") + { + r = floor(af); + } + if (t == "round") + { + // round to nearest even + int old_rm = fegetround(); + fesetround(FE_TONEAREST); + r = nearbyintf(af); + fesetround(old_rm); + } + if (t == "trunc") + { + r = trunc(af); + } + exprstack.push(r); + } + else if (a.type == 0) + { + // noop + exprstack.push(a); + } + else + { + std::string r = t + "(" + a.to_expr() + ")"; + if (a.type < 2) + { + exprstack.push(typed_expr(r, 0)); + } + else + { + exprstack.push(r); + } + } + } + else if (t == "neg" + || t == "sign" + || t == "square") + { + typed_expr a = exprstack.top(); + exprstack.pop(); + + if (a.is_interger_literal()) + { + const int ai = a.i; + + int r = 0; + if (t == "neg") + { + r = -ai; + } + if (t == "sign") + { + r = ai > 0 ? 1 : (ai == 0 ? 0 : -1); + } + if (t == "square") + { + r = ai * ai; + } + + exprstack.push(r); + } + else if (a.is_literal()) + { + const float af = a.f; + + float r = 0; + if (t == "neg") + { + r = -af; + } + if (t == "sign") + { + r = af > 0.f ? 1.f : (af == 0.f ? 0.f : -1.f); + } + if (t == "square") + { + r = af * af; + } + exprstack.push(r); + } + else + { + std::string r = t + "(" + a.to_expr() + ")"; + exprstack.push(typed_expr(r, a.type)); + } + } + else if (t == "abs" || t == "acos" || t == "acosh" || t == "asin" || t == "asinh" || t == "atan" || t == "atanh" - || t == "ceil" || t == "cos" || t == "cosh" || t == "erf" || t == "exp" - || t == "floor" || t == "log" || t == "log10" - || t == "neg" || t == "reciprocal" - || t == "round" || t == "rsqrt" - || t == "sign" || t == "sin" || t == "sinh" || t == "sqrt" - || t == "square" || t == "tan" || t == "tanh" - || t == "trunc" - || t == "torch.bool" - || t == "torch.float" - || t == "torch.long") + || t == "torch.float") { - std::string a = exprstack.top(); + typed_expr a = exprstack.top(); exprstack.pop(); - if (token_is_literal(a)) + if (a.is_literal()) { - float af = std::stof(a); - - if (t == "int") - { - int r = int(af); - if (token_is_interger_literal(a)) - { - r = std::stoi(a); - } + float af = a.type == 0 ? a.i : a.f; - exprstack.push(std::to_string(r)); - } + float r = 0.f; if (t == "abs") { - float r = abs(af); - exprstack.push(std::to_string(r)); + r = abs(af); } if (t == "acos") { - float r = acos(af); - exprstack.push(std::to_string(r)); + r = acos(af); } if (t == "acosh") { - float r = acosh(af); - exprstack.push(std::to_string(r)); + r = acosh(af); } if (t == "asin") { - float r = asin(af); - exprstack.push(std::to_string(r)); + r = asin(af); } if (t == "asinh") { - float r = asinh(af); - exprstack.push(std::to_string(r)); + r = asinh(af); } if (t == "atan") { - float r = atan(af); - exprstack.push(std::to_string(r)); + r = atan(af); } if (t == "atanh") { - float r = atanh(af); - exprstack.push(std::to_string(r)); - } - if (t == "ceil") - { - float r = ceil(af); - exprstack.push(std::to_string(r)); + r = atanh(af); } if (t == "cos") { - float r = cos(af); - exprstack.push(std::to_string(r)); + r = cos(af); } if (t == "cosh") { - float r = cosh(af); - exprstack.push(std::to_string(r)); + r = cosh(af); } if (t == "erf") { - float r = erf(af); - exprstack.push(std::to_string(r)); + r = erf(af); } if (t == "exp") { - float r = exp(af); - exprstack.push(std::to_string(r)); - } - if (t == "floor") - { - float r = floor(af); - exprstack.push(std::to_string(r)); + r = exp(af); } if (t == "log") { - float r = log(af); - exprstack.push(std::to_string(r)); + r = log(af); } if (t == "log10") { - float r = log10(af); - exprstack.push(std::to_string(r)); - } - if (t == "neg") - { - float r = -af; - exprstack.push(std::to_string(r)); + r = log10(af); } if (t == "reciprocal") { - float r = 1.f / af; - exprstack.push(std::to_string(r)); - } - if (t == "round") - { - // round to nearest even - int old_rm = fegetround(); - fesetround(FE_TONEAREST); - float r = nearbyintf(af); - fesetround(old_rm); - exprstack.push(std::to_string(r)); + r = 1.f / af; } if (t == "rsqrt") { - float r = 1.f / sqrt(af); - exprstack.push(std::to_string(r)); - } - if (t == "sign") - { - float r = af > 0.f ? 1.f : (af == 0.f ? 0.f : -1.f); - exprstack.push(std::to_string(r)); + r = 1.f / sqrt(af); } if (t == "sin") { - float r = sin(af); - exprstack.push(std::to_string(r)); + r = sin(af); } if (t == "sinh") { - float r = sinh(af); - exprstack.push(std::to_string(r)); + r = sinh(af); } if (t == "sqrt") { - float r = sqrt(af); - exprstack.push(std::to_string(r)); - } - if (t == "square") - { - float r = af * af; - exprstack.push(std::to_string(r)); + r = sqrt(af); } if (t == "tan") { - float r = tan(af); - exprstack.push(std::to_string(r)); + r = tan(af); } if (t == "tanh") { - float r = tanh(af); - exprstack.push(std::to_string(r)); + r = tanh(af); } - if (t == "trunc") + if (t == "torch.float") { - float r = trunc(af); - exprstack.push(std::to_string(r)); + // noop + r = af; } - if (t == "torch.bool") - { - int r = int(af); - if (token_is_interger_literal(a)) - { - r = std::stoi(a); - } + exprstack.push(r); + } + else + { + std::string r = t + "(" + a.to_expr() + ")"; + exprstack.push(r); + } + } + else if (t == "torch.bool" + || t == "torch.long") + { + typed_expr a = exprstack.top(); + exprstack.pop(); - exprstack.push(r == 0 ? "False" : "True"); - } - if (t == "torch.float") + if (a.is_literal()) + { + if (t == "torch.bool") { - float r = af; - exprstack.push(std::to_string(r)); + bool r = a.type == 0 ? (a.i != 0) : (a.f != 0.f); + std::string rs = r ? "True" : "False"; + exprstack.push(rs); } if (t == "torch.long") { - long r = long(af); - if (token_is_interger_literal(a)) - { - r = std::stol(a); - } - + long r = a.type == 0 ? long(a.i) : long(a.f); exprstack.push(std::to_string(r)); } } else { - std::string r = t + "(" + a + ")"; + std::string r = t + "(" + a.to_expr() + ")"; exprstack.push(r); } } - else if (t == "atan2" - || t == "add" + else if (t == "add" || t == "sub" || t == "max" || t == "maximum" || t == "min" || t == "minimum" || t == "mul" - || t == "div" || t == "floor_divide" - || t == "fmod" - || t == "pow" - || t == "remainder" - || t == "logaddexp") + || t == "remainder") { - std::string a = exprstack.top(); + typed_expr a = exprstack.top(); exprstack.pop(); - std::string b = exprstack.top(); + typed_expr b = exprstack.top(); exprstack.pop(); - if (token_is_literal(a) && token_is_literal(b)) + if (a.is_interger_literal() && b.is_interger_literal()) { - float af = std::stof(a); - float bf = std::stof(b); + const int ai = a.i; + const int bi = b.i; - if (t == "atan2") - { - float r = atan2(af, bf); - exprstack.push(std::to_string(r)); - } + int r = 0; if (t == "add") { - float r = af + bf; - exprstack.push(std::to_string(r)); + r = ai + bi; } if (t == "sub") { - float r = af - bf; - exprstack.push(std::to_string(r)); + r = ai - bi; } if (t == "max" || t == "maximum") { - float r = std::max(af, bf); - exprstack.push(std::to_string(r)); + r = std::max(ai, bi); } - if (t == "minimum") + if (t == "min" || t == "minimum") { - float r = std::min(af, bf); - exprstack.push(std::to_string(r)); + r = std::min(ai, bi); } if (t == "mul") { - float r = af * bf; - exprstack.push(std::to_string(r)); + r = ai * bi; } - if (t == "div") + if (t == "floor_divide") { - float r = af / bf; - exprstack.push(std::to_string(r)); + r = ai / bi; } - if (t == "fmod") + if (t == "remainder") { - float r = fmod(af, bf); - exprstack.push(std::to_string(r)); + r = ai % bi; } - if (t == "floor_divide") + exprstack.push(r); + } + else if (a.is_literal() && b.is_literal()) + { + const float af = a.type == 0 ? a.i : a.f; + const float bf = b.type == 0 ? b.i : b.f; + + float r = 0.f; + if (t == "add") { - int r = (int)af / (int)bf; - exprstack.push(std::to_string(r)); + r = af + bf; } - if (t == "pow") + if (t == "sub") { - float r = pow(af, bf); - exprstack.push(std::to_string(r)); + r = af - bf; + } + if (t == "max" || t == "maximum") + { + r = std::max(af, bf); + } + if (t == "min" || t == "minimum") + { + r = std::min(af, bf); + } + if (t == "mul") + { + r = af * bf; + } + if (t == "floor_divide") + { + r = (int)af / (int)bf; } if (t == "remainder") { - float r = fmod(af, bf); + r = fmod(af, bf); if (af * bf < 0) r += bf; - exprstack.push(std::to_string(r)); + } + exprstack.push(r); + } + else + { + std::string r = t + "(" + a.to_expr() + "," + b.to_expr() + ")"; + if (a.type == 0 && b.type == 0) + { + exprstack.push(typed_expr(r, 0)); + } + else if (a.type == 1 || b.type == 1) + { + exprstack.push(typed_expr(r, 1)); + } + else + { + exprstack.push(r); + } + } + } + else if (t == "atan2" + || t == "div" + || t == "fmod" + || t == "pow" + || t == "logaddexp") + { + typed_expr a = exprstack.top(); + exprstack.pop(); + typed_expr b = exprstack.top(); + exprstack.pop(); + + if (a.is_literal() && b.is_literal()) + { + const float af = a.type == 0 ? a.i : a.f; + const float bf = b.type == 0 ? b.i : b.f; + + float r = 0.f; + if (t == "atan2") + { + r = atan2(af, bf); + } + if (t == "div") + { + r = af / bf; + } + if (t == "fmod") + { + r = fmod(af, bf); + } + if (t == "pow") + { + r = pow(af, bf); } if (t == "logaddexp") { - float r = log(exp(af) + exp(bf)); - exprstack.push(std::to_string(r)); + r = log(exp(af) + exp(bf)); } + exprstack.push(r); } else { - std::string r = t + "(" + a + "," + b + ")"; - exprstack.push(r); + std::string r = t + "(" + a.to_expr() + "," + b.to_expr() + ")"; + if (a.type == 1 || b.type == 1) + { + exprstack.push(typed_expr(r, 1)); + } + else + { + exprstack.push(r); + } } } else if (t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") { - std::string a = exprstack.top(); + typed_expr a = exprstack.top(); exprstack.pop(); - std::string b = exprstack.top(); + typed_expr b = exprstack.top(); exprstack.pop(); - if (token_is_interger_literal(a) && token_is_interger_literal(b)) + if (a.is_interger_literal() && b.is_interger_literal()) { - int ai = std::stoi(a); - int bi = std::stoi(b); + const int ai = a.i; + const int bi = b.i; + int r = 0; if (t == "and") { - int r = ai & bi; - exprstack.push(std::to_string(r)); + r = ai & bi; } if (t == "or") { - int r = ai | bi; - exprstack.push(std::to_string(r)); + r = ai | bi; } if (t == "xor") { - int r = ai ^ bi; - exprstack.push(std::to_string(r)); + r = ai ^ bi; } if (t == "lshift") { - int r = ai << bi; - exprstack.push(std::to_string(r)); + r = ai << bi; } if (t == "rshift") { - int r = ai >> bi; - exprstack.push(std::to_string(r)); + r = ai >> bi; } + exprstack.push(r); } else { - std::string r = t + "(" + a + "," + b + ")"; - exprstack.push(r); + std::string r = t + "(" + a.to_expr() + "," + b.to_expr() + ")"; + exprstack.push(typed_expr(r, 0)); // bitwise always produce integer } } else if (t == "[") // list { - std::vector elements; + std::vector elements; while (!exprstack.empty()) { - std::string a = exprstack.top(); + typed_expr a = exprstack.top(); exprstack.pop(); elements.push_back(a); @@ -536,13 +703,13 @@ static std::string eval_expression(const Operator* op) std::string r = "["; for (int j = 0; j < (int)elements.size() - 1; j++) { - r += elements[j]; + r += elements[j].to_expr(); if (j + 1 != (int)elements.size()) r += ","; } if (!elements.empty()) { - r += elements[elements.size() - 1]; + r += elements[elements.size() - 1].to_expr(); } r += "]"; @@ -555,19 +722,34 @@ static std::string eval_expression(const Operator* op) else { // literal - exprstack.push(t); + if (token_is_complex(t)) + { + exprstack.push(t); + } + else if (token_is_interger_literal(t)) + { + exprstack.push(std::stoi(t)); + } + else if (token_is_literal(t)) + { + exprstack.push(std::stof(t)); + } + else + { + exprstack.push(t); + } } } - std::string r = exprstack.top(); + std::string r = exprstack.top().to_expr(); exprstack.pop(); while (!exprstack.empty()) { - r += std::string(",") + exprstack.top(); + r += std::string(",") + exprstack.top().to_expr(); exprstack.pop(); } - // fprintf(stderr, "eval_expression return %s\n", r.c_str()); + // fprintf(stderr, "eval_expression return %s\n", r.c_str()); return r; }