You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matching_esim.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import random
  2. import numpy as np
  3. import torch
  4. from torch.optim import Adamax
  5. from torch.optim.lr_scheduler import StepLR
  6. from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
  7. from fastNLP.core.callback import GradientClipCallback, LRScheduler, EvaluateCallback
  8. from fastNLP.core.losses import CrossEntropyLoss
  9. from fastNLP.embeddings import StaticEmbedding
  10. from fastNLP.embeddings import ElmoEmbedding
  11. from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe
  12. from fastNLP.models.snli import ESIM
  13. # define hyper-parameters
  14. class ESIMConfig:
  15. task = 'snli'
  16. embedding = 'glove'
  17. batch_size_per_gpu = 196
  18. n_epochs = 30
  19. lr = 2e-3
  20. seed = 42
  21. save_path = None # 模型存储的位置,None表示不存储模型。
  22. train_dataset_name = 'train'
  23. dev_dataset_name = 'dev'
  24. test_dataset_name = 'test'
  25. to_lower = True # 忽略大小写
  26. tokenizer = 'spacy' # 使用spacy进行分词
  27. arg = ESIMConfig()
  28. # set random seed
  29. random.seed(arg.seed)
  30. np.random.seed(arg.seed)
  31. torch.manual_seed(arg.seed)
  32. n_gpu = torch.cuda.device_count()
  33. if n_gpu > 0:
  34. torch.cuda.manual_seed_all(arg.seed)
  35. # load data set
  36. if arg.task == 'snli':
  37. data_bundle = SNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  38. elif arg.task == 'rte':
  39. data_bundle = RTEPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  40. elif arg.task == 'qnli':
  41. data_bundle = QNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  42. elif arg.task == 'mnli':
  43. data_bundle = MNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  44. elif arg.task == 'quora':
  45. data_bundle = QuoraPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
  46. else:
  47. raise RuntimeError(f'NOT support {arg.task} task yet!')
  48. print(data_bundle) # print details in data_bundle
  49. # load embedding
  50. if arg.embedding == 'elmo':
  51. embedding = ElmoEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-medium',
  52. requires_grad=True)
  53. elif arg.embedding == 'glove':
  54. embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d',
  55. requires_grad=True, normalize=False)
  56. else:
  57. raise RuntimeError(f'NOT support {arg.embedding} embedding yet!')
  58. # define model
  59. model = ESIM(embedding, num_labels=len(data_bundle.vocabs[Const.TARGET]))
  60. # define optimizer and callback
  61. optimizer = Adamax(lr=arg.lr, params=model.parameters())
  62. scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍
  63. callbacks = [
  64. GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10)
  65. LRScheduler(scheduler),
  66. ]
  67. if arg.task in ['snli']:
  68. callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name]))
  69. # evaluate test set in every epoch if task is snli.
  70. # define trainer
  71. trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model,
  72. optimizer=optimizer,
  73. loss=CrossEntropyLoss(),
  74. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  75. n_epochs=arg.n_epochs, print_every=-1,
  76. dev_data=data_bundle.datasets[arg.dev_dataset_name],
  77. metrics=AccuracyMetric(), metric_key='acc',
  78. device=[i for i in range(torch.cuda.device_count())],
  79. check_code_level=-1,
  80. save_path=arg.save_path,
  81. callbacks=callbacks)
  82. # train model
  83. trainer.train(load_best_model=True)
  84. # define tester
  85. tester = Tester(
  86. data=data_bundle.datasets[arg.test_dataset_name],
  87. model=model,
  88. metrics=AccuracyMetric(),
  89. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  90. device=[i for i in range(torch.cuda.device_count())],
  91. )
  92. # test model
  93. tester.test()