diff --git a/tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp b/tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp index 36079eeac..e62aa580a 100644 --- a/tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp +++ b/tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp @@ -54,6 +54,7 @@ void convert_Tensor_select(Graph& graph) if (axis > batch_index) axis -= 1; + int dim = op->params.at("dim").i; int index = op->params.at("index").i; op->params["9"] = std::vector {index}; @@ -63,24 +64,26 @@ void convert_Tensor_select(Graph& graph) op->params.erase("dim"); op->params.erase("index"); - // reshape for output, squeezing the select dim + // squeezing the select dim { Operand* out = op->outputs[0]; - Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op); + Operator* squeeze = graph.new_operator_after("torch.squeeze", op->name + "_ncnnsqueeze", op); - Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in"); + Operand* squeeze_in = graph.new_operand(op->name + "_ncnnsqueeze_in"); - reshape->inputs.push_back(reshape_in); - reshape->outputs.push_back(out); + squeeze->inputs.push_back(squeeze_in); + squeeze->outputs.push_back(out); - op->outputs[0] = reshape_in; + op->outputs[0] = squeeze_in; - out->producer = reshape; - reshape_in->producer = op; - reshape_in->consumers.push_back(reshape); + out->producer = squeeze; + squeeze_in->producer = op; + squeeze_in->consumers.push_back(squeeze); - reshape->params["shape"] = out->shape; + squeeze->params["dim"] = dim; + + squeeze_in->params["__batch_index"] = batch_index; } break;