From 4a865ae589d840d64d0793fc593ea493acadca48 Mon Sep 17 00:00:00 2001 From: AtomAlpaca Date: Sat, 2 Aug 2025 12:52:55 +0800 Subject: [PATCH] Add input_content and input_contents params --- tools/pnnx/src/load_torchscript.cpp | 72 +++++++++++++++++++++++------ tools/pnnx/src/load_torchscript.h | 4 +- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/tools/pnnx/src/load_torchscript.cpp b/tools/pnnx/src/load_torchscript.cpp index 01fa3937a..7e330c7ca 100644 --- a/tools/pnnx/src/load_torchscript.cpp +++ b/tools/pnnx/src/load_torchscript.cpp @@ -577,8 +577,10 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph, const std::string& device, const std::vector >& input_shapes, const std::vector& input_types, + const std::vector >& input_contents, const std::vector >& input_shapes2, const std::vector& input_types2, + const std::vector >& input_contents2, const std::vector& customop_modules, const std::vector& module_operators, const std::string& foldable_constants_zippath, @@ -640,31 +642,73 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph, } std::vector input_tensors; - for (size_t i = 0; i < traced_input_shapes.size(); i++) + if (input_contents.size() != 0) + { + for (size_t i = 0; i < traced_input_shapes.size(); i++) + { + const std::vector& shape = traced_input_shapes[i]; + const std::string& type = traced_input_types[i]; + + at::TensorOptions options(input_type_to_c10_ScalarType(type)); + at::IntArrayRef shape2(shape); + at::Tensor t = torch::from_blob((void*)input_contents[i].data(), shape2, options); + if (device == "gpu") + t = t.cuda(); + + input_tensors.push_back(t); + } + } + else { - const std::vector& shape = traced_input_shapes[i]; - const std::string& type = traced_input_types[i]; + for (size_t i = 0; i < traced_input_shapes.size(); i++) + { + const std::vector& shape = traced_input_shapes[i]; + const std::string& type = traced_input_types[i]; - at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type)); - if (device == "gpu") - t = t.cuda(); + at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type)); + if (device == "gpu") + t = t.cuda(); - input_tensors.push_back(t); + input_tensors.push_back(t); + } } std::vector input_tensors2; - for (size_t i = 0; i < input_shapes2.size(); i++) + if(input_contents2.size() != 0) { - const std::vector& shape = input_shapes2[i]; - const std::string& type = input_types2[i]; + for (size_t i = 0; i < traced_input_shapes.size(); i++) + { + const std::vector& shape = input_shapes2[i]; + const std::string& type = input_types2[i]; + + at::TensorOptions options(input_type_to_c10_ScalarType(type)); + at::IntArrayRef shape2(shape); + at::Tensor t = torch::from_blob((void*)input_contents2[i].data(), shape2, options); + if (device == "gpu") + t = t.cuda(); - at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type)); - if (device == "gpu") - t = t.cuda(); + input_tensors2.push_back(t); - input_tensors2.push_back(t); + } } + else + { + for (size_t i = 0; i < traced_input_shapes.size(); i++) + { + std::vector input_tensors2; + for (size_t i = 0; i < input_shapes2.size(); i++) + { + const std::vector& shape = input_shapes2[i]; + const std::string& type = input_types2[i]; + at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type)); + if (device == "gpu") + t = t.cuda(); + + input_tensors2.push_back(t); + } + } + } torch::jit::Module mod; try diff --git a/tools/pnnx/src/load_torchscript.h b/tools/pnnx/src/load_torchscript.h index a973441a4..4fe1e63aa 100644 --- a/tools/pnnx/src/load_torchscript.h +++ b/tools/pnnx/src/load_torchscript.h @@ -8,12 +8,14 @@ namespace pnnx { -int load_torchscript(const std::string& ptpath, Graph& g, +int load_torchscript(const std::string& ptpath, Graph& pnnx_graph, const std::string& device, const std::vector >& input_shapes, const std::vector& input_types, + const std::vector >& input_contents, const std::vector >& input_shapes2, const std::vector& input_types2, + const std::vector >& input_contents2, const std::vector& customop_modules, const std::vector& module_operators, const std::string& foldable_constants_zippath,