You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_test_mix_module.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import pytest
  2. import os
  3. from itertools import chain
  4. import torch
  5. import paddle
  6. from paddle.io import Dataset, DataLoader
  7. import numpy as np
  8. from fastNLP.modules.mix_modules.mix_module import MixModule
  9. from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle
  10. from fastNLP.core import rank_zero_rm
  11. ############################################################################
  12. #
  13. # 测试类的基本功能
  14. #
  15. ############################################################################
  16. class MixModuleForTest(MixModule):
  17. def __init__(self):
  18. super(MixModuleForTest, self).__init__()
  19. self.torch_fc1 = torch.nn.Linear(10, 10)
  20. self.torch_softmax = torch.nn.Softmax(0)
  21. self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
  22. self.torch_tensor = torch.ones(3, 3)
  23. self.torch_param = torch.nn.Parameter(torch.ones(4, 4))
  24. self.paddle_fc1 = paddle.nn.Linear(10, 10)
  25. self.paddle_softmax = paddle.nn.Softmax(0)
  26. self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3)
  27. self.paddle_tensor = paddle.ones((4, 4))
  28. class TorchModuleForTest(torch.nn.Module):
  29. def __init__(self):
  30. super(TorchModuleForTest, self).__init__()
  31. self.torch_fc1 = torch.nn.Linear(10, 10)
  32. self.torch_softmax = torch.nn.Softmax(0)
  33. self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
  34. self.torch_tensor = torch.ones(3, 3)
  35. self.torch_param = torch.nn.Parameter(torch.ones(4, 4))
  36. class PaddleModuleForTest(paddle.nn.Layer):
  37. def __init__(self):
  38. super(PaddleModuleForTest, self).__init__()
  39. self.paddle_fc1 = paddle.nn.Linear(10, 10)
  40. self.paddle_softmax = paddle.nn.Softmax(0)
  41. self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3)
  42. self.paddle_tensor = paddle.ones((4, 4))
  43. @pytest.mark.torchpaddle
  44. class TestTorchPaddleMixModule:
  45. def setup_method(self):
  46. self.model = MixModuleForTest()
  47. self.torch_model = TorchModuleForTest()
  48. self.paddle_model = PaddleModuleForTest()
  49. def test_to(self):
  50. """
  51. 测试混合模型的to函数
  52. """
  53. self.model.to("cuda")
  54. self.torch_model.to("cuda")
  55. self.paddle_model.to("gpu")
  56. self.if_device_correct("cuda")
  57. self.model.to("cuda:2")
  58. self.torch_model.to("cuda:2")
  59. self.paddle_model.to("gpu:2")
  60. self.if_device_correct("cuda:2")
  61. self.model.to("gpu:1")
  62. self.torch_model.to("cuda:1")
  63. self.paddle_model.to("gpu:1")
  64. self.if_device_correct("cuda:1")
  65. self.model.to("cpu")
  66. self.torch_model.to("cpu")
  67. self.paddle_model.to("cpu")
  68. self.if_device_correct("cpu")
  69. def test_train_eval(self):
  70. """
  71. 测试train和eval函数
  72. """
  73. self.model.eval()
  74. self.if_training_correct(False)
  75. self.model.train()
  76. self.if_training_correct(True)
  77. def test_parameters(self):
  78. """
  79. 测试parameters()函数,由于初始化是随机的,目前仅比较得到结果的长度
  80. """
  81. mix_params = []
  82. params = []
  83. for value in self.model.named_parameters():
  84. mix_params.append(value)
  85. for value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()):
  86. params.append(value)
  87. assert len(params) == len(mix_params)
  88. def test_named_parameters(self):
  89. """
  90. 测试named_parameters函数
  91. """
  92. mix_param_names = []
  93. param_names = []
  94. for name, value in self.model.named_parameters():
  95. mix_param_names.append(name)
  96. for name, value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()):
  97. param_names.append(name)
  98. assert sorted(param_names) == sorted(mix_param_names)
  99. def test_torch_named_parameters(self):
  100. """
  101. 测试对torch参数的提取
  102. """
  103. mix_param_names = []
  104. param_names = []
  105. for name, value in self.model.named_parameters(backend="torch"):
  106. mix_param_names.append(name)
  107. for name, value in self.torch_model.named_parameters():
  108. param_names.append(name)
  109. assert sorted(param_names) == sorted(mix_param_names)
  110. def test_paddle_named_parameters(self):
  111. """
  112. 测试对paddle参数的提取
  113. """
  114. mix_param_names = []
  115. param_names = []
  116. for name, value in self.model.named_parameters(backend="paddle"):
  117. mix_param_names.append(name)
  118. for name, value in self.paddle_model.named_parameters():
  119. param_names.append(name)
  120. assert sorted(param_names) == sorted(mix_param_names)
  121. def test_torch_state_dict(self):
  122. """
  123. 测试提取torch的state dict
  124. """
  125. torch_dict = self.torch_model.state_dict()
  126. mix_dict = self.model.state_dict(backend="torch")
  127. assert sorted(torch_dict.keys()) == sorted(mix_dict.keys())
  128. def test_paddle_state_dict(self):
  129. """
  130. 测试提取paddle的state dict
  131. """
  132. paddle_dict = self.paddle_model.state_dict()
  133. mix_dict = self.model.state_dict(backend="paddle")
  134. # TODO 测试程序会显示passed后显示paddle的异常退出信息
  135. assert sorted(paddle_dict.keys()) == sorted(mix_dict.keys())
  136. def test_state_dict(self):
  137. """
  138. 测试提取所有的state dict
  139. """
  140. all_dict = self.torch_model.state_dict()
  141. all_dict.update(self.paddle_model.state_dict())
  142. mix_dict = self.model.state_dict()
  143. # TODO 测试程序会显示passed后显示paddle的异常退出信息
  144. assert sorted(all_dict.keys()) == sorted(mix_dict.keys())
  145. def test_load_state_dict(self):
  146. """
  147. 测试load_state_dict函数
  148. """
  149. state_dict = self.model.state_dict()
  150. new_model = MixModuleForTest()
  151. new_model.load_state_dict(state_dict)
  152. new_state_dict = new_model.state_dict()
  153. for name, value in state_dict.items():
  154. state_dict[name] = value.tolist()
  155. for name, value in new_state_dict.items():
  156. new_state_dict[name] = value.tolist()
  157. # self.assertDictEqual(state_dict, new_state_dict)
  158. def test_save_and_load_state_dict(self):
  159. """
  160. 测试save_state_dict_to_file和load_state_dict_from_file函数
  161. """
  162. path = "model"
  163. try:
  164. self.model.save_state_dict_to_file(path)
  165. new_model = MixModuleForTest()
  166. new_model.load_state_dict_from_file(path)
  167. state_dict = self.model.state_dict()
  168. new_state_dict = new_model.state_dict()
  169. for name, value in state_dict.items():
  170. state_dict[name] = value.tolist()
  171. for name, value in new_state_dict.items():
  172. new_state_dict[name] = value.tolist()
  173. # self.assertDictEqual(state_dict, new_state_dict)
  174. finally:
  175. rank_zero_rm(path)
  176. def if_device_correct(self, device):
  177. assert self.model.torch_fc1.weight.device == self.torch_model.torch_fc1.weight.device
  178. assert self.model.torch_conv2d1.weight.device == self.torch_model.torch_fc1.bias.device
  179. assert self.model.torch_conv2d1.bias.device == self.torch_model.torch_conv2d1.bias.device
  180. assert self.model.torch_tensor.device == self.torch_model.torch_tensor.device
  181. assert self.model.torch_param.device == self.torch_model.torch_param.device
  182. if device == "cpu":
  183. assert self.model.paddle_fc1.weight.place.is_cpu_place()
  184. assert self.model.paddle_fc1.bias.place.is_cpu_place()
  185. assert self.model.paddle_conv2d1.weight.place.is_cpu_place()
  186. assert self.model.paddle_conv2d1.bias.place.is_cpu_place()
  187. assert self.model.paddle_tensor.place.is_cpu_place()
  188. elif device.startswith("cuda"):
  189. assert self.model.paddle_fc1.weight.place.is_gpu_place()
  190. assert self.model.paddle_fc1.bias.place.is_gpu_place()
  191. assert self.model.paddle_conv2d1.weight.place.is_gpu_place()
  192. assert self.model.paddle_conv2d1.bias.place.is_gpu_place()
  193. assert self.model.paddle_tensor.place.is_gpu_place()
  194. assert self.model.paddle_fc1.weight.place.gpu_device_id() == self.paddle_model.paddle_fc1.weight.place.gpu_device_id()
  195. assert self.model.paddle_fc1.bias.place.gpu_device_id() == self.paddle_model.paddle_fc1.bias.place.gpu_device_id()
  196. assert self.model.paddle_conv2d1.weight.place.gpu_device_id() == self.paddle_model.paddle_conv2d1.weight.place.gpu_device_id()
  197. assert self.model.paddle_conv2d1.bias.place.gpu_device_id() == self.paddle_model.paddle_conv2d1.bias.place.gpu_device_id()
  198. assert self.model.paddle_tensor.place.gpu_device_id() == self.paddle_model.paddle_tensor.place.gpu_device_id()
  199. else:
  200. raise NotImplementedError
  201. def if_training_correct(self, training):
  202. assert self.model.torch_fc1.training == training
  203. assert self.model.torch_softmax.training == training
  204. assert self.model.torch_conv2d1.training == training
  205. assert self.model.paddle_fc1.training == training
  206. assert self.model.paddle_softmax.training == training
  207. assert self.model.paddle_conv2d1.training == training
  208. ############################################################################
  209. #
  210. # 测试在MNIST数据集上的表现
  211. #
  212. ############################################################################
  213. class MNISTDataset(Dataset):
  214. def __init__(self, dataset):
  215. self.dataset = [
  216. (
  217. np.array(img).astype('float32').reshape(-1),
  218. label
  219. ) for img, label in dataset
  220. ]
  221. def __getitem__(self, idx):
  222. return self.dataset[idx]
  223. def __len__(self):
  224. return len(self.dataset)
  225. class MixMNISTModel(MixModule):
  226. def __init__(self):
  227. super(MixMNISTModel, self).__init__()
  228. self.fc1 = paddle.nn.Linear(784, 64)
  229. self.fc2 = paddle.nn.Linear(64, 32)
  230. self.fc3 = torch.nn.Linear(32, 10)
  231. self.fc4 = torch.nn.Linear(10, 10)
  232. def forward(self, x):
  233. paddle_out = self.fc1(x)
  234. paddle_out = self.fc2(paddle_out)
  235. torch_in = paddle2torch(paddle_out)
  236. torch_out = self.fc3(torch_in)
  237. torch_out = self.fc4(torch_out)
  238. return torch_out
  239. @pytest.mark.torchpaddle
  240. class TestMNIST:
  241. @classmethod
  242. def setup_class(self):
  243. self.train_dataset = paddle.vision.datasets.MNIST(mode='train')
  244. self.test_dataset = paddle.vision.datasets.MNIST(mode='test')
  245. self.train_dataset = MNISTDataset(self.train_dataset)
  246. self.lr = 0.0003
  247. self.epochs = 20
  248. self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True)
  249. def setup_method(self):
  250. self.model = MixMNISTModel().to("cuda")
  251. self.torch_loss_func = torch.nn.CrossEntropyLoss()
  252. self.torch_opt = torch.optim.Adam(self.model.parameters(backend="torch"), self.lr)
  253. self.paddle_opt = paddle.optimizer.Adam(parameters=self.model.parameters(backend="paddle"), learning_rate=self.lr)
  254. def test_case1(self):
  255. # 开始训练
  256. for epoch in range(self.epochs):
  257. epoch_loss, batch = 0, 0
  258. for batch, (img, label) in enumerate(self.dataloader):
  259. img = paddle.to_tensor(img).cuda()
  260. torch_out = self.model(img)
  261. label = torch.from_numpy(label.numpy()).reshape(-1)
  262. loss = self.torch_loss_func(torch_out.cpu(), label)
  263. epoch_loss += loss.item()
  264. loss.backward()
  265. self.torch_opt.step()
  266. self.paddle_opt.step()
  267. self.torch_opt.zero_grad()
  268. self.paddle_opt.clear_grad()
  269. else:
  270. assert epoch_loss / (batch + 1) < 0.3
  271. # 开始测试
  272. correct = 0
  273. for img, label in self.test_dataset:
  274. img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
  275. torch_out = self.model(img)
  276. res = torch_out.softmax(-1).argmax().item()
  277. label = label.item()
  278. if res == label:
  279. correct += 1
  280. acc = correct / len(self.test_dataset)
  281. assert acc > 0.85
  282. ############################################################################
  283. #
  284. # 测试在ERNIE中文数据集上的表现
  285. #
  286. ############################################################################