diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 80613dc1a..d00b9c505 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2005,6 +2005,127 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) fprintf(pyfp, "\n"); + // export onnx + { + fprintf(pyfp, "def export_onnx():\n"); + fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.eval()\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + std::vector input_names; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; + + const Operand* r = op->outputs[0]; + std::string input_name = std::string("v_") + sanitize_identifier(r->name); + if (type_is_integer(r->type)) + { + fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size() || r->shape.size() == 1) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); + } + else + { + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d, ", r->shape[i]); + } + fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); + } + + input_names.push_back(input_name); + } + + fprintf(pyfp, "\n"); + + // torch.onnx._export(net, v_0, "test_swin_t.onnx", export_params=True, opset_version=14, input_names=['in0'], output_names=['out0']) + + if (input_names.size() == 1) + { + fprintf(pyfp, " torch.onnx._export(net, %s", input_names[0].c_str()); + } + else + { + fprintf(pyfp, " torch.onnx._export(net, ("); + + for (size_t i = 0; i < input_names.size(); i++) + { + fprintf(pyfp, "%s", input_names[i].c_str()); + if (i + 1 != input_names.size()) + fprintf(pyfp, ", "); + } + + fprintf(pyfp, ")"); + } + + fprintf(pyfp, ", \"%s.onnx\", export_params=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, opset_version=13", pypath.c_str()); + + fprintf(pyfp, ", input_names=["); + { + int input_count = 0; + { + for (const Operator* op : ops) + { + if (op->type == "pnnx.Input") + input_count++; + } + } + + int input_index = 0; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; + + fprintf(pyfp, "'in%d'", input_index); + if (input_index + 1 != input_count) + fprintf(pyfp, ", "); + + input_index++; + } + } + fprintf(pyfp, "]"); + + fprintf(pyfp, ", output_names=["); + { + int output_count = 0; + { + for (const Operator* op : ops) + { + if (op->type == "pnnx.Output") + output_count++; + } + } + + int output_index = 0; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Output") + continue; + + fprintf(pyfp, "'out%d'", output_index); + if (output_index + 1 != output_count) + fprintf(pyfp, ", "); + + output_index++; + } + } + fprintf(pyfp, "]"); + + fprintf(pyfp, ")\n"); + } + + fprintf(pyfp, "\n"); + // test inference { fprintf(pyfp, "def test_inference():\n");