You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matching_bert.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import random
  2. import numpy as np
  3. import torch
  4. from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam
  5. from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \
  6. MNLILoader, QNLILoader, QuoraLoader
  7. from reproduction.matching.model.bert import BertForNLI
  8. # define hyper-parameters
  9. class BERTConfig:
  10. task = 'snli'
  11. batch_size_per_gpu = 6
  12. n_epochs = 6
  13. lr = 2e-5
  14. seq_len_type = 'bert'
  15. seed = 42
  16. train_dataset_name = 'train'
  17. dev_dataset_name = 'dev'
  18. test_dataset_name = 'test'
  19. save_path = None # 模型存储的位置,None表示不存储模型。
  20. bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹
  21. arg = BERTConfig()
  22. # set random seed
  23. random.seed(arg.seed)
  24. np.random.seed(arg.seed)
  25. torch.manual_seed(arg.seed)
  26. n_gpu = torch.cuda.device_count()
  27. if n_gpu > 0:
  28. torch.cuda.manual_seed_all(arg.seed)
  29. # load data set
  30. if arg.task == 'snli':
  31. data_info = SNLILoader().process(
  32. paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type,
  33. bert_tokenizer=arg.bert_dir, cut_text=512,
  34. get_index=True, concat='bert',
  35. )
  36. elif arg.task == 'rte':
  37. data_info = RTELoader().process(
  38. paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type,
  39. bert_tokenizer=arg.bert_dir, cut_text=512,
  40. get_index=True, concat='bert',
  41. )
  42. elif arg.task == 'qnli':
  43. data_info = QNLILoader().process(
  44. paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
  45. bert_tokenizer=arg.bert_dir, cut_text=512,
  46. get_index=True, concat='bert',
  47. )
  48. elif arg.task == 'mnli':
  49. data_info = MNLILoader().process(
  50. paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
  51. bert_tokenizer=arg.bert_dir, cut_text=512,
  52. get_index=True, concat='bert',
  53. )
  54. elif arg.task == 'quora':
  55. data_info = QuoraLoader().process(
  56. paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type,
  57. bert_tokenizer=arg.bert_dir, cut_text=512,
  58. get_index=True, concat='bert',
  59. )
  60. else:
  61. raise RuntimeError(f'NOT support {arg.task} task yet!')
  62. # define model
  63. model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir)
  64. # define trainer
  65. trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model,
  66. optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
  67. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  68. n_epochs=arg.n_epochs, print_every=-1,
  69. dev_data=data_info.datasets[arg.dev_dataset_name],
  70. metrics=AccuracyMetric(), metric_key='acc',
  71. device=[i for i in range(torch.cuda.device_count())],
  72. check_code_level=-1,
  73. save_path=arg.save_path)
  74. # train model
  75. trainer.train(load_best_model=True)
  76. # define tester
  77. tester = Tester(
  78. data=data_info.datasets[arg.test_dataset_name],
  79. model=model,
  80. metrics=AccuracyMetric(),
  81. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  82. device=[i for i in range(torch.cuda.device_count())],
  83. )
  84. # test model
  85. tester.test()