diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index c8319d130..e7c2d759e 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -77,6 +77,14 @@ static bool value_link_output(const torch::jit::Value* v, const std::vectorkind().toDisplayString(); + bool is_inplace_op = op_type.size() > 2 && op_type[op_type.size() - 2] != '_' && op_type[op_type.size() - 1] == '_'; + if (is_inplace_op) + { + // optimize me: track other inplace op inputs + return true; + } } return false; diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index 41f366d5e..d951498e7 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -1142,6 +1142,7 @@ static void fix_inplace_copy_output(Graph& graph) Operator* op_copy = graph.new_operator_after("aten::copy", op->name + "_copy", op); Operand* copy_out = graph.new_operand(op->name + "_copy_out"); + copy_out->type = in0->type; copy_out->shape = in0->shape; op_copy->inputs.push_back(op->inputs[0]); diff --git a/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp b/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp index 0a5fabab7..ce5a06710 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp +++ b/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp @@ -177,6 +177,7 @@ void fuse_slice_copy(Graph& graph) Operator* op_clone = graph.new_operator_before("Tensor.clone", op->name + "_ncnnclone", top_sop); Operand* clone_out = graph.new_operand(op->name + "_ncnnclone_out"); + clone_out->type = top_sop->inputs[0]->type; clone_out->shape = top_sop->inputs[0]->shape; op_clone->inputs.push_back(top_sop->inputs[0]); @@ -255,6 +256,7 @@ void fuse_slice_copy(Graph& graph) op_view->params["shape"] = target_shape; + view_out->type = op->inputs[1]->type; view_out->shape = target_shape; op_view->inputs.push_back(op->inputs[1]);