You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matching_bert.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import random
  2. import numpy as np
  3. import torch
  4. from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
  5. from fastNLP.core.callback import WarmupCallback, EvaluateCallback
  6. from fastNLP.core.optimizer import AdamW
  7. from fastNLP.embeddings import BertEmbedding
  8. from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\
  9. QNLIBertPipe, QuoraBertPipe
  10. from fastNLP.models.bert import BertForSentenceMatching
  11. # define hyper-parameters
  12. class BERTConfig:
  13. task = 'snli'
  14. batch_size_per_gpu = 6
  15. n_epochs = 6
  16. lr = 2e-5
  17. warm_up_rate = 0.1
  18. seed = 42
  19. save_path = None # 模型存储的位置,None表示不存储模型。
  20. train_dataset_name = 'train'
  21. dev_dataset_name = 'dev'
  22. test_dataset_name = 'test'
  23. to_lower = True # 忽略大小写
  24. tokenizer = 'spacy' # 使用spacy进行分词
  25. bert_model_dir_or_name = 'bert-base-uncased'
  26. arg = BERTConfig()
  27. # set random seed
  28. random.seed(arg.seed)
  29. np.random.seed(arg.seed)
  30. torch.manual_seed(arg.seed)
  31. n_gpu = torch.cuda.device_count()
  32. if n_gpu > 0:
  33. torch.cuda.manual_seed_all(arg.seed)
  34. # load data set
  35. if arg.task == 'snli':
  36. data_bundle = SNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  37. elif arg.task == 'rte':
  38. data_bundle = RTEBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  39. elif arg.task == 'qnli':
  40. data_bundle = QNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  41. elif arg.task == 'mnli':
  42. data_bundle = MNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  43. elif arg.task == 'quora':
  44. data_bundle = QuoraBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  45. else:
  46. raise RuntimeError(f'NOT support {arg.task} task yet!')
  47. print(data_bundle) # print details in data_bundle
  48. # load embedding
  49. embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name)
  50. # define model
  51. model = BertForSentenceMatching(embed, num_labels=len(data_bundle.vocabs[Const.TARGET]))
  52. # define optimizer and callback
  53. optimizer = AdamW(lr=arg.lr, params=model.parameters())
  54. callbacks = [WarmupCallback(warmup=arg.warm_up_rate, schedule='linear'), ]
  55. if arg.task in ['snli']:
  56. callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name]))
  57. # evaluate test set in every epoch if task is snli.
  58. # define trainer
  59. trainer = Trainer(train_data=data_bundle.get_dataset(arg.train_dataset_name), model=model,
  60. optimizer=optimizer,
  61. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  62. n_epochs=arg.n_epochs, print_every=-1,
  63. dev_data=data_bundle.get_dataset(arg.dev_dataset_name),
  64. metrics=AccuracyMetric(), metric_key='acc',
  65. device=[i for i in range(torch.cuda.device_count())],
  66. check_code_level=-1,
  67. save_path=arg.save_path,
  68. callbacks=callbacks)
  69. # train model
  70. trainer.train(load_best_model=True)
  71. # define tester
  72. tester = Tester(
  73. data=data_bundle.get_dataset(arg.test_dataset_name),
  74. model=model,
  75. metrics=AccuracyMetric(),
  76. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  77. device=[i for i in range(torch.cuda.device_count())],
  78. )
  79. # test model
  80. tester.test()