You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_torch_stack.py 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright 2022 Tencent
  2. # SPDX-License-Identifier: BSD-3-Clause
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. class Model(nn.Module):
  7. def __init__(self):
  8. super(Model, self).__init__()
  9. def forward(self, x, y, z, w):
  10. out0 = torch.stack((x, y), dim=0)
  11. out1 = torch.stack((x, y), dim=2)
  12. out2 = torch.stack((z, w), dim=2)
  13. out3 = torch.stack((z, w), dim=-1)
  14. return out0, out1, out2, out3
  15. def test():
  16. net = Model()
  17. net.eval()
  18. torch.manual_seed(0)
  19. x = torch.rand(3, 16)
  20. y = torch.rand(3, 16)
  21. z = torch.rand(5, 9, 3)
  22. w = torch.rand(5, 9, 3)
  23. a = net(x, y, z, w)
  24. # export torchscript
  25. mod = torch.jit.trace(net, (x, y, z, w))
  26. mod.save("test_torch_stack.pt")
  27. # torchscript to pnnx
  28. import os
  29. os.system("../src/pnnx test_torch_stack.pt inputshape=[3,16],[3,16],[5,9,3],[5,9,3]")
  30. # pnnx inference
  31. import test_torch_stack_pnnx
  32. b = test_torch_stack_pnnx.test_inference()
  33. for a0, b0 in zip(a, b):
  34. if not torch.equal(a0, b0):
  35. return False
  36. return True
  37. if __name__ == "__main__":
  38. if test():
  39. exit(0)
  40. else:
  41. exit(1)