You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matching_mwan.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import sys
  2. import os
  3. import random
  4. import numpy as np
  5. import torch
  6. from torch.optim import Adadelta, SGD
  7. from torch.optim.lr_scheduler import StepLR
  8. from tqdm import tqdm
  9. from fastNLP import CrossEntropyLoss
  10. from fastNLP import cache_results
  11. from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
  12. from fastNLP.core.predictor import Predictor
  13. from fastNLP.core.callback import GradientClipCallback, LRScheduler, FitlogCallback
  14. from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding
  15. from fastNLP.io.data_loader import MNLILoader, QNLILoader, QuoraLoader, SNLILoader, RTELoader
  16. from reproduction.matching.model.mwan import MwanModel
  17. import fitlog
  18. fitlog.debug()
  19. import argparse
  20. argument = argparse.ArgumentParser()
  21. argument.add_argument('--task' , choices = ['snli', 'rte', 'qnli', 'mnli'],default = 'snli')
  22. argument.add_argument('--batch-size' , type = int , default = 128)
  23. argument.add_argument('--n-epochs' , type = int , default = 50)
  24. argument.add_argument('--lr' , type = float , default = 1)
  25. argument.add_argument('--testset-name' , type = str , default = 'test')
  26. argument.add_argument('--devset-name' , type = str , default = 'dev')
  27. argument.add_argument('--seed' , type = int , default = 42)
  28. argument.add_argument('--hidden-size' , type = int , default = 150)
  29. argument.add_argument('--dropout' , type = float , default = 0.3)
  30. arg = argument.parse_args()
  31. random.seed(arg.seed)
  32. np.random.seed(arg.seed)
  33. torch.manual_seed(arg.seed)
  34. n_gpu = torch.cuda.device_count()
  35. if n_gpu > 0:
  36. torch.cuda.manual_seed_all(arg.seed)
  37. print (n_gpu)
  38. for k in arg.__dict__:
  39. print(k, arg.__dict__[k], type(arg.__dict__[k]))
  40. # load data set
  41. if arg.task == 'snli':
  42. @cache_results(f'snli_mwan.pkl')
  43. def read_snli():
  44. data_info = SNLILoader().process(
  45. paths='path/to/snli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
  46. get_index=True, concat=False, extra_split=['/','%','-'],
  47. )
  48. return data_info
  49. data_info = read_snli()
  50. elif arg.task == 'rte':
  51. @cache_results(f'rte_mwan.pkl')
  52. def read_rte():
  53. data_info = RTELoader().process(
  54. paths='path/to/rte/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
  55. get_index=True, concat=False, extra_split=['/','%','-'],
  56. )
  57. return data_info
  58. data_info = read_rte()
  59. elif arg.task == 'qnli':
  60. data_info = QNLILoader().process(
  61. paths='path/to/qnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
  62. get_index=True, concat=False , cut_text=512, extra_split=['/','%','-'],
  63. )
  64. elif arg.task == 'mnli':
  65. @cache_results(f'mnli_v0.9_mwan.pkl')
  66. def read_mnli():
  67. data_info = MNLILoader().process(
  68. paths='path/to/mnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
  69. get_index=True, concat=False, extra_split=['/','%','-'],
  70. )
  71. return data_info
  72. data_info = read_mnli()
  73. else:
  74. raise RuntimeError(f'NOT support {arg.task} task yet!')
  75. print(data_info)
  76. print(len(data_info.vocabs['words']))
  77. model = MwanModel(
  78. num_class = len(data_info.vocabs[Const.TARGET]),
  79. EmbLayer = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=False, normalize=False),
  80. ElmoLayer = None,
  81. args_of_imm = {
  82. "input_size" : 300 ,
  83. "hidden_size" : arg.hidden_size ,
  84. "dropout" : arg.dropout ,
  85. "use_allennlp" : False ,
  86. } ,
  87. )
  88. optimizer = Adadelta(lr=arg.lr, params=model.parameters())
  89. scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
  90. callbacks = [
  91. LRScheduler(scheduler),
  92. ]
  93. if arg.task in ['snli']:
  94. callbacks.append(FitlogCallback(data_info.datasets[arg.testset_name], verbose=1))
  95. elif arg.task == 'mnli':
  96. callbacks.append(FitlogCallback({'dev_matched': data_info.datasets['dev_matched'],
  97. 'dev_mismatched': data_info.datasets['dev_mismatched']},
  98. verbose=1))
  99. trainer = Trainer(
  100. train_data = data_info.datasets['train'],
  101. model = model,
  102. optimizer = optimizer,
  103. num_workers = 0,
  104. batch_size = arg.batch_size,
  105. n_epochs = arg.n_epochs,
  106. print_every = -1,
  107. dev_data = data_info.datasets[arg.devset_name],
  108. metrics = AccuracyMetric(pred = "pred" , target = "target"),
  109. metric_key = 'acc',
  110. device = [i for i in range(torch.cuda.device_count())],
  111. check_code_level = -1,
  112. callbacks = callbacks,
  113. loss = CrossEntropyLoss(pred = "pred" , target = "target")
  114. )
  115. trainer.train(load_best_model=True)
  116. tester = Tester(
  117. data=data_info.datasets[arg.testset_name],
  118. model=model,
  119. metrics=AccuracyMetric(),
  120. batch_size=arg.batch_size,
  121. device=[i for i in range(torch.cuda.device_count())],
  122. )
  123. tester.test()