You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_lstm_att.py 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import sys
  2. sys.path.append('../..')
  3. from fastNLP.io.pipe.classification import IMDBPipe
  4. from fastNLP.embeddings import StaticEmbedding
  5. from model.lstm_self_attention import BiLSTM_SELF_ATTENTION
  6. from fastNLP import CrossEntropyLoss, AccuracyMetric
  7. from fastNLP import Trainer
  8. from torch.optim import Adam
  9. class Config():
  10. train_epoch= 10
  11. lr=0.001
  12. num_classes=2
  13. hidden_dim=256
  14. num_layers=1
  15. attention_unit=256
  16. attention_hops=1
  17. nfc=128
  18. task_name = "IMDB"
  19. datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
  20. save_model_path="./result_IMDB_test/"
  21. opt=Config()
  22. # load data
  23. data_bundle=IMDBPipe.process_from_file(opt.datapath)
  24. # print(data_bundle.datasets["train"])
  25. # print(data_bundle)
  26. # define model
  27. vocab=data_bundle.vocabs['words']
  28. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
  29. model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc)
  30. # define loss_function and metrics
  31. loss=CrossEntropyLoss()
  32. metrics=AccuracyMetric()
  33. optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
  34. def train(data_bundle, model, optimizer, loss, metrics, opt):
  35. trainer = Trainer(data_bundle.datasets['train'], model, optimizer=optimizer, loss=loss,
  36. metrics=metrics, dev_data=data_bundle.datasets['test'], device=0, check_code_level=-1,
  37. n_epochs=opt.train_epoch, save_path=opt.save_model_path)
  38. trainer.train()
  39. if __name__ == "__main__":
  40. train(data_bundle, model, optimizer, loss, metrics, opt)