Browse Source

pnnx ncnn convert select to crop and squeeze (#5826)

tags/20241226
nihui GitHub 1 year ago
parent
commit
a12baae13b
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 10 deletions
  1. +13
    -10
      tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp

+ 13
- 10
tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp View File

@@ -54,6 +54,7 @@ void convert_Tensor_select(Graph& graph)
if (axis > batch_index)
axis -= 1;

int dim = op->params.at("dim").i;
int index = op->params.at("index").i;

op->params["9"] = std::vector<int> {index};
@@ -63,24 +64,26 @@ void convert_Tensor_select(Graph& graph)
op->params.erase("dim");
op->params.erase("index");

// reshape for output, squeezing the select dim
// squeezing the select dim
{
Operand* out = op->outputs[0];

Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op);
Operator* squeeze = graph.new_operator_after("torch.squeeze", op->name + "_ncnnsqueeze", op);

Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in");
Operand* squeeze_in = graph.new_operand(op->name + "_ncnnsqueeze_in");

reshape->inputs.push_back(reshape_in);
reshape->outputs.push_back(out);
squeeze->inputs.push_back(squeeze_in);
squeeze->outputs.push_back(out);

op->outputs[0] = reshape_in;
op->outputs[0] = squeeze_in;

out->producer = reshape;
reshape_in->producer = op;
reshape_in->consumers.push_back(reshape);
out->producer = squeeze;
squeeze_in->producer = op;
squeeze_in->consumers.push_back(squeeze);

reshape->params["shape"] = out->shape;
squeeze->params["dim"] = dim;

squeeze_in->params["__batch_index"] = batch_index;
}

break;


Loading…
Cancel
Save