Browse Source

fix pnnx slice copy shape type, inplace op link output (#4914)

tags/20230816
nihui GitHub 2 years ago
parent
commit
e02b6e8521
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 0 deletions
  1. +8
    -0
      tools/pnnx/src/pass_level0/shape_inference.cpp
  2. +1
    -0
      tools/pnnx/src/pass_level2.cpp
  3. +2
    -0
      tools/pnnx/src/pass_level5/fuse_slice_copy.cpp

+ 8
- 0
tools/pnnx/src/pass_level0/shape_inference.cpp View File

@@ -77,6 +77,14 @@ static bool value_link_output(const torch::jit::Value* v, const std::vector<torc
if (link)
return true;
}

std::string op_type = node->kind().toDisplayString();
bool is_inplace_op = op_type.size() > 2 && op_type[op_type.size() - 2] != '_' && op_type[op_type.size() - 1] == '_';
if (is_inplace_op)
{
// optimize me: track other inplace op inputs
return true;
}
}

return false;


+ 1
- 0
tools/pnnx/src/pass_level2.cpp View File

@@ -1142,6 +1142,7 @@ static void fix_inplace_copy_output(Graph& graph)
Operator* op_copy = graph.new_operator_after("aten::copy", op->name + "_copy", op);
Operand* copy_out = graph.new_operand(op->name + "_copy_out");

copy_out->type = in0->type;
copy_out->shape = in0->shape;

op_copy->inputs.push_back(op->inputs[0]);


+ 2
- 0
tools/pnnx/src/pass_level5/fuse_slice_copy.cpp View File

@@ -177,6 +177,7 @@ void fuse_slice_copy(Graph& graph)
Operator* op_clone = graph.new_operator_before("Tensor.clone", op->name + "_ncnnclone", top_sop);
Operand* clone_out = graph.new_operand(op->name + "_ncnnclone_out");

clone_out->type = top_sop->inputs[0]->type;
clone_out->shape = top_sop->inputs[0]->shape;

op_clone->inputs.push_back(top_sop->inputs[0]);
@@ -255,6 +256,7 @@ void fuse_slice_copy(Graph& graph)

op_view->params["shape"] = target_shape;

view_out->type = op->inputs[1]->type;
view_out->shape = target_shape;

op_view->inputs.push_back(op->inputs[1]);


Loading…
Cancel
Save