You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matching_cntn.py 3.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import argparse
  2. import torch
  3. from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const, CrossEntropyLoss
  4. from fastNLP.embeddings import StaticEmbedding
  5. from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe
  6. from reproduction.matching.model.cntn import CNTNModel
  7. # define hyper-parameters
  8. argument = argparse.ArgumentParser()
  9. argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glove')
  10. argument.add_argument('--batch-size-per-gpu', type=int, default=256)
  11. argument.add_argument('--n-epochs', type=int, default=200)
  12. argument.add_argument('--lr', type=float, default=1e-5)
  13. argument.add_argument('--save-dir', type=str, default=None)
  14. argument.add_argument('--cntn-depth', type=int, default=1)
  15. argument.add_argument('--cntn-ns', type=int, default=200)
  16. argument.add_argument('--cntn-k-top', type=int, default=10)
  17. argument.add_argument('--cntn-r', type=int, default=5)
  18. argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli')
  19. arg = argument.parse_args()
  20. # dataset dict
  21. dev_dict = {
  22. 'qnli': 'dev',
  23. 'rte': 'dev',
  24. 'snli': 'dev',
  25. 'mnli': 'dev_matched',
  26. }
  27. test_dict = {
  28. 'qnli': 'dev',
  29. 'rte': 'dev',
  30. 'snli': 'test',
  31. 'mnli': 'dev_matched',
  32. }
  33. # set num_labels
  34. if arg.dataset == 'qnli' or arg.dataset == 'rte':
  35. num_labels = 2
  36. else:
  37. num_labels = 3
  38. # load data set
  39. if arg.dataset == 'snli':
  40. data_bundle = SNLIPipe(lower=True, tokenizer='raw').process_from_file()
  41. elif arg.dataset == 'rte':
  42. data_bundle = RTEPipe(lower=True, tokenizer='raw').process_from_file()
  43. elif arg.dataset == 'qnli':
  44. data_bundle = QNLIPipe(lower=True, tokenizer='raw').process_from_file()
  45. elif arg.dataset == 'mnli':
  46. data_bundle = MNLIPipe(lower=True, tokenizer='raw').process_from_file()
  47. else:
  48. raise RuntimeError(f'NOT support {arg.task} task yet!')
  49. print(data_bundle) # print details in data_bundle
  50. # load embedding
  51. if arg.embedding == 'word2vec':
  52. embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-word2vec-300',
  53. requires_grad=True)
  54. elif arg.embedding == 'glove':
  55. embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d',
  56. requires_grad=True)
  57. else:
  58. raise ValueError(f'now we only support word2vec or glove embedding for cntn model!')
  59. # define model
  60. model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=num_labels, depth=arg.cntn_depth,
  61. r=arg.cntn_r)
  62. print(model)
  63. # define trainer
  64. trainer = Trainer(train_data=data_bundle.datasets['train'], model=model,
  65. optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
  66. loss=CrossEntropyLoss(),
  67. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  68. n_epochs=arg.n_epochs, print_every=-1,
  69. dev_data=data_bundle.datasets[dev_dict[arg.dataset]],
  70. metrics=AccuracyMetric(), metric_key='acc',
  71. device=[i for i in range(torch.cuda.device_count())],
  72. check_code_level=-1)
  73. # train model
  74. trainer.train(load_best_model=True)
  75. # define tester
  76. tester = Tester(
  77. data=data_bundle.datasets[test_dict[arg.dataset]],
  78. model=model,
  79. metrics=AccuracyMetric(),
  80. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  81. device=[i for i in range(torch.cuda.device_count())]
  82. )
  83. # test model
  84. tester.test()