You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_Tensor_repeat.py 1.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright 2021 Tencent
  2. # SPDX-License-Identifier: BSD-3-Clause
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. class Model(nn.Module):
  7. def __init__(self):
  8. super(Model, self).__init__()
  9. def forward(self, x, y, z):
  10. x = x.repeat(1, 2, 3)
  11. x = x.repeat(2, 3, 4)
  12. y = y.repeat(1, 2, 1, 4)
  13. y = y.repeat(3, 4, 5, 1)
  14. z = z.repeat(1, 2, 3, 1, 5)
  15. z = z.repeat(2, 3, 3, 1, 1)
  16. return x, y, z
  17. def test():
  18. net = Model()
  19. net.eval()
  20. torch.manual_seed(0)
  21. x = torch.rand(1, 3, 16)
  22. y = torch.rand(1, 5, 9, 11)
  23. z = torch.rand(14, 8, 5, 9, 10)
  24. a = net(x, y, z)
  25. # export torchscript
  26. mod = torch.jit.trace(net, (x, y, z))
  27. mod.save("test_Tensor_repeat.pt")
  28. # torchscript to pnnx
  29. import os
  30. os.system("../src/pnnx test_Tensor_repeat.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]")
  31. # pnnx inference
  32. import test_Tensor_repeat_pnnx
  33. b = test_Tensor_repeat_pnnx.test_inference()
  34. for a0, b0 in zip(a, b):
  35. if not torch.equal(a0, b0):
  36. return False
  37. return True
  38. if __name__ == "__main__":
  39. if test():
  40. exit(0)
  41. else:
  42. exit(1)