You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matching_mwan.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import random
  2. import numpy as np
  3. import torch
  4. from torch.optim import Adadelta
  5. from torch.optim.lr_scheduler import StepLR
  6. from fastNLP import CrossEntropyLoss
  7. from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
  8. from fastNLP.core.callback import LRScheduler, EvaluateCallback
  9. from fastNLP.embeddings import StaticEmbedding
  10. from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe
  11. from reproduction.matching.model.mwan import MwanModel
  12. import fitlog
  13. fitlog.debug()
  14. import argparse
  15. argument = argparse.ArgumentParser()
  16. argument.add_argument('--task' , choices = ['snli', 'rte', 'qnli', 'mnli'],default = 'snli')
  17. argument.add_argument('--batch-size' , type = int , default = 128)
  18. argument.add_argument('--n-epochs' , type = int , default = 50)
  19. argument.add_argument('--lr' , type = float , default = 1)
  20. argument.add_argument('--testset-name' , type = str , default = 'test')
  21. argument.add_argument('--devset-name' , type = str , default = 'dev')
  22. argument.add_argument('--seed' , type = int , default = 42)
  23. argument.add_argument('--hidden-size' , type = int , default = 150)
  24. argument.add_argument('--dropout' , type = float , default = 0.3)
  25. arg = argument.parse_args()
  26. random.seed(arg.seed)
  27. np.random.seed(arg.seed)
  28. torch.manual_seed(arg.seed)
  29. n_gpu = torch.cuda.device_count()
  30. if n_gpu > 0:
  31. torch.cuda.manual_seed_all(arg.seed)
  32. print (n_gpu)
  33. for k in arg.__dict__:
  34. print(k, arg.__dict__[k], type(arg.__dict__[k]))
  35. # load data set
  36. if arg.task == 'snli':
  37. data_bundle = SNLIPipe(lower=True, tokenizer='spacy').process_from_file()
  38. elif arg.task == 'rte':
  39. data_bundle = RTEPipe(lower=True, tokenizer='spacy').process_from_file()
  40. elif arg.task == 'qnli':
  41. data_bundle = QNLIPipe(lower=True, tokenizer='spacy').process_from_file()
  42. elif arg.task == 'mnli':
  43. data_bundle = MNLIPipe(lower=True, tokenizer='spacy').process_from_file()
  44. elif arg.task == 'quora':
  45. data_bundle = QuoraPipe(lower=True, tokenizer='spacy').process_from_file()
  46. else:
  47. raise RuntimeError(f'NOT support {arg.task} task yet!')
  48. print(data_bundle)
  49. print(len(data_bundle.vocabs[Const.INPUTS(0)]))
  50. model = MwanModel(
  51. num_class = len(data_bundle.vocabs[Const.TARGET]),
  52. EmbLayer = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], requires_grad=False, normalize=False),
  53. ElmoLayer = None,
  54. args_of_imm = {
  55. "input_size" : 300 ,
  56. "hidden_size" : arg.hidden_size ,
  57. "dropout" : arg.dropout ,
  58. "use_allennlp" : False ,
  59. } ,
  60. )
  61. optimizer = Adadelta(lr=arg.lr, params=model.parameters())
  62. scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
  63. callbacks = [
  64. LRScheduler(scheduler),
  65. ]
  66. if arg.task in ['snli']:
  67. callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.testset_name]))
  68. elif arg.task == 'mnli':
  69. callbacks.append(EvaluateCallback(data={'dev_matched': data_bundle.datasets['dev_matched'],
  70. 'dev_mismatched': data_bundle.datasets['dev_mismatched']},))
  71. trainer = Trainer(
  72. train_data = data_bundle.datasets['train'],
  73. model = model,
  74. optimizer = optimizer,
  75. num_workers = 0,
  76. batch_size = arg.batch_size,
  77. n_epochs = arg.n_epochs,
  78. print_every = -1,
  79. dev_data = data_bundle.datasets[arg.devset_name],
  80. metrics = AccuracyMetric(pred = "pred" , target = "target"),
  81. metric_key = 'acc',
  82. device = [i for i in range(torch.cuda.device_count())],
  83. check_code_level = -1,
  84. callbacks = callbacks,
  85. loss = CrossEntropyLoss(pred = "pred" , target = "target")
  86. )
  87. trainer.train(load_best_model=True)
  88. tester = Tester(
  89. data=data_bundle.datasets[arg.testset_name],
  90. model=model,
  91. metrics=AccuracyMetric(),
  92. batch_size=arg.batch_size,
  93. device=[i for i in range(torch.cuda.device_count())],
  94. )
  95. tester.test()