You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_lstm.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import os
  3. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  4. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  5. import torch.nn as nn
  6. from data.SSTLoader import SSTLoader
  7. from data.IMDBLoader import IMDBLoader
  8. from data.yelpLoader import yelpLoader
  9. from fastNLP.modules.encoder.embedding import StaticEmbedding
  10. from model.lstm import BiLSTMSentiment
  11. from fastNLP.core.const import Const as C
  12. from fastNLP import CrossEntropyLoss, AccuracyMetric
  13. from fastNLP import Trainer, Tester
  14. from torch.optim import Adam
  15. from fastNLP.io.model_io import ModelLoader, ModelSaver
  16. import argparse
  17. class Config():
  18. train_epoch= 10
  19. lr=0.001
  20. num_classes=2
  21. hidden_dim=256
  22. num_layers=1
  23. nfc=128
  24. task_name = "IMDB"
  25. datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
  26. load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51"
  27. save_model_path="./result_IMDB_test/"
  28. opt=Config
  29. # load data
  30. dataloaders = {
  31. "IMDB":IMDBLoader(),
  32. "YELP":yelpLoader(),
  33. "SST-5":SSTLoader(subtree=True,fine_grained=True),
  34. "SST-3":SSTLoader(subtree=True,fine_grained=False)
  35. }
  36. if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]:
  37. raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']")
  38. dataloader = dataloaders[opt.task_name]
  39. datainfo=dataloader.process(opt.datapath)
  40. # print(datainfo.datasets["train"])
  41. # print(datainfo)
  42. # define model
  43. vocab=datainfo.vocabs['words']
  44. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
  45. model=BiLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc)
  46. # define loss_function and metrics
  47. loss=CrossEntropyLoss()
  48. metrics=AccuracyMetric()
  49. optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
  50. def train(datainfo, model, optimizer, loss, metrics, opt):
  51. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
  52. metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1,
  53. n_epochs=opt.train_epoch, save_path=opt.save_model_path)
  54. trainer.train()
  55. def test(datainfo, metrics, opt):
  56. # load model
  57. model = ModelLoader.load_pytorch_model(opt.load_model_path)
  58. print("model loaded!")
  59. # Tester
  60. tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0)
  61. acc = tester.test()
  62. print("acc=",acc)
  63. parser = argparse.ArgumentParser()
  64. parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model')
  65. args = parser.parse_args()
  66. if args.mode == 'train':
  67. train(datainfo, model, optimizer, loss, metrics, opt)
  68. elif args.mode == 'test':
  69. test(datainfo, metrics, opt)
  70. else:
  71. print('no mode specified for model!')
  72. parser.print_help()