From c41aa2fdfdeb80753b02b0ce80b8a0271d49fe1e Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 20 Dec 2023 16:46:38 +0800 Subject: [PATCH] pnnx export with ptpath (#5239) * pnnx export with ptpath * build and test python pnnx --- .ci/pnnx.yml | 9 +++++++++ tools/pnnx/python/README.md | 6 +++--- tools/pnnx/python/pnnx/utils/export.py | 5 ++--- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 3f116a4fa..16602627c 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -127,3 +127,12 @@ jobs: export MKL_ENABLE_INSTRUCTIONS=SSE4_2 cd tools/pnnx cd build && ctest --output-on-failure -j 16 + + - name: python-pnnx + run: | + export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}} + export PNNX_WHEEL_WITHOUT_BUILD=ON + cd tools/pnnx + cp build/src/pnnx python/pnnx/ + python3 setup.py install --user + pytest python/tests/ diff --git a/tools/pnnx/python/README.md b/tools/pnnx/python/README.md index 8b8ce4691..964d8fa53 100644 --- a/tools/pnnx/python/README.md +++ b/tools/pnnx/python/README.md @@ -76,8 +76,8 @@ net = models.resnet18(pretrained=True) x = torch.rand(1, 3, 224, 224) # You could try disabling checking when torch tracing raises error -# opt_net = pnnx.export(net, "resnet18", x, check_trace=False) -opt_net = pnnx.export(net, "resnet18", x) +# opt_net = pnnx.export(net, "resnet18.pt", x, check_trace=False) +opt_net = pnnx.export(net, "resnet18.pt", x) ``` 2. convert existing model to pnnx @@ -94,7 +94,7 @@ opt_net = pnnx.convert("resnet18.pt", x) `model` (torch.nn.Model): model to be exported. -`filename` (str): the file name. +`ptpath` (str): the torchscript name. `inputs` (torch.Tensor of list of torch.Tensor) expected inputs of the model. diff --git a/tools/pnnx/python/pnnx/utils/export.py b/tools/pnnx/python/pnnx/utils/export.py index f6dc1a446..3a04dae73 100644 --- a/tools/pnnx/python/pnnx/utils/export.py +++ b/tools/pnnx/python/pnnx/utils/export.py @@ -15,7 +15,7 @@ import torch import os -def export(model, filename, inputs = None, inputs2 = None, input_shapes = None, input_types = None, +def export(model, ptpath, inputs = None, inputs2 = None, input_shapes = None, input_types = None, input_shapes2 = None, input_types2 = None, device = None, customop = None, moduleop = None, optlevel = None, pnnxparam = None, pnnxbin = None, pnnxpy = None, pnnxonnx = None, ncnnparam = None, ncnnbin = None, ncnnpy = None, @@ -27,8 +27,7 @@ def export(model, filename, inputs = None, inputs2 = None, input_shapes = None, model.eval() mod = torch.jit.trace(model, inputs, check_trace=check_trace) - mod.save(filename) - ptpath = os.path.abspath(filename) + mod.save(ptpath) from . import convert return convert(ptpath, inputs, inputs2, input_shapes, input_types, input_shapes2, input_types2, device, customop, moduleop, optlevel, pnnxparam, pnnxbin, pnnxpy, pnnxonnx, ncnnparam, ncnnbin, ncnnpy)