| @@ -2005,6 +2005,127 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| fprintf(pyfp, "\n"); | |||
| // export onnx | |||
| { | |||
| fprintf(pyfp, "def export_onnx():\n"); | |||
| fprintf(pyfp, " net = Model()\n"); | |||
| fprintf(pyfp, " net.eval()\n"); | |||
| fprintf(pyfp, "\n"); | |||
| fprintf(pyfp, " torch.manual_seed(0)\n"); | |||
| std::vector<std::string> input_names; | |||
| for (const Operator* op : ops) | |||
| { | |||
| if (op->type != "pnnx.Input") | |||
| continue; | |||
| const Operand* r = op->outputs[0]; | |||
| std::string input_name = std::string("v_") + sanitize_identifier(r->name); | |||
| if (type_is_integer(r->type)) | |||
| { | |||
| fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%d", r->shape[i]); | |||
| if (i + 1 != r->shape.size() || r->shape.size() == 1) | |||
| fprintf(pyfp, ", "); | |||
| } | |||
| fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%d, ", r->shape[i]); | |||
| } | |||
| fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); | |||
| } | |||
| input_names.push_back(input_name); | |||
| } | |||
| fprintf(pyfp, "\n"); | |||
| // torch.onnx._export(net, v_0, "test_swin_t.onnx", export_params=True, opset_version=14, input_names=['in0'], output_names=['out0']) | |||
| if (input_names.size() == 1) | |||
| { | |||
| fprintf(pyfp, " torch.onnx._export(net, %s", input_names[0].c_str()); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, " torch.onnx._export(net, ("); | |||
| for (size_t i = 0; i < input_names.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%s", input_names[i].c_str()); | |||
| if (i + 1 != input_names.size()) | |||
| fprintf(pyfp, ", "); | |||
| } | |||
| fprintf(pyfp, ")"); | |||
| } | |||
| fprintf(pyfp, ", \"%s.onnx\", export_params=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, opset_version=13", pypath.c_str()); | |||
| fprintf(pyfp, ", input_names=["); | |||
| { | |||
| int input_count = 0; | |||
| { | |||
| for (const Operator* op : ops) | |||
| { | |||
| if (op->type == "pnnx.Input") | |||
| input_count++; | |||
| } | |||
| } | |||
| int input_index = 0; | |||
| for (const Operator* op : ops) | |||
| { | |||
| if (op->type != "pnnx.Input") | |||
| continue; | |||
| fprintf(pyfp, "'in%d'", input_index); | |||
| if (input_index + 1 != input_count) | |||
| fprintf(pyfp, ", "); | |||
| input_index++; | |||
| } | |||
| } | |||
| fprintf(pyfp, "]"); | |||
| fprintf(pyfp, ", output_names=["); | |||
| { | |||
| int output_count = 0; | |||
| { | |||
| for (const Operator* op : ops) | |||
| { | |||
| if (op->type == "pnnx.Output") | |||
| output_count++; | |||
| } | |||
| } | |||
| int output_index = 0; | |||
| for (const Operator* op : ops) | |||
| { | |||
| if (op->type != "pnnx.Output") | |||
| continue; | |||
| fprintf(pyfp, "'out%d'", output_index); | |||
| if (output_index + 1 != output_count) | |||
| fprintf(pyfp, ", "); | |||
| output_index++; | |||
| } | |||
| } | |||
| fprintf(pyfp, "]"); | |||
| fprintf(pyfp, ")\n"); | |||
| } | |||
| fprintf(pyfp, "\n"); | |||
| // test inference | |||
| { | |||
| fprintf(pyfp, "def test_inference():\n"); | |||