Browse Source

fix pnnx ghost reshape shape expression inputs, fix intmax overflow on fuse/eval expression (#4923)

tags/20230816
nihui GitHub 2 years ago
parent
commit
60fedae38b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 1 deletions
  1. +2
    -0
      tools/pnnx/src/pass_level1.cpp
  2. +3
    -0
      tools/pnnx/src/pass_level3/fuse_expression.cpp
  3. +4
    -1
      tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp
  4. +5
    -0
      tools/pnnx/src/pass_level5/eval_expression.cpp

+ 2
- 0
tools/pnnx/src/pass_level1.cpp View File

@@ -132,6 +132,8 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit
// sub_mod.dump(true, true, true);

op->attrs["data"] = sub_mod.attr(name).toTensor();
op->outputs[0]->type = op->attrs["data"].type;
op->outputs[0]->shape = op->attrs["data"].shape;
}
}
else if (n->kind() == c10::prim::Constant) // || n->kind() == c10::prim::ListConstruct)


+ 3
- 0
tools/pnnx/src/pass_level3/fuse_expression.cpp View File

@@ -251,6 +251,9 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
int64_t v;
zip.read_file(operand->name, (char*)&v);

if (v == std::numeric_limits<int64_t>::max()) v = INT_MAX;
if (v == std::numeric_limits<int64_t>::min()) v = INT_MIN;

char tmp[32];
sprintf(tmp, "%ld", v);
expr += tmp;


+ 4
- 1
tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp View File

@@ -165,7 +165,10 @@ void eliminate_reshape_shape_expression(Graph& graph)
if (op_expr->outputs[0]->consumers.size() == 0)
{
// remove expression operator
op_expr->inputs[0]->remove_consumer(op_expr);
for (auto x : op_expr->inputs)
{
x->remove_consumer(op_expr);
}

Operand* op_expr_out = op_expr->outputs[0];



+ 5
- 0
tools/pnnx/src/pass_level5/eval_expression.cpp View File

@@ -193,6 +193,11 @@ static std::string eval_expression(const Operator* op)
if (t == "int")
{
int r = int(af);
if (token_is_interger_literal(a))
{
r = std::stoi(a);
}

exprstack.push(std::to_string(r));
}
if (t == "abs")


Loading…
Cancel
Save