Browse Source

Add input_content and input_contents params

pull/6227/head
AtomAlpaca 11 months ago
parent
commit
4a865ae589
No known key found for this signature in database GPG Key ID: A003E4A074A5B2CA
2 changed files with 61 additions and 15 deletions
  1. +58
    -14
      tools/pnnx/src/load_torchscript.cpp
  2. +3
    -1
      tools/pnnx/src/load_torchscript.h

+ 58
- 14
tools/pnnx/src/load_torchscript.cpp View File

@@ -577,8 +577,10 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
const std::string& device, const std::string& device,
const std::vector<std::vector<int64_t> >& input_shapes, const std::vector<std::vector<int64_t> >& input_shapes,
const std::vector<std::string>& input_types, const std::vector<std::string>& input_types,
const std::vector<std::vector<char> >& input_contents,
const std::vector<std::vector<int64_t> >& input_shapes2, const std::vector<std::vector<int64_t> >& input_shapes2,
const std::vector<std::string>& input_types2, const std::vector<std::string>& input_types2,
const std::vector<std::vector<char> >& input_contents2,
const std::vector<std::string>& customop_modules, const std::vector<std::string>& customop_modules,
const std::vector<std::string>& module_operators, const std::vector<std::string>& module_operators,
const std::string& foldable_constants_zippath, const std::string& foldable_constants_zippath,
@@ -640,31 +642,73 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
} }


std::vector<at::Tensor> input_tensors; std::vector<at::Tensor> input_tensors;
for (size_t i = 0; i < traced_input_shapes.size(); i++)
if (input_contents.size() != 0)
{
for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
const std::vector<int64_t>& shape = traced_input_shapes[i];
const std::string& type = traced_input_types[i];

at::TensorOptions options(input_type_to_c10_ScalarType(type));
at::IntArrayRef shape2(shape);
at::Tensor t = torch::from_blob((void*)input_contents[i].data(), shape2, options);
if (device == "gpu")
t = t.cuda();

input_tensors.push_back(t);
}
}
else
{ {
const std::vector<int64_t>& shape = traced_input_shapes[i];
const std::string& type = traced_input_types[i];
for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
const std::vector<int64_t>& shape = traced_input_shapes[i];
const std::string& type = traced_input_types[i];


at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();
at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();


input_tensors.push_back(t);
input_tensors.push_back(t);
}
} }


std::vector<at::Tensor> input_tensors2; std::vector<at::Tensor> input_tensors2;
for (size_t i = 0; i < input_shapes2.size(); i++)
if(input_contents2.size() != 0)
{ {
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];
for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];
at::TensorOptions options(input_type_to_c10_ScalarType(type));
at::IntArrayRef shape2(shape);
at::Tensor t = torch::from_blob((void*)input_contents2[i].data(), shape2, options);
if (device == "gpu")
t = t.cuda();


at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();
input_tensors2.push_back(t);


input_tensors2.push_back(t);
}
} }
else
{
for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
std::vector<at::Tensor> input_tensors2;
for (size_t i = 0; i < input_shapes2.size(); i++)
{
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];


at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();

input_tensors2.push_back(t);
}
}
}
torch::jit::Module mod; torch::jit::Module mod;


try try


+ 3
- 1
tools/pnnx/src/load_torchscript.h View File

@@ -8,12 +8,14 @@


namespace pnnx { namespace pnnx {


int load_torchscript(const std::string& ptpath, Graph& g,
int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
const std::string& device, const std::string& device,
const std::vector<std::vector<int64_t> >& input_shapes, const std::vector<std::vector<int64_t> >& input_shapes,
const std::vector<std::string>& input_types, const std::vector<std::string>& input_types,
const std::vector<std::vector<char> >& input_contents,
const std::vector<std::vector<int64_t> >& input_shapes2, const std::vector<std::vector<int64_t> >& input_shapes2,
const std::vector<std::string>& input_types2, const std::vector<std::string>& input_types2,
const std::vector<std::vector<char> >& input_contents2,
const std::vector<std::string>& customop_modules, const std::vector<std::string>& customop_modules,
const std::vector<std::string>& module_operators, const std::vector<std::string>& module_operators,
const std::string& foldable_constants_zippath, const std::string& foldable_constants_zippath,


Loading…
Cancel
Save