You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_awdlstm.py 1.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # 这个模型需要在pytorch=0.4下运行,weight_drop不支持1.0
  2. import sys
  3. sys.path.append('../..')
  4. from fastNLP.io.pipe.classification import IMDBPipe
  5. from fastNLP.embeddings import StaticEmbedding
  6. from model.awd_lstm import AWDLSTMSentiment
  7. from fastNLP import CrossEntropyLoss, AccuracyMetric
  8. from fastNLP import Trainer
  9. from torch.optim import Adam
  10. class Config():
  11. train_epoch= 10
  12. lr=0.001
  13. num_classes=2
  14. hidden_dim=256
  15. num_layers=1
  16. nfc=128
  17. wdrop=0.5
  18. task_name = "IMDB"
  19. datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
  20. save_model_path="./result_IMDB_test/"
  21. opt=Config()
  22. # load data
  23. data_bundle=IMDBPipe.process_from_file(opt.datapath)
  24. # print(data_bundle.datasets["train"])
  25. # print(data_bundle)
  26. # define model
  27. vocab=data_bundle.vocabs['words']
  28. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
  29. model=AWDLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc, wdrop=opt.wdrop)
  30. # define loss_function and metrics
  31. loss=CrossEntropyLoss()
  32. metrics=AccuracyMetric()
  33. optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
  34. def train(datainfo, model, optimizer, loss, metrics, opt):
  35. trainer = Trainer(data_bundle.datasets['train'], model, optimizer=optimizer, loss=loss,
  36. metrics=metrics, dev_data=data_bundle.datasets['test'], device=0, check_code_level=-1,
  37. n_epochs=opt.train_epoch, save_path=opt.save_model_path)
  38. trainer.train()
  39. if __name__ == "__main__":
  40. train(data_bundle, model, optimizer, loss, metrics, opt)