You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callbacks.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \
  5. LRFinder, TensorboardCallback
  6. from fastNLP import DataSet
  7. from fastNLP import Instance
  8. from fastNLP import BCELoss
  9. from fastNLP import AccuracyMetric
  10. from fastNLP import SGD
  11. from fastNLP import Trainer
  12. from fastNLP.models.base_model import NaiveClassifier
  13. def prepare_env():
  14. def prepare_fake_dataset():
  15. mean = np.array([-3, -3])
  16. cov = np.array([[1, 0], [0, 1]])
  17. class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
  18. mean = np.array([3, 3])
  19. cov = np.array([[1, 0], [0, 1]])
  20. class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
  21. data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
  22. [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
  23. return data_set
  24. data_set = prepare_fake_dataset()
  25. data_set.set_input("x")
  26. data_set.set_target("y")
  27. model = NaiveClassifier(2, 1)
  28. return data_set, model
  29. class TestCallback(unittest.TestCase):
  30. def test_gradient_clip(self):
  31. data_set, model = prepare_env()
  32. trainer = Trainer(data_set, model,
  33. loss=BCELoss(pred="predict", target="y"),
  34. n_epochs=20,
  35. batch_size=32,
  36. print_every=50,
  37. optimizer=SGD(lr=0.1),
  38. check_code_level=2,
  39. use_tqdm=False,
  40. dev_data=data_set,
  41. metrics=AccuracyMetric(pred="predict", target="y"),
  42. callbacks=[GradientClipCallback(model.parameters(), clip_value=2)])
  43. trainer.train()
  44. def test_early_stop(self):
  45. data_set, model = prepare_env()
  46. trainer = Trainer(data_set, model,
  47. loss=BCELoss(pred="predict", target="y"),
  48. n_epochs=20,
  49. batch_size=32,
  50. print_every=50,
  51. optimizer=SGD(lr=0.01),
  52. check_code_level=2,
  53. use_tqdm=False,
  54. dev_data=data_set,
  55. metrics=AccuracyMetric(pred="predict", target="y"),
  56. callbacks=[EarlyStopCallback(5)])
  57. trainer.train()
  58. def test_lr_scheduler(self):
  59. data_set, model = prepare_env()
  60. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  61. trainer = Trainer(data_set, model,
  62. loss=BCELoss(pred="predict", target="y"),
  63. n_epochs=5,
  64. batch_size=32,
  65. print_every=50,
  66. optimizer=optimizer,
  67. check_code_level=2,
  68. use_tqdm=False,
  69. dev_data=data_set,
  70. metrics=AccuracyMetric(pred="predict", target="y"),
  71. callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))])
  72. trainer.train()
  73. def test_KeyBoardInterrupt(self):
  74. data_set, model = prepare_env()
  75. trainer = Trainer(data_set, model,
  76. loss=BCELoss(pred="predict", target="y"),
  77. n_epochs=5,
  78. batch_size=32,
  79. print_every=50,
  80. optimizer=SGD(lr=0.1),
  81. check_code_level=2,
  82. use_tqdm=False,
  83. callbacks=[ControlC(False)])
  84. trainer.train()
  85. def test_LRFinder(self):
  86. data_set, model = prepare_env()
  87. trainer = Trainer(data_set, model,
  88. loss=BCELoss(pred="predict", target="y"),
  89. n_epochs=5,
  90. batch_size=32,
  91. print_every=50,
  92. optimizer=SGD(lr=0.1),
  93. check_code_level=2,
  94. use_tqdm=False,
  95. callbacks=[LRFinder(len(data_set) // 32)])
  96. trainer.train()
  97. def test_TensorboardCallback(self):
  98. data_set, model = prepare_env()
  99. trainer = Trainer(data_set, model,
  100. loss=BCELoss(pred="predict", target="y"),
  101. n_epochs=5,
  102. batch_size=32,
  103. print_every=50,
  104. optimizer=SGD(lr=0.1),
  105. check_code_level=2,
  106. use_tqdm=False,
  107. dev_data=data_set,
  108. metrics=AccuracyMetric(pred="predict", target="y"),
  109. callbacks=[TensorboardCallback("loss", "metric")])
  110. trainer.train()
  111. def test_readonly_property(self):
  112. from fastNLP.core.callback import Callback
  113. passed_epochs = []
  114. total_epochs = 5
  115. class MyCallback(Callback):
  116. def __init__(self):
  117. super(MyCallback, self).__init__()
  118. def on_epoch_begin(self):
  119. passed_epochs.append(self.epoch)
  120. print(self.n_epochs, self.n_steps, self.batch_size)
  121. print(self.model)
  122. print(self.optimizer)
  123. data_set, model = prepare_env()
  124. trainer = Trainer(data_set, model,
  125. loss=BCELoss(pred="predict", target="y"),
  126. n_epochs=total_epochs,
  127. batch_size=32,
  128. print_every=50,
  129. optimizer=SGD(lr=0.1),
  130. check_code_level=2,
  131. use_tqdm=False,
  132. dev_data=data_set,
  133. metrics=AccuracyMetric(pred="predict", target="y"),
  134. callbacks=[MyCallback()])
  135. trainer.train()
  136. assert passed_epochs == list(range(1, total_epochs + 1))