diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index 1ff4451e5..be00396ac 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -821,6 +821,8 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde int q = graph_op_count - 1; for (; q >= 1; q--) { + matched = true; + for (const Operator* pattern : pattern_graph_output_operators) { for (size_t i = 0; i < pattern->inputs.size(); i++) @@ -900,6 +902,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde matched_outputs.clear(); captured_params.clear(); captured_attrs.clear(); + matched = false; continue; } diff --git a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp index 6865ec72d..258b26176 100644 --- a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp +++ b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp @@ -167,6 +167,32 @@ static void solve_batch_index_forward(Operand* operand) solve_batch_index_backward(r); } } + else if (op->type == "torch.transpose") + { + const int dim0 = op->params.at("dim0").i; + const int dim1 = op->params.at("dim1").i; + + int batch_index_transposed = batch_index; + if (dim0 == batch_index) + { + batch_index_transposed = dim1; + } + else if (dim1 == batch_index) + { + batch_index_transposed = dim0; + } + + for (Operand* r : op->outputs) + { + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index_transposed; + + solve_batch_index_forward(r); + solve_batch_index_backward(r); + } + } else if (op->type == "Tensor.reshape" || op->type == "Tensor.view") { if (op->params.find("shape") == op->params.end()) @@ -241,6 +267,32 @@ static void solve_batch_index_backward(Operand* operand) solve_batch_index_forward(r); } } + else if (op->type == "torch.transpose") + { + const int dim0 = op->params.at("dim0").i; + const int dim1 = op->params.at("dim1").i; + + int batch_index_transposed = batch_index; + if (dim0 == batch_index) + { + batch_index_transposed = dim1; + } + else if (dim1 == batch_index) + { + batch_index_transposed = dim0; + } + + for (Operand* r : op->inputs) + { + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index_transposed; + + solve_batch_index_backward(r); + solve_batch_index_forward(r); + } + } else if (op->type == "Tensor.reshape" || op->type == "Tensor.view") { if (op->params.find("shape") == op->params.end())