You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fp16.py 9.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.runner import auto_fp16, force_fp32
  7. from mmcv.runner.fp16_utils import cast_tensor_type
  8. def test_cast_tensor_type():
  9. inputs = torch.FloatTensor([5.])
  10. src_type = torch.float32
  11. dst_type = torch.int32
  12. outputs = cast_tensor_type(inputs, src_type, dst_type)
  13. assert isinstance(outputs, torch.Tensor)
  14. assert outputs.dtype == dst_type
  15. inputs = 'tensor'
  16. src_type = str
  17. dst_type = str
  18. outputs = cast_tensor_type(inputs, src_type, dst_type)
  19. assert isinstance(outputs, str)
  20. inputs = np.array([5.])
  21. src_type = np.ndarray
  22. dst_type = np.ndarray
  23. outputs = cast_tensor_type(inputs, src_type, dst_type)
  24. assert isinstance(outputs, np.ndarray)
  25. inputs = dict(
  26. tensor_a=torch.FloatTensor([1.]), tensor_b=torch.FloatTensor([2.]))
  27. src_type = torch.float32
  28. dst_type = torch.int32
  29. outputs = cast_tensor_type(inputs, src_type, dst_type)
  30. assert isinstance(outputs, dict)
  31. assert outputs['tensor_a'].dtype == dst_type
  32. assert outputs['tensor_b'].dtype == dst_type
  33. inputs = [torch.FloatTensor([1.]), torch.FloatTensor([2.])]
  34. src_type = torch.float32
  35. dst_type = torch.int32
  36. outputs = cast_tensor_type(inputs, src_type, dst_type)
  37. assert isinstance(outputs, list)
  38. assert outputs[0].dtype == dst_type
  39. assert outputs[1].dtype == dst_type
  40. inputs = 5
  41. outputs = cast_tensor_type(inputs, None, None)
  42. assert isinstance(outputs, int)
  43. def test_auto_fp16():
  44. with pytest.raises(TypeError):
  45. # ExampleObject is not a subclass of nn.Module
  46. class ExampleObject:
  47. @auto_fp16()
  48. def __call__(self, x):
  49. return x
  50. model = ExampleObject()
  51. input_x = torch.ones(1, dtype=torch.float32)
  52. model(input_x)
  53. # apply to all input args
  54. class ExampleModule(nn.Module):
  55. @auto_fp16()
  56. def forward(self, x, y):
  57. return x, y
  58. model = ExampleModule()
  59. input_x = torch.ones(1, dtype=torch.float32)
  60. input_y = torch.ones(1, dtype=torch.float32)
  61. output_x, output_y = model(input_x, input_y)
  62. assert output_x.dtype == torch.float32
  63. assert output_y.dtype == torch.float32
  64. model.fp16_enabled = True
  65. output_x, output_y = model(input_x, input_y)
  66. assert output_x.dtype == torch.half
  67. assert output_y.dtype == torch.half
  68. if torch.cuda.is_available():
  69. model.cuda()
  70. output_x, output_y = model(input_x.cuda(), input_y.cuda())
  71. assert output_x.dtype == torch.half
  72. assert output_y.dtype == torch.half
  73. # apply to specified input args
  74. class ExampleModule(nn.Module):
  75. @auto_fp16(apply_to=('x', ))
  76. def forward(self, x, y):
  77. return x, y
  78. model = ExampleModule()
  79. input_x = torch.ones(1, dtype=torch.float32)
  80. input_y = torch.ones(1, dtype=torch.float32)
  81. output_x, output_y = model(input_x, input_y)
  82. assert output_x.dtype == torch.float32
  83. assert output_y.dtype == torch.float32
  84. model.fp16_enabled = True
  85. output_x, output_y = model(input_x, input_y)
  86. assert output_x.dtype == torch.half
  87. assert output_y.dtype == torch.float32
  88. if torch.cuda.is_available():
  89. model.cuda()
  90. output_x, output_y = model(input_x.cuda(), input_y.cuda())
  91. assert output_x.dtype == torch.half
  92. assert output_y.dtype == torch.float32
  93. # apply to optional input args
  94. class ExampleModule(nn.Module):
  95. @auto_fp16(apply_to=('x', 'y'))
  96. def forward(self, x, y=None, z=None):
  97. return x, y, z
  98. model = ExampleModule()
  99. input_x = torch.ones(1, dtype=torch.float32)
  100. input_y = torch.ones(1, dtype=torch.float32)
  101. input_z = torch.ones(1, dtype=torch.float32)
  102. output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
  103. assert output_x.dtype == torch.float32
  104. assert output_y.dtype == torch.float32
  105. assert output_z.dtype == torch.float32
  106. model.fp16_enabled = True
  107. output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
  108. assert output_x.dtype == torch.half
  109. assert output_y.dtype == torch.half
  110. assert output_z.dtype == torch.float32
  111. if torch.cuda.is_available():
  112. model.cuda()
  113. output_x, output_y, output_z = model(
  114. input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
  115. assert output_x.dtype == torch.half
  116. assert output_y.dtype == torch.half
  117. assert output_z.dtype == torch.float32
  118. # out_fp32=True
  119. class ExampleModule(nn.Module):
  120. @auto_fp16(apply_to=('x', 'y'), out_fp32=True)
  121. def forward(self, x, y=None, z=None):
  122. return x, y, z
  123. model = ExampleModule()
  124. input_x = torch.ones(1, dtype=torch.half)
  125. input_y = torch.ones(1, dtype=torch.float32)
  126. input_z = torch.ones(1, dtype=torch.float32)
  127. output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
  128. assert output_x.dtype == torch.half
  129. assert output_y.dtype == torch.float32
  130. assert output_z.dtype == torch.float32
  131. model.fp16_enabled = True
  132. output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
  133. assert output_x.dtype == torch.float32
  134. assert output_y.dtype == torch.float32
  135. assert output_z.dtype == torch.float32
  136. if torch.cuda.is_available():
  137. model.cuda()
  138. output_x, output_y, output_z = model(
  139. input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
  140. assert output_x.dtype == torch.float32
  141. assert output_y.dtype == torch.float32
  142. assert output_z.dtype == torch.float32
  143. def test_force_fp32():
  144. with pytest.raises(TypeError):
  145. # ExampleObject is not a subclass of nn.Module
  146. class ExampleObject:
  147. @force_fp32()
  148. def __call__(self, x):
  149. return x
  150. model = ExampleObject()
  151. input_x = torch.ones(1, dtype=torch.float32)
  152. model(input_x)
  153. # apply to all input args
  154. class ExampleModule(nn.Module):
  155. @force_fp32()
  156. def forward(self, x, y):
  157. return x, y
  158. model = ExampleModule()
  159. input_x = torch.ones(1, dtype=torch.half)
  160. input_y = torch.ones(1, dtype=torch.half)
  161. output_x, output_y = model(input_x, input_y)
  162. assert output_x.dtype == torch.half
  163. assert output_y.dtype == torch.half
  164. model.fp16_enabled = True
  165. output_x, output_y = model(input_x, input_y)
  166. assert output_x.dtype == torch.float32
  167. assert output_y.dtype == torch.float32
  168. if torch.cuda.is_available():
  169. model.cuda()
  170. output_x, output_y = model(input_x.cuda(), input_y.cuda())
  171. assert output_x.dtype == torch.float32
  172. assert output_y.dtype == torch.float32
  173. # apply to specified input args
  174. class ExampleModule(nn.Module):
  175. @force_fp32(apply_to=('x', ))
  176. def forward(self, x, y):
  177. return x, y
  178. model = ExampleModule()
  179. input_x = torch.ones(1, dtype=torch.half)
  180. input_y = torch.ones(1, dtype=torch.half)
  181. output_x, output_y = model(input_x, input_y)
  182. assert output_x.dtype == torch.half
  183. assert output_y.dtype == torch.half
  184. model.fp16_enabled = True
  185. output_x, output_y = model(input_x, input_y)
  186. assert output_x.dtype == torch.float32
  187. assert output_y.dtype == torch.half
  188. if torch.cuda.is_available():
  189. model.cuda()
  190. output_x, output_y = model(input_x.cuda(), input_y.cuda())
  191. assert output_x.dtype == torch.float32
  192. assert output_y.dtype == torch.half
  193. # apply to optional input args
  194. class ExampleModule(nn.Module):
  195. @force_fp32(apply_to=('x', 'y'))
  196. def forward(self, x, y=None, z=None):
  197. return x, y, z
  198. model = ExampleModule()
  199. input_x = torch.ones(1, dtype=torch.half)
  200. input_y = torch.ones(1, dtype=torch.half)
  201. input_z = torch.ones(1, dtype=torch.half)
  202. output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
  203. assert output_x.dtype == torch.half
  204. assert output_y.dtype == torch.half
  205. assert output_z.dtype == torch.half
  206. model.fp16_enabled = True
  207. output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
  208. assert output_x.dtype == torch.float32
  209. assert output_y.dtype == torch.float32
  210. assert output_z.dtype == torch.half
  211. if torch.cuda.is_available():
  212. model.cuda()
  213. output_x, output_y, output_z = model(
  214. input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
  215. assert output_x.dtype == torch.float32
  216. assert output_y.dtype == torch.float32
  217. assert output_z.dtype == torch.half
  218. # out_fp16=True
  219. class ExampleModule(nn.Module):
  220. @force_fp32(apply_to=('x', 'y'), out_fp16=True)
  221. def forward(self, x, y=None, z=None):
  222. return x, y, z
  223. model = ExampleModule()
  224. input_x = torch.ones(1, dtype=torch.float32)
  225. input_y = torch.ones(1, dtype=torch.half)
  226. input_z = torch.ones(1, dtype=torch.half)
  227. output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
  228. assert output_x.dtype == torch.float32
  229. assert output_y.dtype == torch.half
  230. assert output_z.dtype == torch.half
  231. model.fp16_enabled = True
  232. output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
  233. assert output_x.dtype == torch.half
  234. assert output_y.dtype == torch.half
  235. assert output_z.dtype == torch.half
  236. if torch.cuda.is_available():
  237. model.cuda()
  238. output_x, output_y, output_z = model(
  239. input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
  240. assert output_x.dtype == torch.half
  241. assert output_y.dtype == torch.half
  242. assert output_z.dtype == torch.half

No Description

Contributors (2)