You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_lstm_att.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import os
  3. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  4. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  5. import torch.nn as nn
  6. from data.SSTLoader import SSTLoader
  7. from data.IMDBLoader import IMDBLoader
  8. from data.yelpLoader import yelpLoader
  9. from fastNLP.modules.encoder.embedding import StaticEmbedding
  10. from model.lstm_self_attention import BiLSTM_SELF_ATTENTION
  11. from fastNLP.core.const import Const as C
  12. from fastNLP import CrossEntropyLoss, AccuracyMetric
  13. from fastNLP import Trainer, Tester
  14. from torch.optim import Adam
  15. from fastNLP.io.model_io import ModelLoader, ModelSaver
  16. import argparse
  17. class Config():
  18. train_epoch= 10
  19. lr=0.001
  20. num_classes=2
  21. hidden_dim=256
  22. num_layers=1
  23. attention_unit=256
  24. attention_hops=1
  25. nfc=128
  26. task_name = "IMDB"
  27. datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
  28. load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51"
  29. save_model_path="./result_IMDB_test/"
  30. opt=Config
  31. # load data
  32. dataloaders = {
  33. "IMDB":IMDBLoader(),
  34. "YELP":yelpLoader(),
  35. "SST-5":SSTLoader(subtree=True,fine_grained=True),
  36. "SST-3":SSTLoader(subtree=True,fine_grained=False)
  37. }
  38. if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]:
  39. raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']")
  40. dataloader = dataloaders[opt.task_name]
  41. datainfo=dataloader.process(opt.datapath)
  42. # print(datainfo.datasets["train"])
  43. # print(datainfo)
  44. # define model
  45. vocab=datainfo.vocabs['words']
  46. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
  47. model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc)
  48. # define loss_function and metrics
  49. loss=CrossEntropyLoss()
  50. metrics=AccuracyMetric()
  51. optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
  52. def train(datainfo, model, optimizer, loss, metrics, opt):
  53. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
  54. metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1,
  55. n_epochs=opt.train_epoch, save_path=opt.save_model_path)
  56. trainer.train()
  57. def test(datainfo, metrics, opt):
  58. # load model
  59. model = ModelLoader.load_pytorch_model(opt.load_model_path)
  60. print("model loaded!")
  61. # Tester
  62. tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0)
  63. acc = tester.test()
  64. print("acc=",acc)
  65. parser = argparse.ArgumentParser()
  66. parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model')
  67. args = parser.parse_args()
  68. if args.mode == 'train':
  69. train(datainfo, model, optimizer, loss, metrics, opt)
  70. elif args.mode == 'test':
  71. test(datainfo, metrics, opt)
  72. else:
  73. print('no mode specified for model!')
  74. parser.print_help()