Browse Source

fix out of index access on pattern match failure, fix #5347 (#5350)

tags/20240410
nihui GitHub 2 years ago
parent
commit
cf293ec35e
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 0 deletions
  1. +3
    -0
      tools/pnnx/src/pass_level2.cpp
  2. +52
    -0
      tools/pnnx/src/pass_ncnn/solve_batch_index.cpp

+ 3
- 0
tools/pnnx/src/pass_level2.cpp View File

@@ -821,6 +821,8 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
int q = graph_op_count - 1;
for (; q >= 1; q--)
{
matched = true;

for (const Operator* pattern : pattern_graph_output_operators)
{
for (size_t i = 0; i < pattern->inputs.size(); i++)
@@ -900,6 +902,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
matched_outputs.clear();
captured_params.clear();
captured_attrs.clear();
matched = false;
continue;
}



+ 52
- 0
tools/pnnx/src/pass_ncnn/solve_batch_index.cpp View File

@@ -167,6 +167,32 @@ static void solve_batch_index_forward(Operand* operand)
solve_batch_index_backward(r);
}
}
else if (op->type == "torch.transpose")
{
const int dim0 = op->params.at("dim0").i;
const int dim1 = op->params.at("dim1").i;

int batch_index_transposed = batch_index;
if (dim0 == batch_index)
{
batch_index_transposed = dim1;
}
else if (dim1 == batch_index)
{
batch_index_transposed = dim0;
}

for (Operand* r : op->outputs)
{
if (r->params.find("__batch_index") != r->params.end())
continue;

r->params["__batch_index"] = batch_index_transposed;

solve_batch_index_forward(r);
solve_batch_index_backward(r);
}
}
else if (op->type == "Tensor.reshape" || op->type == "Tensor.view")
{
if (op->params.find("shape") == op->params.end())
@@ -241,6 +267,32 @@ static void solve_batch_index_backward(Operand* operand)
solve_batch_index_forward(r);
}
}
else if (op->type == "torch.transpose")
{
const int dim0 = op->params.at("dim0").i;
const int dim1 = op->params.at("dim1").i;

int batch_index_transposed = batch_index;
if (dim0 == batch_index)
{
batch_index_transposed = dim1;
}
else if (dim1 == batch_index)
{
batch_index_transposed = dim0;
}

for (Operand* r : op->inputs)
{
if (r->params.find("__batch_index") != r->params.end())
continue;

r->params["__batch_index"] = batch_index_transposed;

solve_batch_index_backward(r);
solve_batch_index_forward(r);
}
}
else if (op->type == "Tensor.reshape" || op->type == "Tensor.view")
{
if (op->params.find("shape") == op->params.end())


Loading…
Cancel
Save