You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_lstm.py 1.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import sys
  2. sys.path.append('../..')
  3. from fastNLP.io.pipe.classification import IMDBPipe
  4. from fastNLP.embeddings import StaticEmbedding
  5. from model.lstm import BiLSTMSentiment
  6. from fastNLP import CrossEntropyLoss, AccuracyMetric
  7. from fastNLP import Trainer
  8. from torch.optim import Adam
  9. class Config():
  10. train_epoch= 10
  11. lr=0.001
  12. num_classes=2
  13. hidden_dim=256
  14. num_layers=1
  15. nfc=128
  16. task_name = "IMDB"
  17. datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
  18. save_model_path="./result_IMDB_test/"
  19. opt=Config()
  20. # load data
  21. data_bundle=IMDBPipe.process_from_file(opt.datapath)
  22. # print(data_bundle.datasets["train"])
  23. # print(data_bundle)
  24. # define model
  25. vocab=data_bundle.vocabs['words']
  26. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
  27. model=BiLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc)
  28. # define loss_function and metrics
  29. loss=CrossEntropyLoss()
  30. metrics=AccuracyMetric()
  31. optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
  32. def train(data_bundle, model, optimizer, loss, metrics, opt):
  33. trainer = Trainer(data_bundle.datasets['train'], model, optimizer=optimizer, loss=loss,
  34. metrics=metrics, dev_data=data_bundle.datasets['test'], device=0, check_code_level=-1,
  35. n_epochs=opt.train_epoch, save_path=opt.save_model_path)
  36. trainer.train()
  37. if __name__ == "__main__":
  38. train(data_bundle, model, optimizer, loss, metrics, opt)