import argparse import torch from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding from reproduction.matching.data.MatchingDataLoader import SNLILoader from reproduction.matching.model.esim import ESIMModel argument = argparse.ArgumentParser() argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') argument.add_argument('--batch-size-per-gpu', type=int, default=128) argument.add_argument('--n-epochs', type=int, default=100) argument.add_argument('--lr', type=float, default=1e-4) argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') argument.add_argument('--save-dir', type=str, default=None) arg = argument.parse_args() bert_dirs = 'path/to/bert/dir' # load data set data_info = SNLILoader().process( paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, get_index=True, concat=False, ) # load embedding if arg.embedding == 'elmo': embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) elif arg.embedding == 'glove': embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) else: raise ValueError(f'now we only support elmo or glove embedding for esim model!') # define model model = ESIMModel(embedding) # define trainer trainer = Trainer(train_data=data_info.datasets['train'], model=model, optimizer=Adam(lr=arg.lr, model_params=model.parameters()), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, n_epochs=arg.n_epochs, print_every=-1, dev_data=data_info.datasets['dev'], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1, save_path=arg.save_path) # train model trainer.train(load_best_model=True) # define tester tester = Tester( data=data_info.datasets['test'], model=model, metrics=AccuracyMetric(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, device=[i for i in range(torch.cuda.device_count())], ) # test model tester.test()