Browse Source

pnnx export with ptpath (#5239)

* pnnx export with ptpath

* build and test python pnnx
tags/20240102
nihui GitHub 2 years ago
parent
commit
c41aa2fdfd
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 6 deletions
  1. +9
    -0
      .ci/pnnx.yml
  2. +3
    -3
      tools/pnnx/python/README.md
  3. +2
    -3
      tools/pnnx/python/pnnx/utils/export.py

+ 9
- 0
.ci/pnnx.yml View File

@@ -127,3 +127,12 @@ jobs:
export MKL_ENABLE_INSTRUCTIONS=SSE4_2
cd tools/pnnx
cd build && ctest --output-on-failure -j 16

- name: python-pnnx
run: |
export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}}
export PNNX_WHEEL_WITHOUT_BUILD=ON
cd tools/pnnx
cp build/src/pnnx python/pnnx/
python3 setup.py install --user
pytest python/tests/

+ 3
- 3
tools/pnnx/python/README.md View File

@@ -76,8 +76,8 @@ net = models.resnet18(pretrained=True)
x = torch.rand(1, 3, 224, 224)

# You could try disabling checking when torch tracing raises error
# opt_net = pnnx.export(net, "resnet18", x, check_trace=False)
opt_net = pnnx.export(net, "resnet18", x)
# opt_net = pnnx.export(net, "resnet18.pt", x, check_trace=False)
opt_net = pnnx.export(net, "resnet18.pt", x)
```

2. convert existing model to pnnx
@@ -94,7 +94,7 @@ opt_net = pnnx.convert("resnet18.pt", x)

`model` (torch.nn.Model): model to be exported.

`filename` (str): the file name.
`ptpath` (str): the torchscript name.

`inputs` (torch.Tensor of list of torch.Tensor) expected inputs of the model.



+ 2
- 3
tools/pnnx/python/pnnx/utils/export.py View File

@@ -15,7 +15,7 @@
import torch
import os

def export(model, filename, inputs = None, inputs2 = None, input_shapes = None, input_types = None,
def export(model, ptpath, inputs = None, inputs2 = None, input_shapes = None, input_types = None,
input_shapes2 = None, input_types2 = None, device = None, customop = None,
moduleop = None, optlevel = None, pnnxparam = None, pnnxbin = None,
pnnxpy = None, pnnxonnx = None, ncnnparam = None, ncnnbin = None, ncnnpy = None,
@@ -27,8 +27,7 @@ def export(model, filename, inputs = None, inputs2 = None, input_shapes = None,

model.eval()
mod = torch.jit.trace(model, inputs, check_trace=check_trace)
mod.save(filename)
ptpath = os.path.abspath(filename)
mod.save(ptpath)

from . import convert
return convert(ptpath, inputs, inputs2, input_shapes, input_types, input_shapes2, input_types2, device, customop, moduleop, optlevel, pnnxparam, pnnxbin, pnnxpy, pnnxonnx, ncnnparam, ncnnbin, ncnnpy)

Loading…
Cancel
Save