Browse Source

pnnx export_onnx function (#3784)

tags/20220701
nihui GitHub 4 years ago
parent
commit
d476191ff1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 121 additions and 0 deletions
  1. +121
    -0
      tools/pnnx/src/ir.cpp

+ 121
- 0
tools/pnnx/src/ir.cpp View File

@@ -2005,6 +2005,127 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)

fprintf(pyfp, "\n");

// export onnx
{
fprintf(pyfp, "def export_onnx():\n");
fprintf(pyfp, " net = Model()\n");
fprintf(pyfp, " net.eval()\n");
fprintf(pyfp, "\n");
fprintf(pyfp, " torch.manual_seed(0)\n");

std::vector<std::string> input_names;
for (const Operator* op : ops)
{
if (op->type != "pnnx.Input")
continue;

const Operand* r = op->outputs[0];
std::string input_name = std::string("v_") + sanitize_identifier(r->name);
if (type_is_integer(r->type))
{
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}

input_names.push_back(input_name);
}

fprintf(pyfp, "\n");

// torch.onnx._export(net, v_0, "test_swin_t.onnx", export_params=True, opset_version=14, input_names=['in0'], output_names=['out0'])

if (input_names.size() == 1)
{
fprintf(pyfp, " torch.onnx._export(net, %s", input_names[0].c_str());
}
else
{
fprintf(pyfp, " torch.onnx._export(net, (");

for (size_t i = 0; i < input_names.size(); i++)
{
fprintf(pyfp, "%s", input_names[i].c_str());
if (i + 1 != input_names.size())
fprintf(pyfp, ", ");
}

fprintf(pyfp, ")");
}

fprintf(pyfp, ", \"%s.onnx\", export_params=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, opset_version=13", pypath.c_str());

fprintf(pyfp, ", input_names=[");
{
int input_count = 0;
{
for (const Operator* op : ops)
{
if (op->type == "pnnx.Input")
input_count++;
}
}

int input_index = 0;
for (const Operator* op : ops)
{
if (op->type != "pnnx.Input")
continue;

fprintf(pyfp, "'in%d'", input_index);
if (input_index + 1 != input_count)
fprintf(pyfp, ", ");

input_index++;
}
}
fprintf(pyfp, "]");

fprintf(pyfp, ", output_names=[");
{
int output_count = 0;
{
for (const Operator* op : ops)
{
if (op->type == "pnnx.Output")
output_count++;
}
}

int output_index = 0;
for (const Operator* op : ops)
{
if (op->type != "pnnx.Output")
continue;

fprintf(pyfp, "'out%d'", output_index);
if (output_index + 1 != output_count)
fprintf(pyfp, ", ");

output_index++;
}
}
fprintf(pyfp, "]");

fprintf(pyfp, ")\n");
}

fprintf(pyfp, "\n");

// test inference
{
fprintf(pyfp, "def test_inference():\n");


Loading…
Cancel
Save