You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mix_module.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import unittest
  2. import os
  3. from itertools import chain
  4. import torch
  5. import paddle
  6. from paddle.io import Dataset, DataLoader
  7. import numpy as np
  8. from fastNLP.modules.mix_modules.mix_module import MixModule
  9. from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle
  10. from fastNLP.core import synchronize_safe_rm
  11. ############################################################################
  12. #
  13. # 测试类的基本功能
  14. #
  15. ############################################################################
  16. class TestMixModule(MixModule):
  17. def __init__(self):
  18. super(TestMixModule, self).__init__()
  19. self.torch_fc1 = torch.nn.Linear(10, 10)
  20. self.torch_softmax = torch.nn.Softmax(0)
  21. self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
  22. self.torch_tensor = torch.ones(3, 3)
  23. self.torch_param = torch.nn.Parameter(torch.ones(4, 4))
  24. self.paddle_fc1 = paddle.nn.Linear(10, 10)
  25. self.paddle_softmax = paddle.nn.Softmax(0)
  26. self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3)
  27. self.paddle_tensor = paddle.ones((4, 4))
  28. class TestTorchModule(torch.nn.Module):
  29. def __init__(self):
  30. super(TestTorchModule, self).__init__()
  31. self.torch_fc1 = torch.nn.Linear(10, 10)
  32. self.torch_softmax = torch.nn.Softmax(0)
  33. self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
  34. self.torch_tensor = torch.ones(3, 3)
  35. self.torch_param = torch.nn.Parameter(torch.ones(4, 4))
  36. class TestPaddleModule(paddle.nn.Layer):
  37. def __init__(self):
  38. super(TestPaddleModule, self).__init__()
  39. self.paddle_fc1 = paddle.nn.Linear(10, 10)
  40. self.paddle_softmax = paddle.nn.Softmax(0)
  41. self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3)
  42. self.paddle_tensor = paddle.ones((4, 4))
  43. class TorchPaddleMixModuleTestCase(unittest.TestCase):
  44. def setUp(self):
  45. self.model = TestMixModule()
  46. self.torch_model = TestTorchModule()
  47. self.paddle_model = TestPaddleModule()
  48. def test_to(self):
  49. """
  50. 测试混合模型的to函数
  51. """
  52. self.model.to("cuda")
  53. self.torch_model.to("cuda")
  54. self.paddle_model.to("gpu")
  55. self.if_device_correct("cuda")
  56. self.model.to("cuda:2")
  57. self.torch_model.to("cuda:2")
  58. self.paddle_model.to("gpu:2")
  59. self.if_device_correct("cuda:2")
  60. self.model.to("gpu:1")
  61. self.torch_model.to("cuda:1")
  62. self.paddle_model.to("gpu:1")
  63. self.if_device_correct("cuda:1")
  64. self.model.to("cpu")
  65. self.torch_model.to("cpu")
  66. self.paddle_model.to("cpu")
  67. self.if_device_correct("cpu")
  68. def test_train_eval(self):
  69. """
  70. 测试train和eval函数
  71. """
  72. self.model.eval()
  73. self.if_training_correct(False)
  74. self.model.train()
  75. self.if_training_correct(True)
  76. def test_parameters(self):
  77. """
  78. 测试parameters()函数,由于初始化是随机的,目前仅比较得到结果的长度
  79. """
  80. mix_params = []
  81. params = []
  82. for value in self.model.named_parameters():
  83. mix_params.append(value)
  84. for value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()):
  85. params.append(value)
  86. self.assertEqual(len(params), len(mix_params))
  87. def test_named_parameters(self):
  88. """
  89. 测试named_parameters函数
  90. """
  91. mix_param_names = []
  92. param_names = []
  93. for name, value in self.model.named_parameters():
  94. mix_param_names.append(name)
  95. for name, value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()):
  96. param_names.append(name)
  97. self.assertListEqual(sorted(param_names), sorted(mix_param_names))
  98. def test_torch_named_parameters(self):
  99. """
  100. 测试对torch参数的提取
  101. """
  102. mix_param_names = []
  103. param_names = []
  104. for name, value in self.model.named_parameters(backend="torch"):
  105. mix_param_names.append(name)
  106. for name, value in self.torch_model.named_parameters():
  107. param_names.append(name)
  108. self.assertListEqual(sorted(param_names), sorted(mix_param_names))
  109. def test_paddle_named_parameters(self):
  110. """
  111. 测试对paddle参数的提取
  112. """
  113. mix_param_names = []
  114. param_names = []
  115. for name, value in self.model.named_parameters(backend="paddle"):
  116. mix_param_names.append(name)
  117. for name, value in self.paddle_model.named_parameters():
  118. param_names.append(name)
  119. self.assertListEqual(sorted(param_names), sorted(mix_param_names))
  120. def test_torch_state_dict(self):
  121. """
  122. 测试提取torch的state dict
  123. """
  124. torch_dict = self.torch_model.state_dict()
  125. mix_dict = self.model.state_dict(backend="torch")
  126. self.assertListEqual(sorted(torch_dict.keys()), sorted(mix_dict.keys()))
  127. def test_paddle_state_dict(self):
  128. """
  129. 测试提取paddle的state dict
  130. """
  131. paddle_dict = self.paddle_model.state_dict()
  132. mix_dict = self.model.state_dict(backend="paddle")
  133. # TODO 测试程序会显示passed后显示paddle的异常退出信息
  134. self.assertListEqual(sorted(paddle_dict.keys()), sorted(mix_dict.keys()))
  135. def test_state_dict(self):
  136. """
  137. 测试提取所有的state dict
  138. """
  139. all_dict = self.torch_model.state_dict()
  140. all_dict.update(self.paddle_model.state_dict())
  141. mix_dict = self.model.state_dict()
  142. # TODO 测试程序会显示passed后显示paddle的异常退出信息
  143. self.assertListEqual(sorted(all_dict.keys()), sorted(mix_dict.keys()))
  144. def test_load_state_dict(self):
  145. """
  146. 测试load_state_dict函数
  147. """
  148. state_dict = self.model.state_dict()
  149. new_model = TestMixModule()
  150. new_model.load_state_dict(state_dict)
  151. new_state_dict = new_model.state_dict()
  152. for name, value in state_dict.items():
  153. state_dict[name] = value.tolist()
  154. for name, value in new_state_dict.items():
  155. new_state_dict[name] = value.tolist()
  156. self.assertDictEqual(state_dict, new_state_dict)
  157. def test_save_and_load_state_dict(self):
  158. """
  159. 测试save_state_dict_to_file和load_state_dict_from_file函数
  160. """
  161. path = "model"
  162. try:
  163. self.model.save_state_dict_to_file(path)
  164. new_model = TestMixModule()
  165. new_model.load_state_dict_from_file(path)
  166. state_dict = self.model.state_dict()
  167. new_state_dict = new_model.state_dict()
  168. for name, value in state_dict.items():
  169. state_dict[name] = value.tolist()
  170. for name, value in new_state_dict.items():
  171. new_state_dict[name] = value.tolist()
  172. self.assertDictEqual(state_dict, new_state_dict)
  173. finally:
  174. synchronize_safe_rm(path)
  175. def if_device_correct(self, device):
  176. self.assertEqual(self.model.torch_fc1.weight.device, self.torch_model.torch_fc1.weight.device)
  177. self.assertEqual(self.model.torch_conv2d1.weight.device, self.torch_model.torch_fc1.bias.device)
  178. self.assertEqual(self.model.torch_conv2d1.bias.device, self.torch_model.torch_conv2d1.bias.device)
  179. self.assertEqual(self.model.torch_tensor.device, self.torch_model.torch_tensor.device)
  180. self.assertEqual(self.model.torch_param.device, self.torch_model.torch_param.device)
  181. if device == "cpu":
  182. self.assertTrue(self.model.paddle_fc1.weight.place.is_cpu_place())
  183. self.assertTrue(self.model.paddle_fc1.bias.place.is_cpu_place())
  184. self.assertTrue(self.model.paddle_conv2d1.weight.place.is_cpu_place())
  185. self.assertTrue(self.model.paddle_conv2d1.bias.place.is_cpu_place())
  186. self.assertTrue(self.model.paddle_tensor.place.is_cpu_place())
  187. elif device.startswith("cuda"):
  188. self.assertTrue(self.model.paddle_fc1.weight.place.is_gpu_place())
  189. self.assertTrue(self.model.paddle_fc1.bias.place.is_gpu_place())
  190. self.assertTrue(self.model.paddle_conv2d1.weight.place.is_gpu_place())
  191. self.assertTrue(self.model.paddle_conv2d1.bias.place.is_gpu_place())
  192. self.assertTrue(self.model.paddle_tensor.place.is_gpu_place())
  193. self.assertEqual(self.model.paddle_fc1.weight.place.gpu_device_id(), self.paddle_model.paddle_fc1.weight.place.gpu_device_id())
  194. self.assertEqual(self.model.paddle_fc1.bias.place.gpu_device_id(), self.paddle_model.paddle_fc1.bias.place.gpu_device_id())
  195. self.assertEqual(self.model.paddle_conv2d1.weight.place.gpu_device_id(), self.paddle_model.paddle_conv2d1.weight.place.gpu_device_id())
  196. self.assertEqual(self.model.paddle_conv2d1.bias.place.gpu_device_id(), self.paddle_model.paddle_conv2d1.bias.place.gpu_device_id())
  197. self.assertEqual(self.model.paddle_tensor.place.gpu_device_id(), self.paddle_model.paddle_tensor.place.gpu_device_id())
  198. else:
  199. raise NotImplementedError
  200. def if_training_correct(self, training):
  201. self.assertEqual(self.model.torch_fc1.training, training)
  202. self.assertEqual(self.model.torch_softmax.training, training)
  203. self.assertEqual(self.model.torch_conv2d1.training, training)
  204. self.assertEqual(self.model.paddle_fc1.training, training)
  205. self.assertEqual(self.model.paddle_softmax.training, training)
  206. self.assertEqual(self.model.paddle_conv2d1.training, training)
  207. ############################################################################
  208. #
  209. # 测试在MNIST数据集上的表现
  210. #
  211. ############################################################################
  212. class MNISTDataset(Dataset):
  213. def __init__(self, dataset):
  214. self.dataset = [
  215. (
  216. np.array(img).astype('float32').reshape(-1),
  217. label
  218. ) for img, label in dataset
  219. ]
  220. def __getitem__(self, idx):
  221. return self.dataset[idx]
  222. def __len__(self):
  223. return len(self.dataset)
  224. class MixMNISTModel(MixModule):
  225. def __init__(self):
  226. super(MixMNISTModel, self).__init__()
  227. self.fc1 = paddle.nn.Linear(784, 64)
  228. self.fc2 = paddle.nn.Linear(64, 32)
  229. self.fc3 = torch.nn.Linear(32, 10)
  230. self.fc4 = torch.nn.Linear(10, 10)
  231. def forward(self, x):
  232. paddle_out = self.fc1(x)
  233. paddle_out = self.fc2(paddle_out)
  234. torch_in = paddle2torch(paddle_out)
  235. torch_out = self.fc3(torch_in)
  236. torch_out = self.fc4(torch_out)
  237. return torch_out
  238. class TestMNIST(unittest.TestCase):
  239. @classmethod
  240. def setUpClass(self):
  241. self.train_dataset = paddle.vision.datasets.MNIST(mode='train')
  242. self.test_dataset = paddle.vision.datasets.MNIST(mode='test')
  243. self.train_dataset = MNISTDataset(self.train_dataset)
  244. self.lr = 0.0003
  245. self.epochs = 20
  246. self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True)
  247. def setUp(self):
  248. self.model = MixMNISTModel().to("cuda")
  249. self.torch_loss_func = torch.nn.CrossEntropyLoss()
  250. self.torch_opt = torch.optim.Adam(self.model.parameters(backend="torch"), self.lr)
  251. self.paddle_opt = paddle.optimizer.Adam(parameters=self.model.parameters(backend="paddle"), learning_rate=self.lr)
  252. def test_case1(self):
  253. # 开始训练
  254. for epoch in range(self.epochs):
  255. epoch_loss, batch = 0, 0
  256. for batch, (img, label) in enumerate(self.dataloader):
  257. img = paddle.to_tensor(img).cuda()
  258. torch_out = self.model(img)
  259. label = torch.from_numpy(label.numpy()).reshape(-1)
  260. loss = self.torch_loss_func(torch_out.cpu(), label)
  261. epoch_loss += loss.item()
  262. loss.backward()
  263. self.torch_opt.step()
  264. self.paddle_opt.step()
  265. self.torch_opt.zero_grad()
  266. self.paddle_opt.clear_grad()
  267. else:
  268. self.assertLess(epoch_loss / (batch + 1), 0.3)
  269. # 开始测试
  270. correct = 0
  271. for img, label in self.test_dataset:
  272. img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
  273. torch_out = self.model(img)
  274. res = torch_out.softmax(-1).argmax().item()
  275. label = label.item()
  276. if res == label:
  277. correct += 1
  278. acc = correct / len(self.test_dataset)
  279. self.assertGreater(acc, 0.85)
  280. ############################################################################
  281. #
  282. # 测试在ERNIE中文数据集上的表现
  283. #
  284. ############################################################################