You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 1.9 kB

7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from action.action import Action
  2. from action.tester import Tester
  3. class Trainer(Action):
  4. """
  5. Trainer for common training logic of all models
  6. """
  7. def __init__(self, train_args):
  8. """
  9. :param train_args: namedtuple
  10. """
  11. super(Trainer, self).__init__()
  12. self.train_args = train_args
  13. self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()}
  14. self.n_epochs = self.train_args.epochs
  15. self.validate = True
  16. self.save_when_better = True
  17. def train(self, network, data, dev_data):
  18. X, Y = network.prepare_input(data)
  19. iterations, train_batch_generator = self.batchify(X, Y)
  20. loss_history = list()
  21. network.mode(test=False)
  22. test_args = "..."
  23. evaluator = Tester(test_args)
  24. best_loss = 1e10
  25. for epoch in range(self.n_epochs):
  26. for step in range(iterations):
  27. batch_x, batch_y = train_batch_generator.__next__()
  28. prediction = network.data_forward(batch_x)
  29. loss = network.loss(batch_y, prediction)
  30. network.grad_backward()
  31. loss_history.append(loss)
  32. self.log(self.make_log(epoch, step, loss))
  33. # evaluate over dev set
  34. if self.validate:
  35. evaluator.test(network, dev_data)
  36. self.log(self.make_valid_log(epoch, evaluator.loss))
  37. if evaluator.loss < best_loss:
  38. best_loss = evaluator.loss
  39. if self.save_when_better:
  40. self.save_model(network)
  41. # finish training
  42. def make_log(self, *args):
  43. raise NotImplementedError
  44. def make_valid_log(self, *args):
  45. raise NotImplementedError
  46. def save_model(self, model):
  47. raise NotImplementedError

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等