|
|
|
@@ -54,6 +54,7 @@ void convert_Tensor_select(Graph& graph) |
|
|
|
if (axis > batch_index) |
|
|
|
axis -= 1; |
|
|
|
|
|
|
|
int dim = op->params.at("dim").i; |
|
|
|
int index = op->params.at("index").i; |
|
|
|
|
|
|
|
op->params["9"] = std::vector<int> {index}; |
|
|
|
@@ -63,24 +64,26 @@ void convert_Tensor_select(Graph& graph) |
|
|
|
op->params.erase("dim"); |
|
|
|
op->params.erase("index"); |
|
|
|
|
|
|
|
// reshape for output, squeezing the select dim |
|
|
|
// squeezing the select dim |
|
|
|
{ |
|
|
|
Operand* out = op->outputs[0]; |
|
|
|
|
|
|
|
Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op); |
|
|
|
Operator* squeeze = graph.new_operator_after("torch.squeeze", op->name + "_ncnnsqueeze", op); |
|
|
|
|
|
|
|
Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in"); |
|
|
|
Operand* squeeze_in = graph.new_operand(op->name + "_ncnnsqueeze_in"); |
|
|
|
|
|
|
|
reshape->inputs.push_back(reshape_in); |
|
|
|
reshape->outputs.push_back(out); |
|
|
|
squeeze->inputs.push_back(squeeze_in); |
|
|
|
squeeze->outputs.push_back(out); |
|
|
|
|
|
|
|
op->outputs[0] = reshape_in; |
|
|
|
op->outputs[0] = squeeze_in; |
|
|
|
|
|
|
|
out->producer = reshape; |
|
|
|
reshape_in->producer = op; |
|
|
|
reshape_in->consumers.push_back(reshape); |
|
|
|
out->producer = squeeze; |
|
|
|
squeeze_in->producer = op; |
|
|
|
squeeze_in->consumers.push_back(squeeze); |
|
|
|
|
|
|
|
reshape->params["shape"] = out->shape; |
|
|
|
squeeze->params["dim"] = dim; |
|
|
|
|
|
|
|
squeeze_in->params["__batch_index"] = batch_index; |
|
|
|
} |
|
|
|
|
|
|
|
break; |
|
|
|
|