You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_lstm_att.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import os
  3. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  4. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  5. import torch.nn as nn
  6. from data.IMDBLoader import IMDBLoader
  7. from fastNLP.modules.encoder.embedding import StaticEmbedding
  8. from model.lstm_self_attention import BiLSTM_SELF_ATTENTION
  9. from fastNLP.core.const import Const as C
  10. from fastNLP import CrossEntropyLoss, AccuracyMetric
  11. from fastNLP import Trainer, Tester
  12. from torch.optim import Adam
  13. from fastNLP.io.model_io import ModelLoader, ModelSaver
  14. import argparse
  15. class Config():
  16. train_epoch= 10
  17. lr=0.001
  18. num_classes=2
  19. hidden_dim=256
  20. num_layers=1
  21. attention_unit=256
  22. attention_hops=1
  23. nfc=128
  24. task_name = "IMDB"
  25. datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
  26. save_model_path="./result_IMDB_test/"
  27. opt=Config()
  28. # load data
  29. dataloader=IMDBLoader()
  30. datainfo=dataloader.process(opt.datapath)
  31. # print(datainfo.datasets["train"])
  32. # print(datainfo)
  33. # define model
  34. vocab=datainfo.vocabs['words']
  35. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
  36. model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc)
  37. # define loss_function and metrics
  38. loss=CrossEntropyLoss()
  39. metrics=AccuracyMetric()
  40. optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
  41. def train(datainfo, model, optimizer, loss, metrics, opt):
  42. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
  43. metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1,
  44. n_epochs=opt.train_epoch, save_path=opt.save_model_path)
  45. trainer.train()
  46. if __name__ == "__main__":
  47. train(datainfo, model, optimizer, loss, metrics, opt)