You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_test_mix_module.py 13 kB


  1. import pytest
  2. import os
  3. from itertools import chain
  4. import torch
  5. import paddle
  6. from paddle.io import Dataset, DataLoader
  7. import numpy as np
  8. from fastNLP.modules.mix_modules.mix_module import MixModule
  9. from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle
  10. from fastNLP.core import rank_zero_rm
  11. ############################################################################
  12. #
  13. # 测试类的基本功能
  14. #
  15. ############################################################################
  16. class MixModuleForTest(MixModule):
  17. def __init__(self):
  18. super(MixModuleForTest, self).__init__()
  19. self.torch_fc1 = torch.nn.Linear(10, 10)
  20. self.torch_softmax = torch.nn.Softmax(0)
  21. self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
  22. self.torch_tensor = torch.ones(3, 3)
  23. self.torch_param = torch.nn.Parameter(torch.ones(4, 4))
  24. self.paddle_fc1 = paddle.nn.Linear(10, 10)
  25. self.paddle_softmax = paddle.nn.Softmax(0)
  26. self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3)
  27. self.paddle_tensor = paddle.ones((4, 4))
  28. class TorchModuleForTest(torch.nn.Module):
  29. def __init__(self):
  30. super(TorchModuleForTest, self).__init__()
  31. self.torch_fc1 = torch.nn.Linear(10, 10)
  32. self.torch_softmax = torch.nn.Softmax(0)
  33. self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
  34. self.torch_tensor = torch.ones(3, 3)
  35. self.torch_param = torch.nn.Parameter(torch.ones(4, 4))
  36. class PaddleModuleForTest(paddle.nn.Layer):
  37. def __init__(self):
  38. super(PaddleModuleForTest, self).__init__()
  39. self.paddle_fc1 = paddle.nn.Linear(10, 10)
  40. self.paddle_softmax = paddle.nn.Softmax(0)
  41. self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3)
  42. self.paddle_tensor = paddle.ones((4, 4))
  43. @pytest.mark.torchpaddle
  44. class TestTorchPaddleMixModule:
  45. def setup_method(self):
  46. self.model = MixModuleForTest()
  47. self.torch_model = TorchModuleForTest()
  48. self.paddle_model = PaddleModuleForTest()
  49. def test_to(self):
  50. """
  51. 测试混合模型的to函数
  52. """
  53. self.model.to("cuda")
  54. self.torch_model.to("cuda")
  55. self.paddle_model.to("gpu")
  56. self.if_device_correct("cuda")
  57. self.model.to("cuda:2")
  58. self.torch_model.to("cuda:2")
  59. self.paddle_model.to("gpu:2")
  60. self.if_device_correct("cuda:2")
  61. self.model.to("gpu:1")
  62. self.torch_model.to("cuda:1")
  63. self.paddle_model.to("gpu:1")
  64. self.if_device_correct("cuda:1")
  65. self.model.to("cpu")
  66. self.torch_model.to("cpu")
  67. self.paddle_model.to("cpu")
  68. self.if_device_correct("cpu")
  69. def test_train_eval(self):
  70. """
  71. 测试train和eval函数
  72. """
  73. self.model.eval()
  74. self.if_training_correct(False)
  75. self.model.train()
  76. self.if_training_correct(True)
  77. def test_parameters(self):
  78. """
  79. 测试parameters()函数,由于初始化是随机的,目前仅比较得到结果的长度
  80. """
  81. mix_params = []
  82. params = []
  83. for value in self.model.named_parameters():
  84. mix_params.append(value)
  85. for value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()):
  86. params.append(value)
  87. assert len(params) == len(mix_params)
  88. def test_named_parameters(self):
  89. """
  90. 测试named_parameters函数
  91. """
  92. mix_param_names = []
  93. param_names = []
  94. for name, value in self.model.named_parameters():
  95. mix_param_names.append(name)
  96. for name, value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()):
  97. param_names.append(name)
  98. assert sorted(param_names) == sorted(mix_param_names)
  99. def test_torch_named_parameters(self):
  100. """
  101. 测试对torch参数的提取
  102. """
  103. mix_param_names = []
  104. param_names = []
  105. for name, value in self.model.named_parameters(backend="torch"):
  106. mix_param_names.append(name)
  107. for name, value in self.torch_model.named_parameters():
  108. param_names.append(name)
  109. assert sorted(param_names) == sorted(mix_param_names)
  110. def test_paddle_named_parameters(self):
  111. """
  112. 测试对paddle参数的提取
  113. """
  114. mix_param_names = []
  115. param_names = []
  116. for name, value in self.model.named_parameters(backend="paddle"):
  117. mix_param_names.append(name)
  118. for name, value in self.paddle_model.named_parameters():
  119. param_names.append(name)
  120. assert sorted(param_names) == sorted(mix_param_names)
  121. def test_torch_state_dict(self):
  122. """
  123. 测试提取torch的state dict
  124. """
  125. torch_dict = self.torch_model.state_dict()
  126. mix_dict = self.model.state_dict(backend="torch")
  127. assert sorted(torch_dict.keys()) == sorted(mix_dict.keys())
  128. def test_paddle_state_dict(self):
  129. """
  130. 测试提取paddle的state dict
  131. """
  132. paddle_dict = self.paddle_model.state_dict()
  133. mix_dict = self.model.state_dict(backend="paddle")
  134. # TODO 测试程序会显示passed后显示paddle的异常退出信息
  135. assert sorted(paddle_dict.keys()) == sorted(mix_dict.keys())
  136. def test_state_dict(self):
  137. """
  138. 测试提取所有的state dict
  139. """
  140. all_dict = self.torch_model.state_dict()
  141. all_dict.update(self.paddle_model.state_dict())
  142. mix_dict = self.model.state_dict()
  143. # TODO 测试程序会显示passed后显示paddle的异常退出信息
  144. assert sorted(all_dict.keys()) == sorted(mix_dict.keys())
  145. def test_load_state_dict(self):
  146. """
  147. 测试load_state_dict函数
  148. """
  149. state_dict = self.model.state_dict()
  150. new_model = MixModuleForTest()
  151. new_model.load_state_dict(state_dict)
  152. new_state_dict = new_model.state_dict()
  153. for name, value in state_dict.items():
  154. state_dict[name] = value.tolist()
  155. for name, value in new_state_dict.items():
  156. new_state_dict[name] = value.tolist()
  157. # self.assertDictEqual(state_dict, new_state_dict)
  158. def test_save_and_load_state_dict(self):
  159. """
  160. 测试save_state_dict_to_file和load_state_dict_from_file函数
  161. """
  162. path = "model"
  163. try:
  164. self.model.save_state_dict_to_file(path)
  165. new_model = MixModuleForTest()
  166. new_model.load_state_dict_from_file(path)
  167. state_dict = self.model.state_dict()
  168. new_state_dict = new_model.state_dict()
  169. for name, value in state_dict.items():
  170. state_dict[name] = value.tolist()
  171. for name, value in new_state_dict.items():
  172. new_state_dict[name] = value.tolist()
  173. # self.assertDictEqual(state_dict, new_state_dict)
  174. finally:
  175. rank_zero_rm(path)
  176. def if_device_correct(self, device):
  177. assert self.model.torch_fc1.weight.device == self.torch_model.torch_fc1.weight.device
  178. assert self.model.torch_conv2d1.weight.device == self.torch_model.torch_fc1.bias.device
  179. assert self.model.torch_conv2d1.bias.device == self.torch_model.torch_conv2d1.bias.device
  180. assert self.model.torch_tensor.device == self.torch_model.torch_tensor.device
  181. assert self.model.torch_param.device == self.torch_model.torch_param.device
  182. if device == "cpu":
  183. assert self.model.paddle_fc1.weight.place.is_cpu_place()
  184. assert self.model.paddle_fc1.bias.place.is_cpu_place()
  185. assert self.model.paddle_conv2d1.weight.place.is_cpu_place()
  186. assert self.model.paddle_conv2d1.bias.place.is_cpu_place()
  187. assert self.model.paddle_tensor.place.is_cpu_place()
  188. elif device.startswith("cuda"):
  189. assert self.model.paddle_fc1.weight.place.is_gpu_place()
  190. assert self.model.paddle_fc1.bias.place.is_gpu_place()
  191. assert self.model.paddle_conv2d1.weight.place.is_gpu_place()
  192. assert self.model.paddle_conv2d1.bias.place.is_gpu_place()
  193. assert self.model.paddle_tensor.place.is_gpu_place()
  194. assert self.model.paddle_fc1.weight.place.gpu_device_id() == self.paddle_model.paddle_fc1.weight.place.gpu_device_id()
  195. assert self.model.paddle_fc1.bias.place.gpu_device_id() == self.paddle_model.paddle_fc1.bias.place.gpu_device_id()
  196. assert self.model.paddle_conv2d1.weight.place.gpu_device_id() == self.paddle_model.paddle_conv2d1.weight.place.gpu_device_id()
  197. assert self.model.paddle_conv2d1.bias.place.gpu_device_id() == self.paddle_model.paddle_conv2d1.bias.place.gpu_device_id()
  198. assert self.model.paddle_tensor.place.gpu_device_id() == self.paddle_model.paddle_tensor.place.gpu_device_id()
  199. else:
  200. raise NotImplementedError
  201. def if_training_correct(self, training):
  202. assert self.model.torch_fc1.training == training
  203. assert self.model.torch_softmax.training == training
  204. assert self.model.torch_conv2d1.training == training
  205. assert self.model.paddle_fc1.training == training
  206. assert self.model.paddle_softmax.training == training
  207. assert self.model.paddle_conv2d1.training == training
  208. ############################################################################
  209. #
  210. # 测试在MNIST数据集上的表现
  211. #
  212. ############################################################################
  213. class MNISTDataset(Dataset):
  214. def __init__(self, dataset):
  215. self.dataset = [
  216. (
  217. np.array(img).astype('float32').reshape(-1),
  218. label
  219. ) for img, label in dataset
  220. ]
  221. def __getitem__(self, idx):
  222. return self.dataset[idx]
  223. def __len__(self):
  224. return len(self.dataset)
  225. class MixMNISTModel(MixModule):
  226. def __init__(self):
  227. super(MixMNISTModel, self).__init__()
  228. self.fc1 = paddle.nn.Linear(784, 64)
  229. self.fc2 = paddle.nn.Linear(64, 32)
  230. self.fc3 = torch.nn.Linear(32, 10)
  231. self.fc4 = torch.nn.Linear(10, 10)
  232. def forward(self, x):
  233. paddle_out = self.fc1(x)
  234. paddle_out = self.fc2(paddle_out)
  235. torch_in = paddle2torch(paddle_out)
  236. torch_out = self.fc3(torch_in)
  237. torch_out = self.fc4(torch_out)
  238. return torch_out
  239. @pytest.mark.torchpaddle
  240. class TestMNIST:
  241. @classmethod
  242. def setup_class(self):
  243. self.train_dataset = paddle.vision.datasets.MNIST(mode='train')
  244. self.test_dataset = paddle.vision.datasets.MNIST(mode='test')
  245. self.train_dataset = MNISTDataset(self.train_dataset)
  246. self.lr = 0.0003
  247. self.epochs = 20
  248. self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True)
  249. def setup_method(self):
  250. self.model = MixMNISTModel().to("cuda")
  251. self.torch_loss_func = torch.nn.CrossEntropyLoss()
  252. self.torch_opt = torch.optim.Adam(self.model.parameters(backend="torch"), self.lr)
  253. self.paddle_opt = paddle.optimizer.Adam(parameters=self.model.parameters(backend="paddle"), learning_rate=self.lr)
  254. def test_case1(self):
  255. # 开始训练
  256. for epoch in range(self.epochs):
  257. epoch_loss, batch = 0, 0
  258. for batch, (img, label) in enumerate(self.dataloader):
  259. img = paddle.to_tensor(img).cuda()
  260. torch_out = self.model(img)
  261. label = torch.from_numpy(label.numpy()).reshape(-1)
  262. loss = self.torch_loss_func(torch_out.cpu(), label)
  263. epoch_loss += loss.item()
  264. loss.backward()
  265. self.torch_opt.step()
  266. self.paddle_opt.step()
  267. self.torch_opt.zero_grad()
  268. self.paddle_opt.clear_grad()
  269. else:
  270. assert epoch_loss / (batch + 1) < 0.3
  271. # 开始测试
  272. correct = 0
  273. for img, label in self.test_dataset:
  274. img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
  275. torch_out = self.model(img)
  276. res = torch_out.softmax(-1).argmax().item()
  277. label = label.item()
  278. if res == label:
  279. correct += 1
  280. acc = correct / len(self.test_dataset)
  281. assert acc > 0.85
  282. ############################################################################
  283. #
  284. # 测试在ERNIE中文数据集上的表现
  285. #
  286. ############################################################################