You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. """Train Model1: baseline model"""
  18. import os
  19. import sys
  20. import json
  21. import argparse
  22. import datetime
  23. import torch
  24. import torch.nn
  25. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  26. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  27. sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/')
  28. from fastNLP.core.const import Const
  29. from fastNLP.core.trainer import Trainer, Tester
  30. from fastNLP.io.model_io import ModelLoader, ModelSaver
  31. from fastNLP.io.embed_loader import EmbedLoader
  32. from tools.logger import *
  33. from data.dataloader import SummarizationLoader
  34. # from model.TransformerModel import TransformerModel
  35. from model.TForiginal import TransformerModel
  36. from model.Metric import LabelFMetric, FastRougeMetric, PyRougeMetric
  37. from model.Loss import MyCrossEntropyLoss
  38. from tools.Callback import TrainCallback
  39. def setup_training(model, train_loader, valid_loader, hps):
  40. """Does setup before starting training (run_training)"""
  41. train_dir = os.path.join(hps.save_root, "train")
  42. if not os.path.exists(train_dir): os.makedirs(train_dir)
  43. if hps.restore_model != 'None':
  44. logger.info("[INFO] Restoring %s for training...", hps.restore_model)
  45. bestmodel_file = os.path.join(train_dir, hps.restore_model)
  46. loader = ModelLoader()
  47. loader.load_pytorch(model, bestmodel_file)
  48. else:
  49. logger.info("[INFO] Create new model for training...")
  50. try:
  51. run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted
  52. except KeyboardInterrupt:
  53. logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...")
  54. save_file = os.path.join(train_dir, "earlystop.pkl")
  55. saver = ModelSaver(save_file)
  56. saver.save_pytorch(model)
  57. logger.info('[INFO] Saving early stop model to %s', save_file)
  58. def run_training(model, train_loader, valid_loader, hps):
  59. """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
  60. logger.info("[INFO] Starting run_training")
  61. train_dir = os.path.join(hps.save_root, "train")
  62. if not os.path.exists(train_dir): os.makedirs(train_dir)
  63. eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data
  64. if not os.path.exists(eval_dir): os.makedirs(eval_dir)
  65. lr = hps.lr
  66. optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
  67. criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none')
  68. # criterion = torch.nn.CrossEntropyLoss(reduce="none")
  69. trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion,
  70. n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")],
  71. metric_key="f", validate_every=-1, save_path=eval_dir,
  72. callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False)
  73. train_info = trainer.train(load_best_model=True)
  74. logger.info(' | end of Train | time: {:5.2f}s | '.format(train_info["seconds"]))
  75. logger.info('[INFO] best eval model in epoch %d and iter %d', train_info["best_epoch"], train_info["best_step"])
  76. logger.info(train_info["best_eval"])
  77. bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved
  78. saver = ModelSaver(bestmodel_save_path)
  79. saver.save_pytorch(model)
  80. logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path)
  81. def run_test(model, loader, hps, limited=False):
  82. """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
  83. test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data
  84. eval_dir = os.path.join(hps.save_root, "eval")
  85. if not os.path.exists(test_dir) : os.makedirs(test_dir)
  86. if not os.path.exists(eval_dir) :
  87. logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir)
  88. raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir))
  89. if hps.test_model == "evalbestmodel":
  90. bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved
  91. elif hps.test_model == "earlystop":
  92. train_dir = os.path.join(hps.save_root, "train")
  93. bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl')
  94. else:
  95. logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
  96. raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
  97. logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path)
  98. modelloader = ModelLoader()
  99. modelloader.load_pytorch(model, bestmodel_load_path)
  100. if hps.use_pyrouge:
  101. logger.info("[INFO] Use PyRougeMetric for testing")
  102. tester = Tester(data=loader, model=model,
  103. metrics=[LabelFMetric(pred="prediction"), PyRougeMetric(hps, pred="prediction")],
  104. batch_size=hps.batch_size)
  105. else:
  106. logger.info("[INFO] Use FastRougeMetric for testing")
  107. tester = Tester(data=loader, model=model,
  108. metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")],
  109. batch_size=hps.batch_size)
  110. test_info = tester.test()
  111. logger.info(test_info)
  112. def main():
  113. parser = argparse.ArgumentParser(description='Summarization Model')
  114. # Where to find data
  115. parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.')
  116. parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.')
  117. parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.')
  118. # Important settings
  119. parser.add_argument('--mode', choices=['train', 'test'], default='train', help='must be one of train/test')
  120. parser.add_argument('--embedding', type=str, default='glove', choices=['word2vec', 'glove', 'elmo', 'bert'], help='must be one of word2vec/glove/elmo/bert')
  121. parser.add_argument('--sentence_encoder', type=str, default='transformer', choices=['bilstm', 'deeplstm', 'transformer'], help='must be one of LSTM/Transformer')
  122. parser.add_argument('--sentence_decoder', type=str, default='SeqLab', choices=['PN', 'SeqLab'], help='must be one of PN/SeqLab')
  123. parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]')
  124. # Where to save output
  125. parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.')
  126. parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.')
  127. # Hyperparameters
  128. parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]')
  129. parser.add_argument('--cuda', action='store_true', default=False, help='use cuda')
  130. parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.')
  131. parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]')
  132. parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]')
  133. parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding')
  134. parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.')
  135. parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]')
  136. parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]')
  137. parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]')
  138. parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]')
  139. parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel')
  140. parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]')
  141. parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)')
  142. parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)')
  143. parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention')
  144. # Training
  145. parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
  146. parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent')
  147. parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps')
  148. parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping')
  149. parser.add_argument('--max_grad_norm', type=float, default=10, help='for gradient clipping max gradient normalization')
  150. # test
  151. parser.add_argument('-m', type=int, default=3, help='decode summary length')
  152. parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length')
  153. parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]')
  154. parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge')
  155. args = parser.parse_args()
  156. os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
  157. torch.set_printoptions(threshold=50000)
  158. # File paths
  159. DATA_FILE = args.data_path
  160. VALID_FILE = args.valid_path
  161. VOCAL_FILE = args.vocab_path
  162. LOG_PATH = args.log_root
  163. # train_log setting
  164. if not os.path.exists(LOG_PATH):
  165. if args.mode == "train":
  166. os.makedirs(LOG_PATH)
  167. else:
  168. logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH)
  169. raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH))
  170. nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
  171. log_path = os.path.join(LOG_PATH, args.mode + "_" + nowTime)
  172. file_handler = logging.FileHandler(log_path)
  173. file_handler.setFormatter(formatter)
  174. logger.addHandler(file_handler)
  175. logger.info("Pytorch %s", torch.__version__)
  176. sum_loader = SummarizationLoader()
  177. hps = args
  178. if hps.mode == 'test':
  179. paths = {"test": DATA_FILE}
  180. hps.recurrent_dropout_prob = 0.0
  181. hps.atten_dropout_prob = 0.0
  182. hps.ffn_dropout_prob = 0.0
  183. logger.info(hps)
  184. else:
  185. paths = {"train": DATA_FILE, "valid": VALID_FILE}
  186. dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE))
  187. if args.embedding == "glove":
  188. vocab = dataInfo.vocabs["vocab"]
  189. embed = torch.nn.Embedding(len(vocab), hps.word_emb_dim)
  190. if hps.word_embedding:
  191. embed_loader = EmbedLoader()
  192. pretrained_weight = embed_loader.load_with_vocab(hps.embedding_path, vocab) # unfound with random init
  193. embed.weight.data.copy_(torch.from_numpy(pretrained_weight))
  194. embed.weight.requires_grad = hps.embed_train
  195. else:
  196. logger.error("[ERROR] embedding To Be Continued!")
  197. sys.exit(1)
  198. if args.sentence_encoder == "transformer" and args.sentence_decoder == "SeqLab":
  199. model_param = json.load(open("config/transformer.config", "rb"))
  200. hps.__dict__.update(model_param)
  201. model = TransformerModel(hps, embed)
  202. else:
  203. logger.error("[ERROR] Model To Be Continued!")
  204. sys.exit(1)
  205. logger.info(hps)
  206. if hps.cuda:
  207. model = model.cuda()
  208. logger.info("[INFO] Use cuda")
  209. if hps.mode == 'train':
  210. dataInfo.datasets["valid"].set_target("text", "summary")
  211. setup_training(model, dataInfo.datasets["train"], dataInfo.datasets["valid"], hps)
  212. elif hps.mode == 'test':
  213. logger.info("[INFO] Decoding...")
  214. dataInfo.datasets["test"].set_target("text", "summary")
  215. run_test(model, dataInfo.datasets["test"], hps, limited=hps.limited)
  216. else:
  217. logger.error("The 'mode' flag must be one of train/eval/test")
  218. raise ValueError("The 'mode' flag must be one of train/eval/test")
  219. if __name__ == '__main__':
  220. main()