You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_BertSum.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import sys
  2. import argparse
  3. import os
  4. import json
  5. import torch
  6. from time import time
  7. from datetime import timedelta
  8. from os.path import join, exists
  9. from torch.optim import Adam
  10. from utils import get_data_path, get_rouge_path
  11. from dataloader import BertSumLoader
  12. from model import BertSum
  13. from fastNLP.core.optimizer import AdamW
  14. from metrics import MyBCELoss, LossMetric, RougeMetric
  15. from fastNLP.core.sampler import BucketSampler
  16. from callback import MyCallback, SaveModelCallback
  17. from fastNLP.core.trainer import Trainer
  18. from fastNLP.core.tester import Tester
  19. def configure_training(args):
  20. devices = [int(gpu) for gpu in args.gpus.split(',')]
  21. params = {}
  22. params['label_type'] = args.label_type
  23. params['batch_size'] = args.batch_size
  24. params['accum_count'] = args.accum_count
  25. params['max_lr'] = args.max_lr
  26. params['warmup_steps'] = args.warmup_steps
  27. params['n_epochs'] = args.n_epochs
  28. params['valid_steps'] = args.valid_steps
  29. return devices, params
  30. def train_model(args):
  31. # check if the data_path and save_path exists
  32. data_paths = get_data_path(args.mode, args.label_type)
  33. for name in data_paths:
  34. assert exists(data_paths[name])
  35. if not exists(args.save_path):
  36. os.makedirs(args.save_path)
  37. # load summarization datasets
  38. datasets = BertSumLoader().process(data_paths)
  39. print('Information of dataset is:')
  40. print(datasets)
  41. train_set = datasets.datasets['train']
  42. valid_set = datasets.datasets['val']
  43. # configure training
  44. devices, train_params = configure_training(args)
  45. with open(join(args.save_path, 'params.json'), 'w') as f:
  46. json.dump(train_params, f, indent=4)
  47. print('Devices is:')
  48. print(devices)
  49. # configure model
  50. model = BertSum()
  51. optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0)
  52. callbacks = [MyCallback(args), SaveModelCallback(args.save_path)]
  53. criterion = MyBCELoss()
  54. val_metric = [LossMetric()]
  55. # sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size)
  56. trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer,
  57. loss=criterion, batch_size=args.batch_size, # sampler=sampler,
  58. update_every=args.accum_count, n_epochs=args.n_epochs,
  59. print_every=100, dev_data=valid_set, metrics=val_metric,
  60. metric_key='-loss', validate_every=args.valid_steps,
  61. save_path=args.save_path, device=devices, callbacks=callbacks)
  62. print('Start training with the following hyper-parameters:')
  63. print(train_params)
  64. trainer.train()
  65. def test_model(args):
  66. models = os.listdir(args.save_path)
  67. # load dataset
  68. data_paths = get_data_path(args.mode, args.label_type)
  69. datasets = BertSumLoader().process(data_paths)
  70. print('Information of dataset is:')
  71. print(datasets)
  72. test_set = datasets.datasets['test']
  73. # only need 1 gpu for testing
  74. device = int(args.gpus)
  75. args.batch_size = 1
  76. for cur_model in models:
  77. print('Current model is {}'.format(cur_model))
  78. # load model
  79. model = torch.load(join(args.save_path, cur_model))
  80. # configure testing
  81. original_path, dec_path, ref_path = get_rouge_path(args.label_type)
  82. test_metric = RougeMetric(data_path=original_path, dec_path=dec_path,
  83. ref_path=ref_path, n_total = len(test_set))
  84. tester = Tester(data=test_set, model=model, metrics=[test_metric],
  85. batch_size=args.batch_size, device=device)
  86. tester.test()
  87. if __name__ == '__main__':
  88. parser = argparse.ArgumentParser(
  89. description='training/testing of BertSum(liu et al. 2019)'
  90. )
  91. parser.add_argument('--mode', required=True,
  92. help='training or testing of BertSum', type=str)
  93. parser.add_argument('--label_type', default='greedy',
  94. help='greedy/limit', type=str)
  95. parser.add_argument('--save_path', required=True,
  96. help='root of the model', type=str)
  97. # example for gpus input: '0,1,2,3'
  98. parser.add_argument('--gpus', required=True,
  99. help='available gpus for training(separated by commas)', type=str)
  100. parser.add_argument('--batch_size', default=18,
  101. help='the training batch size', type=int)
  102. parser.add_argument('--accum_count', default=2,
  103. help='number of updates steps to accumulate before performing a backward/update pass.', type=int)
  104. parser.add_argument('--max_lr', default=2e-5,
  105. help='max learning rate for warm up', type=float)
  106. parser.add_argument('--warmup_steps', default=10000,
  107. help='warm up steps for training', type=int)
  108. parser.add_argument('--n_epochs', default=10,
  109. help='total number of training epochs', type=int)
  110. parser.add_argument('--valid_steps', default=1000,
  111. help='number of update steps for checkpoint and validation', type=int)
  112. args = parser.parse_args()
  113. if args.mode == 'train':
  114. print('Training process of BertSum !!!')
  115. train_model(args)
  116. else:
  117. print('Testing process of BertSum !!!')
  118. test_model(args)