You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. import pytest
  2. from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
  3. from fastNLP.modules.mix_modules.utils import (
  4. paddle2torch,
  5. torch2paddle,
  6. jittor2torch,
  7. torch2jittor,
  8. )
  9. if _NEED_IMPORT_TORCH:
  10. import torch
  11. if _NEED_IMPORT_PADDLE:
  12. import paddle
  13. if _NEED_IMPORT_JITTOR:
  14. import jittor
  15. ############################################################################
  16. #
  17. # 测试paddle到torch的转换
  18. #
  19. ############################################################################
  20. @pytest.mark.torchpaddle
  21. class TestPaddle2Torch:
  22. def check_torch_tensor(self, tensor, device, requires_grad):
  23. """
  24. 检查张量设备和梯度情况的工具函数
  25. """
  26. assert isinstance(tensor, torch.Tensor)
  27. assert tensor.device == torch.device(device)
  28. assert tensor.requires_grad == requires_grad
  29. def test_gradient(self):
  30. """
  31. 测试张量转换后的反向传播是否正确
  32. """
  33. x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0, 5.0], stop_gradient=False)
  34. y = paddle2torch(x)
  35. z = 3 * (y ** 2)
  36. z.sum().backward()
  37. assert y.grad.tolist() == [6, 12, 18, 24, 30]
  38. def test_tensor_transfer(self):
  39. """
  40. 测试单个张量的设备和梯度转换是否正确
  41. """
  42. paddle_tensor = paddle.rand((3, 4, 5)).cpu()
  43. res = paddle2torch(paddle_tensor)
  44. self.check_torch_tensor(res, "cpu", not paddle_tensor.stop_gradient)
  45. res = paddle2torch(paddle_tensor, device="cuda:2", no_gradient=None)
  46. self.check_torch_tensor(res, "cuda:2", not paddle_tensor.stop_gradient)
  47. res = paddle2torch(paddle_tensor, device="cuda:1", no_gradient=True)
  48. self.check_torch_tensor(res, "cuda:1", False)
  49. res = paddle2torch(paddle_tensor, device="cuda:1", no_gradient=False)
  50. self.check_torch_tensor(res, "cuda:1", True)
  51. def test_list_transfer(self):
  52. """
  53. 测试张量列表的转换
  54. """
  55. paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)]
  56. res = paddle2torch(paddle_list)
  57. assert isinstance(res, list)
  58. for t in res:
  59. self.check_torch_tensor(t, "cuda:1", False)
  60. res = paddle2torch(paddle_list, device="cpu", no_gradient=False)
  61. assert isinstance(res, list)
  62. for t in res:
  63. self.check_torch_tensor(t, "cpu", True)
  64. def test_tensor_tuple_transfer(self):
  65. """
  66. 测试张量元组的转换
  67. """
  68. paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)]
  69. paddle_tuple = tuple(paddle_list)
  70. res = paddle2torch(paddle_tuple)
  71. assert isinstance(res, tuple)
  72. for t in res:
  73. self.check_torch_tensor(t, "cuda:1", False)
  74. def test_dict_transfer(self):
  75. """
  76. 测试包含复杂结构的字典的转换
  77. """
  78. paddle_dict = {
  79. "tensor": paddle.rand((3, 4)).cuda(0),
  80. "list": [paddle.rand((6, 4, 2)).cuda(0) for i in range(10)],
  81. "dict":{
  82. "list": [paddle.rand((6, 4, 2)).cuda(0) for i in range(10)],
  83. "tensor": paddle.rand((3, 4)).cuda(0)
  84. },
  85. "int": 2,
  86. "string": "test string"
  87. }
  88. res = paddle2torch(paddle_dict)
  89. assert isinstance(res, dict)
  90. self.check_torch_tensor(res["tensor"], "cuda:0", False)
  91. assert isinstance(res["list"], list)
  92. for t in res["list"]:
  93. self.check_torch_tensor(t, "cuda:0", False)
  94. assert isinstance(res["int"], int)
  95. assert isinstance(res["string"], str)
  96. assert isinstance(res["dict"], dict)
  97. assert isinstance(res["dict"]["list"], list)
  98. for t in res["dict"]["list"]:
  99. self.check_torch_tensor(t, "cuda:0", False)
  100. self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", False)
  101. ############################################################################
  102. #
  103. # 测试torch到paddle的转换
  104. #
  105. ############################################################################
  106. @pytest.mark.torchpaddle
  107. class TestTorch2Paddle:
  108. def check_paddle_tensor(self, tensor, device, stop_gradient):
  109. """
  110. 检查得到的paddle张量设备和梯度情况的工具函数
  111. """
  112. assert isinstance(tensor, paddle.Tensor)
  113. if device == "cpu":
  114. assert tensor.place.is_cpu_place()
  115. elif device.startswith("gpu"):
  116. paddle_device = paddle.device._convert_to_place(device)
  117. assert tensor.place.is_gpu_place()
  118. if hasattr(tensor.place, "gpu_device_id"):
  119. # paddle中,有两种Place
  120. # paddle.fluid.core.Place是创建Tensor时使用的类型
  121. # 有函数gpu_device_id获取设备
  122. assert tensor.place.gpu_device_id() == paddle_device.get_device_id()
  123. else:
  124. # 通过_convert_to_place得到的是paddle.CUDAPlace
  125. # 通过get_device_id获取设备
  126. assert tensor.place.get_device_id() == paddle_device.get_device_id()
  127. else:
  128. raise NotImplementedError
  129. assert tensor.stop_gradient == stop_gradient
  130. def test_gradient(self):
  131. """
  132. 测试转换后梯度的反向传播
  133. """
  134. x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
  135. y = torch2paddle(x)
  136. z = 3 * (y ** 2)
  137. z.sum().backward()
  138. assert y.grad.tolist() == [6, 12, 18, 24, 30]
  139. def test_tensor_transfer(self):
  140. """
  141. 测试单个张量的转换
  142. """
  143. torch_tensor = torch.rand((3, 4, 5))
  144. res = torch2paddle(torch_tensor)
  145. self.check_paddle_tensor(res, "cpu", True)
  146. res = torch2paddle(torch_tensor, device="gpu:2", no_gradient=None)
  147. self.check_paddle_tensor(res, "gpu:2", True)
  148. res = torch2paddle(torch_tensor, device="gpu:2", no_gradient=True)
  149. self.check_paddle_tensor(res, "gpu:2", True)
  150. res = torch2paddle(torch_tensor, device="gpu:2", no_gradient=False)
  151. self.check_paddle_tensor(res, "gpu:2", False)
  152. def test_tensor_list_transfer(self):
  153. """
  154. 测试张量列表的转换
  155. """
  156. torch_list = [torch.rand(6, 4, 2) for i in range(10)]
  157. res = torch2paddle(torch_list)
  158. assert isinstance(res, list)
  159. for t in res:
  160. self.check_paddle_tensor(t, "cpu", True)
  161. res = torch2paddle(torch_list, device="gpu:1", no_gradient=False)
  162. assert isinstance(res, list)
  163. for t in res:
  164. self.check_paddle_tensor(t, "gpu:1", False)
  165. def test_tensor_tuple_transfer(self):
  166. """
  167. 测试张量元组的转换
  168. """
  169. torch_list = [torch.rand(6, 4, 2) for i in range(10)]
  170. torch_tuple = tuple(torch_list)
  171. res = torch2paddle(torch_tuple, device="cpu")
  172. assert isinstance(res, tuple)
  173. for t in res:
  174. self.check_paddle_tensor(t, "cpu", True)
  175. def test_dict_transfer(self):
  176. """
  177. 测试复杂的字典结构的转换
  178. """
  179. torch_dict = {
  180. "tensor": torch.rand((3, 4)),
  181. "list": [torch.rand(6, 4, 2) for i in range(10)],
  182. "dict":{
  183. "list": [torch.rand(6, 4, 2) for i in range(10)],
  184. "tensor": torch.rand((3, 4))
  185. },
  186. "int": 2,
  187. "string": "test string"
  188. }
  189. res = torch2paddle(torch_dict)
  190. assert isinstance(res, dict)
  191. self.check_paddle_tensor(res["tensor"], "cpu", True)
  192. assert isinstance(res["list"], list)
  193. for t in res["list"]:
  194. self.check_paddle_tensor(t, "cpu", True)
  195. assert isinstance(res["int"], int)
  196. assert isinstance(res["string"], str)
  197. assert isinstance(res["dict"], dict)
  198. assert isinstance(res["dict"]["list"], list)
  199. for t in res["dict"]["list"]:
  200. self.check_paddle_tensor(t, "cpu", True)
  201. self.check_paddle_tensor(res["dict"]["tensor"], "cpu", True)
  202. ############################################################################
  203. #
  204. # 测试jittor到torch的转换
  205. #
  206. ############################################################################
  207. @pytest.mark.torchjittor
  208. class TestJittor2Torch:
  209. def check_torch_tensor(self, tensor, device, requires_grad):
  210. """
  211. 检查得到的torch张量的工具函数
  212. """
  213. assert isinstance(tensor, torch.Tensor)
  214. if device == "cpu":
  215. assert not tensor.is_cuda
  216. else:
  217. assert tensor.device == torch.device(device)
  218. assert tensor.requires_grad == requires_grad
  219. def test_var_transfer(self):
  220. """
  221. 测试单个Jittor Var的转换
  222. """
  223. jittor_var = jittor.rand((3, 4, 5))
  224. res = jittor2torch(jittor_var)
  225. self.check_torch_tensor(res, "cpu", True)
  226. res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None)
  227. self.check_torch_tensor(res, "cuda:2", True)
  228. res = jittor2torch(jittor_var, device="cuda:2", no_gradient=True)
  229. self.check_torch_tensor(res, "cuda:2", False)
  230. res = jittor2torch(jittor_var, device="cuda:2", no_gradient=False)
  231. self.check_torch_tensor(res, "cuda:2", True)
  232. def test_var_list_transfer(self):
  233. """
  234. 测试Jittor列表的转换
  235. """
  236. jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)]
  237. res = jittor2torch(jittor_list)
  238. assert isinstance(res, list)
  239. for t in res:
  240. self.check_torch_tensor(t, "cpu", True)
  241. res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False)
  242. assert isinstance(res, list)
  243. for t in res:
  244. self.check_torch_tensor(t, "cuda:1", True)
  245. def test_var_tuple_transfer(self):
  246. """
  247. 测试Jittor变量元组的转换
  248. """
  249. jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)]
  250. jittor_tuple = tuple(jittor_list)
  251. res = jittor2torch(jittor_tuple, device="cpu")
  252. assert isinstance(res, tuple)
  253. for t in res:
  254. self.check_torch_tensor(t, "cpu", True)
  255. def test_dict_transfer(self):
  256. """
  257. 测试字典结构的转换
  258. """
  259. jittor_dict = {
  260. "tensor": jittor.rand((3, 4)),
  261. "list": [jittor.rand(6, 4, 2) for i in range(10)],
  262. "dict":{
  263. "list": [jittor.rand(6, 4, 2) for i in range(10)],
  264. "tensor": jittor.rand((3, 4))
  265. },
  266. "int": 2,
  267. "string": "test string"
  268. }
  269. res = jittor2torch(jittor_dict)
  270. assert isinstance(res, dict)
  271. self.check_torch_tensor(res["tensor"], "cpu", True)
  272. assert isinstance(res["list"], list)
  273. for t in res["list"]:
  274. self.check_torch_tensor(t, "cpu", True)
  275. assert isinstance(res["int"], int)
  276. assert isinstance(res["string"], str)
  277. assert isinstance(res["dict"], dict)
  278. assert isinstance(res["dict"]["list"], list)
  279. for t in res["dict"]["list"]:
  280. self.check_torch_tensor(t, "cpu", True)
  281. self.check_torch_tensor(res["dict"]["tensor"], "cpu", True)
  282. ############################################################################
  283. #
  284. # 测试torch到jittor的转换
  285. #
  286. ############################################################################
  287. @pytest.mark.torchjittor
  288. class TestTorch2Jittor:
  289. def check_jittor_var(self, var, requires_grad):
  290. """
  291. 检查得到的Jittor Var梯度情况的工具函数
  292. """
  293. assert isinstance(var, jittor.Var)
  294. assert var.requires_grad == requires_grad
  295. def test_gradient(self):
  296. """
  297. 测试反向传播的梯度
  298. """
  299. x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
  300. y = torch2jittor(x)
  301. z = 3 * (y ** 2)
  302. grad = jittor.grad(z, y)
  303. assert grad.tolist() == [6.0, 12.0, 18.0, 24.0, 30.0]
  304. def test_tensor_transfer(self):
  305. """
  306. 测试单个张量转换为Jittor
  307. """
  308. torch_tensor = torch.rand((3, 4, 5))
  309. res = torch2jittor(torch_tensor)
  310. self.check_jittor_var(res, False)
  311. res = torch2jittor(torch_tensor, no_gradient=None)
  312. self.check_jittor_var(res, False)
  313. res = torch2jittor(torch_tensor, no_gradient=True)
  314. self.check_jittor_var(res, False)
  315. res = torch2jittor(torch_tensor, no_gradient=False)
  316. self.check_jittor_var(res, True)
  317. def test_tensor_list_transfer(self):
  318. """
  319. 测试张量列表的转换
  320. """
  321. torch_list = [torch.rand((6, 4, 2)) for i in range(10)]
  322. res = torch2jittor(torch_list)
  323. assert isinstance(res, list)
  324. for t in res:
  325. self.check_jittor_var(t, False)
  326. res = torch2jittor(torch_list, no_gradient=False)
  327. assert isinstance(res, list)
  328. for t in res:
  329. self.check_jittor_var(t, True)
  330. def test_tensor_tuple_transfer(self):
  331. """
  332. 测试张量元组的转换
  333. """
  334. torch_list = [torch.rand((6, 4, 2)) for i in range(10)]
  335. torch_tuple = tuple(torch_list)
  336. res = torch2jittor(torch_tuple)
  337. assert isinstance(res, tuple)
  338. for t in res:
  339. self.check_jittor_var(t, False)
  340. def test_dict_transfer(self):
  341. """
  342. 测试字典结构的转换
  343. """
  344. torch_dict = {
  345. "tensor": torch.rand((3, 4)),
  346. "list": [torch.rand(6, 4, 2) for i in range(10)],
  347. "dict":{
  348. "list": [torch.rand(6, 4, 2) for i in range(10)],
  349. "tensor": torch.rand((3, 4))
  350. },
  351. "int": 2,
  352. "string": "test string"
  353. }
  354. res = torch2jittor(torch_dict)
  355. assert isinstance(res, dict)
  356. self.check_jittor_var(res["tensor"], False)
  357. assert isinstance(res["list"], list)
  358. for t in res["list"]:
  359. self.check_jittor_var(t, False)
  360. assert isinstance(res["int"], int)
  361. assert isinstance(res["string"], str)
  362. assert isinstance(res["dict"], dict)
  363. assert isinstance(res["dict"]["list"], list)
  364. for t in res["dict"]["list"]:
  365. self.check_jittor_var(t, False)
  366. self.check_jittor_var(res["dict"]["tensor"], False)