You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callback_events.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import pytest
  2. from functools import reduce
  3. from fastNLP.core.callbacks.callback_events import Events, Filter
  4. class TestFilter:
  5. def test_params_check(self):
  6. # 顺利通过
  7. _filter1 = Filter(every=10)
  8. _filter2 = Filter(once=10)
  9. _filter3 = Filter(filter_fn=lambda: None)
  10. # 触发 ValueError
  11. with pytest.raises(ValueError) as e:
  12. _filter4 = Filter()
  13. exec_msg = e.value.args[0]
  14. assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter."
  15. # 触发 ValueError
  16. with pytest.raises(ValueError) as e:
  17. _filter5 = Filter(every=10, once=10)
  18. exec_msg = e.value.args[0]
  19. assert exec_msg == "These three values should be only set one."
  20. # 触发 TypeError
  21. with pytest.raises(ValueError) as e:
  22. _filter6 = Filter(every="heihei")
  23. exec_msg = e.value.args[0]
  24. assert exec_msg == "Argument every should be integer and greater than zero"
  25. # 触发 TypeError
  26. with pytest.raises(ValueError) as e:
  27. _filter7 = Filter(once="heihei")
  28. exec_msg = e.value.args[0]
  29. assert exec_msg == "Argument once should be integer and positive"
  30. # 触发 TypeError
  31. with pytest.raises(TypeError) as e:
  32. _filter7 = Filter(filter_fn="heihei")
  33. exec_msg = e.value.args[0]
  34. assert exec_msg == "Argument event_filter should be a callable"
  35. def test_every_filter(self):
  36. # every = 10
  37. @Filter(every=10)
  38. def _fn(data):
  39. return data
  40. _res = []
  41. for i in range(100):
  42. cu_res = _fn(i)
  43. if cu_res is not None:
  44. _res.append(cu_res)
  45. assert _res == [w-1 for w in range(10, 101, 10)]
  46. # every = 1
  47. @Filter(every=1)
  48. def _fn(data):
  49. return data
  50. _res = []
  51. for i in range(100):
  52. cu_res = _fn(i)
  53. if cu_res is not None:
  54. _res.append(cu_res)
  55. assert _res == list(range(100))
  56. def test_once_filter(self):
  57. # once = 10
  58. @Filter(once=10)
  59. def _fn(data):
  60. return data
  61. _res = []
  62. for i in range(100):
  63. cu_res = _fn(i)
  64. if cu_res is not None:
  65. _res.append(cu_res)
  66. assert _res == [9]
  67. def test_extract_filter_from_fn(self):
  68. @Filter(every=10)
  69. def _fn(data):
  70. return data
  71. _filter_num_called = []
  72. _filter_num_executed = []
  73. for i in range(100):
  74. cu_res = _fn(i)
  75. _filter = _fn.__fastNLP_filter__
  76. _filter_num_called.append(_filter.num_called)
  77. _filter_num_executed.append(_filter.num_executed)
  78. assert _filter_num_called == list(range(1, 101))
  79. assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10]
  80. def _fn(data):
  81. return data
  82. assert not hasattr(_fn, "__fastNLP_filter__")
  83. def test_filter_state_dict(self):
  84. # every = 10
  85. @Filter(every=10)
  86. def _fn(data):
  87. return data
  88. _res = []
  89. for i in range(50):
  90. cu_res = _fn(i)
  91. if cu_res is not None:
  92. _res.append(cu_res)
  93. assert _res == [w - 1 for w in range(10, 51, 10)]
  94. # 保存状态
  95. state = _fn.__fastNLP_filter__.state_dict()
  96. # 加载状态
  97. _fn.__fastNLP_filter__.load_state_dict(state)
  98. _res = []
  99. for i in range(50, 100):
  100. cu_res = _fn(i)
  101. if cu_res is not None:
  102. _res.append(cu_res)
  103. assert _res == [w - 1 for w in range(60, 101, 10)]
  104. @pytest.mark.torch
  105. def test_filter_fn_torch():
  106. from torch.optim import SGD
  107. from torch.utils.data import DataLoader
  108. from fastNLP.core.controllers.trainer import Trainer
  109. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  110. from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
  111. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  112. optimizer = SGD(model.parameters(), lr=0.0001)
  113. dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
  114. dataloader = DataLoader(dataset=dataset, batch_size=4)
  115. trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
  116. def filter_fn(filter, trainer):
  117. if trainer.__heihei_test__ == 10:
  118. return True
  119. return False
  120. @Filter(filter_fn=filter_fn)
  121. def _fn(trainer, data):
  122. return data
  123. _res = []
  124. for i in range(100):
  125. trainer.__heihei_test__ = i
  126. cu_res = _fn(trainer, i)
  127. if cu_res is not None:
  128. _res.append(cu_res)
  129. assert _res == [10]
  130. class TestCallbackEvents:
  131. def test_every(self):
  132. # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
  133. event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1;
  134. @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
  135. def _fn(data):
  136. return data
  137. _res = []
  138. for i in range(100):
  139. cu_res = _fn(i)
  140. if cu_res is not None:
  141. _res.append(cu_res)
  142. assert _res == list(range(100))
  143. event_state = Events.on_train_begin(every=10)
  144. @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
  145. def _fn(data):
  146. return data
  147. _res = []
  148. for i in range(100):
  149. cu_res = _fn(i)
  150. if cu_res is not None:
  151. _res.append(cu_res)
  152. assert _res == [w - 1 for w in range(10, 101, 10)]
  153. def test_once(self):
  154. event_state = Events.on_train_begin(once=10)
  155. @Filter(once=event_state.once)
  156. def _fn(data):
  157. return data
  158. _res = []
  159. for i in range(100):
  160. cu_res = _fn(i)
  161. if cu_res is not None:
  162. _res.append(cu_res)
  163. assert _res == [9]
  164. @pytest.mark.torch
  165. def test_callback_events_torch():
  166. from torch.optim import SGD
  167. from torch.utils.data import DataLoader
  168. from fastNLP.core.controllers.trainer import Trainer
  169. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  170. from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
  171. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  172. optimizer = SGD(model.parameters(), lr=0.0001)
  173. dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
  174. dataloader = DataLoader(dataset=dataset, batch_size=4)
  175. trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
  176. def filter_fn(filter, trainer):
  177. if trainer.__heihei_test__ == 10:
  178. return True
  179. return False
  180. event_state = Events.on_train_begin(filter_fn=filter_fn)
  181. @Filter(filter_fn=event_state.filter_fn)
  182. def _fn(trainer, data):
  183. return data
  184. _res = []
  185. for i in range(100):
  186. trainer.__heihei_test__ = i
  187. cu_res = _fn(trainer, i)
  188. if cu_res is not None:
  189. _res.append(cu_res)
  190. assert _res == [10]