|
|
|
@@ -15,7 +15,7 @@ |
|
|
|
import torch |
|
|
|
import os |
|
|
|
|
|
|
|
def export(model, filename, inputs = None, inputs2 = None, input_shapes = None, input_types = None, |
|
|
|
def export(model, ptpath, inputs = None, inputs2 = None, input_shapes = None, input_types = None, |
|
|
|
input_shapes2 = None, input_types2 = None, device = None, customop = None, |
|
|
|
moduleop = None, optlevel = None, pnnxparam = None, pnnxbin = None, |
|
|
|
pnnxpy = None, pnnxonnx = None, ncnnparam = None, ncnnbin = None, ncnnpy = None, |
|
|
|
@@ -27,8 +27,7 @@ def export(model, filename, inputs = None, inputs2 = None, input_shapes = None, |
|
|
|
|
|
|
|
model.eval() |
|
|
|
mod = torch.jit.trace(model, inputs, check_trace=check_trace) |
|
|
|
mod.save(filename) |
|
|
|
ptpath = os.path.abspath(filename) |
|
|
|
mod.save(ptpath) |
|
|
|
|
|
|
|
from . import convert |
|
|
|
return convert(ptpath, inputs, inputs2, input_shapes, input_types, input_shapes2, input_types2, device, customop, moduleop, optlevel, pnnxparam, pnnxbin, pnnxpy, pnnxonnx, ncnnparam, ncnnbin, ncnnpy) |