You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callback_events.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import pytest
  2. from functools import reduce
  3. from fastNLP.core.callbacks.callback_events import Events, Filter
  4. class TestFilter:
  5. def test_params_check(self):
  6. # 顺利通过
  7. _filter1 = Filter(every=10)
  8. _filter2 = Filter(once=10)
  9. _filter3 = Filter(filter_fn=lambda: None)
  10. # 触发 ValueError
  11. with pytest.raises(ValueError) as e:
  12. _filter4 = Filter()
  13. exec_msg = e.value.args[0]
  14. assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter."
  15. # 触发 ValueError
  16. with pytest.raises(ValueError) as e:
  17. _filter5 = Filter(every=10, once=10)
  18. exec_msg = e.value.args[0]
  19. assert exec_msg == "These three values should be only set one."
  20. # 触发 TypeError
  21. with pytest.raises(ValueError) as e:
  22. _filter6 = Filter(every="heihei")
  23. exec_msg = e.value.args[0]
  24. assert exec_msg == "Argument every should be integer and greater than zero"
  25. # 触发 TypeError
  26. with pytest.raises(ValueError) as e:
  27. _filter7 = Filter(once="heihei")
  28. exec_msg = e.value.args[0]
  29. assert exec_msg == "Argument once should be integer and positive"
  30. # 触发 TypeError
  31. with pytest.raises(TypeError) as e:
  32. _filter7 = Filter(filter_fn="heihei")
  33. exec_msg = e.value.args[0]
  34. assert exec_msg == "Argument event_filter should be a callable"
  35. def test_every_filter(self):
  36. # every = 10
  37. @Filter(every=10)
  38. def _fn(data):
  39. return data
  40. _res = []
  41. for i in range(100):
  42. cu_res = _fn(i)
  43. if cu_res is not None:
  44. _res.append(cu_res)
  45. assert _res == [w-1 for w in range(10, 101, 10)]
  46. # every = 1
  47. @Filter(every=1)
  48. def _fn(data):
  49. return data
  50. _res = []
  51. for i in range(100):
  52. cu_res = _fn(i)
  53. if cu_res is not None:
  54. _res.append(cu_res)
  55. assert _res == list(range(100))
  56. def test_once_filter(self):
  57. # once = 10
  58. @Filter(once=10)
  59. def _fn(data):
  60. return data
  61. _res = []
  62. for i in range(100):
  63. cu_res = _fn(i)
  64. if cu_res is not None:
  65. _res.append(cu_res)
  66. assert _res == [9]
  67. def test_filter_fn(self):
  68. from torch.optim import SGD
  69. from torch.utils.data import DataLoader
  70. from fastNLP.core.controllers.trainer import Trainer
  71. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  72. from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
  73. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  74. optimizer = SGD(model.parameters(), lr=0.0001)
  75. dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
  76. dataloader = DataLoader(dataset=dataset, batch_size=4)
  77. trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
  78. def filter_fn(filter, trainer):
  79. if trainer.__heihei_test__ == 10:
  80. return True
  81. return False
  82. @Filter(filter_fn=filter_fn)
  83. def _fn(trainer, data):
  84. return data
  85. _res = []
  86. for i in range(100):
  87. trainer.__heihei_test__ = i
  88. cu_res = _fn(trainer, i)
  89. if cu_res is not None:
  90. _res.append(cu_res)
  91. assert _res == [10]
  92. def test_extract_filter_from_fn(self):
  93. @Filter(every=10)
  94. def _fn(data):
  95. return data
  96. _filter_num_called = []
  97. _filter_num_executed = []
  98. for i in range(100):
  99. cu_res = _fn(i)
  100. _filter = _fn.__fastNLP_filter__
  101. _filter_num_called.append(_filter.num_called)
  102. _filter_num_executed.append(_filter.num_executed)
  103. assert _filter_num_called == list(range(1, 101))
  104. assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10]
  105. def _fn(data):
  106. return data
  107. assert not hasattr(_fn, "__fastNLP_filter__")
  108. def test_filter_state_dict(self):
  109. # every = 10
  110. @Filter(every=10)
  111. def _fn(data):
  112. return data
  113. _res = []
  114. for i in range(50):
  115. cu_res = _fn(i)
  116. if cu_res is not None:
  117. _res.append(cu_res)
  118. assert _res == [w - 1 for w in range(10, 51, 10)]
  119. # 保存状态
  120. state = _fn.__fastNLP_filter__.state_dict()
  121. # 加载状态
  122. _fn.__fastNLP_filter__.load_state_dict(state)
  123. _res = []
  124. for i in range(50, 100):
  125. cu_res = _fn(i)
  126. if cu_res is not None:
  127. _res.append(cu_res)
  128. assert _res == [w - 1 for w in range(60, 101, 10)]