You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matching_esim.py 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import argparse
  2. import torch
  3. from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
  4. from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding
  5. from reproduction.matching.data.MatchingDataLoader import SNLILoader
  6. from reproduction.matching.model.esim import ESIMModel
  7. argument = argparse.ArgumentParser()
  8. argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove')
  9. argument.add_argument('--batch-size-per-gpu', type=int, default=128)
  10. argument.add_argument('--n-epochs', type=int, default=100)
  11. argument.add_argument('--lr', type=float, default=1e-4)
  12. argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len')
  13. argument.add_argument('--save-dir', type=str, default=None)
  14. arg = argument.parse_args()
  15. bert_dirs = 'path/to/bert/dir'
  16. # load data set
  17. data_info = SNLILoader().process(
  18. paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
  19. get_index=True, concat=False,
  20. )
  21. # load embedding
  22. if arg.embedding == 'elmo':
  23. embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
  24. elif arg.embedding == 'glove':
  25. embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
  26. else:
  27. raise ValueError(f'now we only support elmo or glove embedding for esim model!')
  28. # define model
  29. model = ESIMModel(embedding)
  30. # define trainer
  31. trainer = Trainer(train_data=data_info.datasets['train'], model=model,
  32. optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
  33. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  34. n_epochs=arg.n_epochs, print_every=-1,
  35. dev_data=data_info.datasets['dev'],
  36. metrics=AccuracyMetric(), metric_key='acc',
  37. device=[i for i in range(torch.cuda.device_count())],
  38. check_code_level=-1,
  39. save_path=arg.save_path)
  40. # train model
  41. trainer.train(load_best_model=True)
  42. # define tester
  43. tester = Tester(
  44. data=data_info.datasets['test'],
  45. model=model,
  46. metrics=AccuracyMetric(),
  47. batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
  48. device=[i for i in range(torch.cuda.device_count())],
  49. )
  50. # test model
  51. tester.test()