import argparse import torch from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const, CrossEntropyLoss from fastNLP.embeddings import StaticEmbedding from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe from reproduction.matching.model.cntn import CNTNModel # define hyper-parameters argument = argparse.ArgumentParser() argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glove') argument.add_argument('--batch-size-per-gpu', type=int, default=256) argument.add_argument('--n-epochs', type=int, default=200) argument.add_argument('--lr', type=float, default=1e-5) argument.add_argument('--save-dir', type=str, default=None) argument.add_argument('--cntn-depth', type=int, default=1) argument.add_argument('--cntn-ns', type=int, default=200) argument.add_argument('--cntn-k-top', type=int, default=10) argument.add_argument('--cntn-r', type=int, default=5) argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli') arg = argument.parse_args() # dataset dict dev_dict = { 'qnli': 'dev', 'rte': 'dev', 'snli': 'dev', 'mnli': 'dev_matched', } test_dict = { 'qnli': 'dev', 'rte': 'dev', 'snli': 'test', 'mnli': 'dev_matched', } # set num_labels if arg.dataset == 'qnli' or arg.dataset == 'rte': num_labels = 2 else: num_labels = 3 # load data set if arg.dataset == 'snli': data_bundle = SNLIPipe(lower=True, tokenizer='raw').process_from_file() elif arg.dataset == 'rte': data_bundle = RTEPipe(lower=True, tokenizer='raw').process_from_file() elif arg.dataset == 'qnli': data_bundle = QNLIPipe(lower=True, tokenizer='raw').process_from_file() elif arg.dataset == 'mnli': data_bundle = MNLIPipe(lower=True, tokenizer='raw').process_from_file() else: raise RuntimeError(f'NOT support {arg.task} task yet!') print(data_bundle) # print details in data_bundle # load embedding if arg.embedding == 'word2vec': embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-word2vec-300', requires_grad=True) elif arg.embedding == 'glove': embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d', requires_grad=True) else: raise ValueError(f'now we only support word2vec or glove embedding for cntn model!') # define model model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=num_labels, depth=arg.cntn_depth, r=arg.cntn_r) print(model) # define trainer trainer = Trainer(train_data=data_bundle.datasets['train'], model=model, optimizer=Adam(lr=arg.lr, model_params=model.parameters()), loss=CrossEntropyLoss(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, n_epochs=arg.n_epochs, print_every=-1, dev_data=data_bundle.datasets[dev_dict[arg.dataset]], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1) # train model trainer.train(load_best_model=True) # define tester tester = Tester( data=data_bundle.datasets[test_dict[arg.dataset]], model=model, metrics=AccuracyMetric(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, device=[i for i in range(torch.cuda.device_count())] ) # test model tester.test()