You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_lstm.py 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import os
  3. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  4. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  5. from fastNLP.io.data_loader import IMDBLoader
  6. from fastNLP.embeddings import StaticEmbedding
  7. from model.lstm import BiLSTMSentiment
  8. from fastNLP import CrossEntropyLoss, AccuracyMetric
  9. from fastNLP import Trainer
  10. from torch.optim import Adam
  11. class Config():
  12. train_epoch= 10
  13. lr=0.001
  14. num_classes=2
  15. hidden_dim=256
  16. num_layers=1
  17. nfc=128
  18. task_name = "IMDB"
  19. datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
  20. save_model_path="./result_IMDB_test/"
  21. opt=Config()
  22. # load data
  23. dataloader=IMDBLoader()
  24. datainfo=dataloader.process(opt.datapath)
  25. # print(datainfo.datasets["train"])
  26. # print(datainfo)
  27. # define model
  28. vocab=datainfo.vocabs['words']
  29. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
  30. model=BiLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc)
  31. # define loss_function and metrics
  32. loss=CrossEntropyLoss()
  33. metrics=AccuracyMetric()
  34. optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
  35. def train(datainfo, model, optimizer, loss, metrics, opt):
  36. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
  37. metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1,
  38. n_epochs=opt.train_epoch, save_path=opt.save_model_path)
  39. trainer.train()
  40. if __name__ == "__main__":
  41. train(datainfo, model, optimizer, loss, metrics, opt)