You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 2.4 kB

6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import sys
  2. sys.path.append('../..')
  3. import torch
  4. from torch.optim import Adam
  5. from fastNLP.core.callback import Callback, GradientClipCallback
  6. from fastNLP.core.trainer import Trainer
  7. from fastNLP.io.pipe.coreference import CoreferencePipe
  8. from reproduction.coreference_resolution.model.config import Config
  9. from reproduction.coreference_resolution.model.model_re import Model
  10. from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss
  11. from reproduction.coreference_resolution.model.metric import CRMetric
  12. from fastNLP import SequentialSampler
  13. from fastNLP import cache_results
  14. # torch.backends.cudnn.benchmark = False
  15. # torch.backends.cudnn.deterministic = True
  16. class LRCallback(Callback):
  17. def __init__(self, parameters, decay_rate=1e-3):
  18. super().__init__()
  19. self.paras = parameters
  20. self.decay_rate = decay_rate
  21. def on_step_end(self):
  22. if self.step % 100 == 0:
  23. for para in self.paras:
  24. para['lr'] = para['lr'] * (1 - self.decay_rate)
  25. if __name__ == "__main__":
  26. config = Config()
  27. print(config)
  28. @cache_results('cache.pkl')
  29. def cache():
  30. bundle = CoreferencePipe(Config()).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path})
  31. return bundle
  32. data_info = cache()
  33. print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])),
  34. "\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"])))
  35. # print(data_info)
  36. model = Model(data_info.vocabs, config)
  37. print(model)
  38. loss = SoftmaxLoss()
  39. metric = CRMetric()
  40. optim = Adam(model.parameters(), lr=config.lr)
  41. lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay)
  42. trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"],
  43. loss=loss, metrics=metric, check_code_level=-1,sampler=None,
  44. batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch,
  45. optimizer=optim,
  46. save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save',
  47. callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)])
  48. print()
  49. trainer.train()