You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import unittest
  2. import os
  3. os.environ["log_silent"] = "1"
  4. import torch
  5. import paddle
  6. import jittor
  7. from fastNLP.modules.mix_modules.utils import (
  8. paddle2torch,
  9. torch2paddle,
  10. jittor2torch,
  11. torch2jittor,
  12. )
  13. ############################################################################
  14. #
  15. # 测试paddle到torch的转换
  16. #
  17. ############################################################################
  18. class Paddle2TorchTestCase(unittest.TestCase):
  19. def check_torch_tensor(self, tensor, device, requires_grad):
  20. """
  21. 检查张量设备和梯度情况的工具函数
  22. """
  23. self.assertIsInstance(tensor, torch.Tensor)
  24. self.assertEqual(tensor.device, torch.device(device))
  25. self.assertEqual(tensor.requires_grad, requires_grad)
  26. def test_gradient(self):
  27. """
  28. 测试张量转换后的反向传播是否正确
  29. """
  30. x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0, 5.0], stop_gradient=False)
  31. y = paddle2torch(x)
  32. z = 3 * (y ** 2)
  33. z.sum().backward()
  34. self.assertListEqual(y.grad.tolist(), [6, 12, 18, 24, 30])
  35. def test_tensor_transfer(self):
  36. """
  37. 测试单个张量的设备和梯度转换是否正确
  38. """
  39. paddle_tensor = paddle.rand((3, 4, 5)).cpu()
  40. res = paddle2torch(paddle_tensor)
  41. self.check_torch_tensor(res, "cpu", not paddle_tensor.stop_gradient)
  42. res = paddle2torch(paddle_tensor, target_device="cuda:2", no_gradient=None)
  43. self.check_torch_tensor(res, "cuda:2", not paddle_tensor.stop_gradient)
  44. res = paddle2torch(paddle_tensor, target_device="cuda:1", no_gradient=True)
  45. self.check_torch_tensor(res, "cuda:1", False)
  46. res = paddle2torch(paddle_tensor, target_device="cuda:1", no_gradient=False)
  47. self.check_torch_tensor(res, "cuda:1", True)
  48. def test_list_transfer(self):
  49. """
  50. 测试张量列表的转换
  51. """
  52. paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)]
  53. res = paddle2torch(paddle_list)
  54. self.assertIsInstance(res, list)
  55. for t in res:
  56. self.check_torch_tensor(t, "cuda:1", False)
  57. res = paddle2torch(paddle_list, target_device="cpu", no_gradient=False)
  58. self.assertIsInstance(res, list)
  59. for t in res:
  60. self.check_torch_tensor(t, "cpu", True)
  61. def test_tensor_tuple_transfer(self):
  62. """
  63. 测试张量元组的转换
  64. """
  65. paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)]
  66. paddle_tuple = tuple(paddle_list)
  67. res = paddle2torch(paddle_tuple)
  68. self.assertIsInstance(res, tuple)
  69. for t in res:
  70. self.check_torch_tensor(t, "cuda:1", False)
  71. def test_dict_transfer(self):
  72. """
  73. 测试包含复杂结构的字典的转换
  74. """
  75. paddle_dict = {
  76. "tensor": paddle.rand((3, 4)).cuda(0),
  77. "list": [paddle.rand((6, 4, 2)).cuda(0) for i in range(10)],
  78. "dict":{
  79. "list": [paddle.rand((6, 4, 2)).cuda(0) for i in range(10)],
  80. "tensor": paddle.rand((3, 4)).cuda(0)
  81. },
  82. "int": 2,
  83. "string": "test string"
  84. }
  85. res = paddle2torch(paddle_dict)
  86. self.assertIsInstance(res, dict)
  87. self.check_torch_tensor(res["tensor"], "cuda:0", False)
  88. self.assertIsInstance(res["list"], list)
  89. for t in res["list"]:
  90. self.check_torch_tensor(t, "cuda:0", False)
  91. self.assertIsInstance(res["int"], int)
  92. self.assertIsInstance(res["string"], str)
  93. self.assertIsInstance(res["dict"], dict)
  94. self.assertIsInstance(res["dict"]["list"], list)
  95. for t in res["dict"]["list"]:
  96. self.check_torch_tensor(t, "cuda:0", False)
  97. self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", False)
  98. ############################################################################
  99. #
  100. # 测试torch到paddle的转换
  101. #
  102. ############################################################################
  103. class Torch2PaddleTestCase(unittest.TestCase):
  104. def check_paddle_tensor(self, tensor, device, stop_gradient):
  105. """
  106. 检查得到的paddle张量设备和梯度情况的工具函数
  107. """
  108. self.assertIsInstance(tensor, paddle.Tensor)
  109. if device == "cpu":
  110. self.assertTrue(tensor.place.is_cpu_place())
  111. elif device.startswith("gpu"):
  112. paddle_device = paddle.device._convert_to_place(device)
  113. self.assertTrue(tensor.place.is_gpu_place())
  114. if hasattr(tensor.place, "gpu_device_id"):
  115. # paddle中,有两种Place
  116. # paddle.fluid.core.Place是创建Tensor时使用的类型
  117. # 有函数gpu_device_id获取设备
  118. self.assertEqual(tensor.place.gpu_device_id(), paddle_device.get_device_id())
  119. else:
  120. # 通过_convert_to_place得到的是paddle.CUDAPlace
  121. # 通过get_device_id获取设备
  122. self.assertEqual(tensor.place.get_device_id(), paddle_device.get_device_id())
  123. else:
  124. raise NotImplementedError
  125. self.assertEqual(tensor.stop_gradient, stop_gradient)
  126. def test_gradient(self):
  127. """
  128. 测试转换后梯度的反向传播
  129. """
  130. x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
  131. y = torch2paddle(x)
  132. z = 3 * (y ** 2)
  133. z.sum().backward()
  134. self.assertListEqual(y.grad.tolist(), [6, 12, 18, 24, 30])
  135. def test_tensor_transfer(self):
  136. """
  137. 测试单个张量的转换
  138. """
  139. torch_tensor = torch.rand((3, 4, 5))
  140. res = torch2paddle(torch_tensor)
  141. self.check_paddle_tensor(res, "cpu", True)
  142. res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=None)
  143. self.check_paddle_tensor(res, "gpu:2", True)
  144. res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=True)
  145. self.check_paddle_tensor(res, "gpu:2", True)
  146. res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=False)
  147. self.check_paddle_tensor(res, "gpu:2", False)
  148. def test_tensor_list_transfer(self):
  149. """
  150. 测试张量列表的转换
  151. """
  152. torch_list = [torch.rand(6, 4, 2) for i in range(10)]
  153. res = torch2paddle(torch_list)
  154. self.assertIsInstance(res, list)
  155. for t in res:
  156. self.check_paddle_tensor(t, "cpu", True)
  157. res = torch2paddle(torch_list, target_device="gpu:1", no_gradient=False)
  158. self.assertIsInstance(res, list)
  159. for t in res:
  160. self.check_paddle_tensor(t, "gpu:1", False)
  161. def test_tensor_tuple_transfer(self):
  162. """
  163. 测试张量元组的转换
  164. """
  165. torch_list = [torch.rand(6, 4, 2) for i in range(10)]
  166. torch_tuple = tuple(torch_list)
  167. res = torch2paddle(torch_tuple, target_device="cpu")
  168. self.assertIsInstance(res, tuple)
  169. for t in res:
  170. self.check_paddle_tensor(t, "cpu", True)
  171. def test_dict_transfer(self):
  172. """
  173. 测试复杂的字典结构的转换
  174. """
  175. torch_dict = {
  176. "tensor": torch.rand((3, 4)),
  177. "list": [torch.rand(6, 4, 2) for i in range(10)],
  178. "dict":{
  179. "list": [torch.rand(6, 4, 2) for i in range(10)],
  180. "tensor": torch.rand((3, 4))
  181. },
  182. "int": 2,
  183. "string": "test string"
  184. }
  185. res = torch2paddle(torch_dict)
  186. self.assertIsInstance(res, dict)
  187. self.check_paddle_tensor(res["tensor"], "cpu", True)
  188. self.assertIsInstance(res["list"], list)
  189. for t in res["list"]:
  190. self.check_paddle_tensor(t, "cpu", True)
  191. self.assertIsInstance(res["int"], int)
  192. self.assertIsInstance(res["string"], str)
  193. self.assertIsInstance(res["dict"], dict)
  194. self.assertIsInstance(res["dict"]["list"], list)
  195. for t in res["dict"]["list"]:
  196. self.check_paddle_tensor(t, "cpu", True)
  197. self.check_paddle_tensor(res["dict"]["tensor"], "cpu", True)
  198. ############################################################################
  199. #
  200. # 测试jittor到torch的转换
  201. #
  202. ############################################################################
  203. class Jittor2TorchTestCase(unittest.TestCase):
  204. def check_torch_tensor(self, tensor, device, requires_grad):
  205. """
  206. 检查得到的torch张量的工具函数
  207. """
  208. self.assertIsInstance(tensor, torch.Tensor)
  209. if device == "cpu":
  210. self.assertFalse(tensor.is_cuda)
  211. else:
  212. self.assertEqual(tensor.device, torch.device(device))
  213. self.assertEqual(tensor.requires_grad, requires_grad)
  214. def test_var_transfer(self):
  215. """
  216. 测试单个Jittor Var的转换
  217. """
  218. jittor_var = jittor.rand((3, 4, 5))
  219. res = jittor2torch(jittor_var)
  220. self.check_torch_tensor(res, "cpu", True)
  221. res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=None)
  222. self.check_torch_tensor(res, "cuda:2", True)
  223. res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=True)
  224. self.check_torch_tensor(res, "cuda:2", False)
  225. res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=False)
  226. self.check_torch_tensor(res, "cuda:2", True)
  227. def test_var_list_transfer(self):
  228. """
  229. 测试Jittor列表的转换
  230. """
  231. jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)]
  232. res = jittor2torch(jittor_list)
  233. self.assertIsInstance(res, list)
  234. for t in res:
  235. self.check_torch_tensor(t, "cpu", True)
  236. res = jittor2torch(jittor_list, target_device="cuda:1", no_gradient=False)
  237. self.assertIsInstance(res, list)
  238. for t in res:
  239. self.check_torch_tensor(t, "cuda:1", True)
  240. def test_var_tuple_transfer(self):
  241. """
  242. 测试Jittor变量元组的转换
  243. """
  244. jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)]
  245. jittor_tuple = tuple(jittor_list)
  246. res = jittor2torch(jittor_tuple, target_device="cpu")
  247. self.assertIsInstance(res, tuple)
  248. for t in res:
  249. self.check_torch_tensor(t, "cpu", True)
  250. def test_dict_transfer(self):
  251. """
  252. 测试字典结构的转换
  253. """
  254. jittor_dict = {
  255. "tensor": jittor.rand((3, 4)),
  256. "list": [jittor.rand(6, 4, 2) for i in range(10)],
  257. "dict":{
  258. "list": [jittor.rand(6, 4, 2) for i in range(10)],
  259. "tensor": jittor.rand((3, 4))
  260. },
  261. "int": 2,
  262. "string": "test string"
  263. }
  264. res = jittor2torch(jittor_dict)
  265. self.assertIsInstance(res, dict)
  266. self.check_torch_tensor(res["tensor"], "cpu", True)
  267. self.assertIsInstance(res["list"], list)
  268. for t in res["list"]:
  269. self.check_torch_tensor(t, "cpu", True)
  270. self.assertIsInstance(res["int"], int)
  271. self.assertIsInstance(res["string"], str)
  272. self.assertIsInstance(res["dict"], dict)
  273. self.assertIsInstance(res["dict"]["list"], list)
  274. for t in res["dict"]["list"]:
  275. self.check_torch_tensor(t, "cpu", True)
  276. self.check_torch_tensor(res["dict"]["tensor"], "cpu", True)
  277. ############################################################################
  278. #
  279. # 测试torch到jittor的转换
  280. #
  281. ############################################################################
  282. class Torch2JittorTestCase(unittest.TestCase):
  283. def check_jittor_var(self, var, requires_grad):
  284. """
  285. 检查得到的Jittor Var梯度情况的工具函数
  286. """
  287. self.assertIsInstance(var, jittor.Var)
  288. self.assertEqual(var.requires_grad, requires_grad)
  289. def test_gradient(self):
  290. """
  291. 测试反向传播的梯度
  292. """
  293. x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
  294. y = torch2jittor(x)
  295. z = 3 * (y ** 2)
  296. grad = jittor.grad(z, y)
  297. self.assertListEqual(grad.tolist(), [6.0, 12.0, 18.0, 24.0, 30.0])
  298. def test_tensor_transfer(self):
  299. """
  300. 测试单个张量转换为Jittor
  301. """
  302. torch_tensor = torch.rand((3, 4, 5))
  303. res = torch2jittor(torch_tensor)
  304. self.check_jittor_var(res, False)
  305. res = torch2jittor(torch_tensor, no_gradient=None)
  306. self.check_jittor_var(res, False)
  307. res = torch2jittor(torch_tensor, no_gradient=True)
  308. self.check_jittor_var(res, False)
  309. res = torch2jittor(torch_tensor, no_gradient=False)
  310. self.check_jittor_var(res, True)
  311. def test_tensor_list_transfer(self):
  312. """
  313. 测试张量列表的转换
  314. """
  315. torch_list = [torch.rand((6, 4, 2)) for i in range(10)]
  316. res = torch2jittor(torch_list)
  317. self.assertIsInstance(res, list)
  318. for t in res:
  319. self.check_jittor_var(t, False)
  320. res = torch2jittor(torch_list, no_gradient=False)
  321. self.assertIsInstance(res, list)
  322. for t in res:
  323. self.check_jittor_var(t, True)
  324. def test_tensor_tuple_transfer(self):
  325. """
  326. 测试张量元组的转换
  327. """
  328. torch_list = [torch.rand((6, 4, 2)) for i in range(10)]
  329. torch_tuple = tuple(torch_list)
  330. res = torch2jittor(torch_tuple)
  331. self.assertIsInstance(res, tuple)
  332. for t in res:
  333. self.check_jittor_var(t, False)
  334. def test_dict_transfer(self):
  335. """
  336. 测试字典结构的转换
  337. """
  338. torch_dict = {
  339. "tensor": torch.rand((3, 4)),
  340. "list": [torch.rand(6, 4, 2) for i in range(10)],
  341. "dict":{
  342. "list": [torch.rand(6, 4, 2) for i in range(10)],
  343. "tensor": torch.rand((3, 4))
  344. },
  345. "int": 2,
  346. "string": "test string"
  347. }
  348. res = torch2jittor(torch_dict)
  349. self.assertIsInstance(res, dict)
  350. self.check_jittor_var(res["tensor"], False)
  351. self.assertIsInstance(res["list"], list)
  352. for t in res["list"]:
  353. self.check_jittor_var(t, False)
  354. self.assertIsInstance(res["int"], int)
  355. self.assertIsInstance(res["string"], str)
  356. self.assertIsInstance(res["dict"], dict)
  357. self.assertIsInstance(res["dict"]["list"], list)
  358. for t in res["dict"]["list"]:
  359. self.check_jittor_var(t, False)
  360. self.check_jittor_var(res["dict"]["tensor"], False)