|
|
|
@@ -483,9 +483,9 @@ def test(): |
|
|
|
ny = y.numpy() |
|
|
|
nz = z.numpy() |
|
|
|
|
|
|
|
npy.save("x.npy", nx) |
|
|
|
npy.save("y.npy", ny) |
|
|
|
npy.save("z.npy", nz) |
|
|
|
npy.save("test_pnnx_fuse_multiheadattention_x.npy", nx) |
|
|
|
npy.save("test_pnnx_fuse_multiheadattention_y.npy", ny) |
|
|
|
npy.save("test_pnnx_fuse_multiheadattention_z.npy", nz) |
|
|
|
|
|
|
|
a = net(x, y, z) |
|
|
|
|
|
|
|
@@ -495,7 +495,7 @@ def test(): |
|
|
|
|
|
|
|
# torchscript to pnnx |
|
|
|
import os |
|
|
|
os.system("../../src/pnnx test_pnnx_fuse_multiheadattention.pt input=x.npy,y.npy,z.npy") |
|
|
|
os.system("../../src/pnnx test_pnnx_fuse_multiheadattention.pt input=test_pnnx_fuse_multiheadattention_x.npy,test_pnnx_fuse_multiheadattention_y.npy,test_pnnx_fuse_multiheadattention_z.npy") |
|
|
|
|
|
|
|
# pnnx inference |
|
|
|
import test_pnnx_fuse_multiheadattention_pnnx |
|
|
|
|