You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_origin.py 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. """Train Model1: baseline model"""
  18. import os
  19. import sys
  20. import time
  21. import copy
  22. import pickle
  23. import datetime
  24. import argparse
  25. import logging
  26. import numpy as np
  27. import torch
  28. import torch.nn as nn
  29. from torch.autograd import Variable
  30. from rouge import Rouge
  31. sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/')
  32. from fastNLP.core.batch import DataSetIter
  33. from fastNLP.core.const import Const
  34. from fastNLP.io.model_io import ModelLoader, ModelSaver
  35. from fastNLP.core.sampler import BucketSampler
  36. from tools import utils
  37. from tools.logger import *
  38. from data.dataloader import SummarizationLoader
  39. from model.TForiginal import TransformerModel
  40. def setup_training(model, train_loader, valid_loader, hps):
  41. """Does setup before starting training (run_training)"""
  42. train_dir = os.path.join(hps.save_root, "train")
  43. if not os.path.exists(train_dir): os.makedirs(train_dir)
  44. if hps.restore_model != 'None':
  45. logger.info("[INFO] Restoring %s for training...", hps.restore_model)
  46. bestmodel_file = os.path.join(train_dir, hps.restore_model)
  47. loader = ModelLoader()
  48. loader.load_pytorch(model, bestmodel_file)
  49. else:
  50. logger.info("[INFO] Create new model for training...")
  51. try:
  52. run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted
  53. except KeyboardInterrupt:
  54. logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...")
  55. save_file = os.path.join(train_dir, "earlystop.pkl")
  56. saver = ModelSaver(save_file)
  57. saver.save_pytorch(model)
  58. logger.info('[INFO] Saving early stop model to %s', save_file)
  59. def run_training(model, train_loader, valid_loader, hps):
  60. """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
  61. logger.info("[INFO] Starting run_training")
  62. train_dir = os.path.join(hps.save_root, "train")
  63. if not os.path.exists(train_dir): os.makedirs(train_dir)
  64. lr = hps.lr
  65. # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98),
  66. # eps=1e-09)
  67. optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
  68. criterion = torch.nn.CrossEntropyLoss(reduction='none')
  69. best_train_loss = None
  70. best_train_F= None
  71. best_loss = None
  72. best_F = None
  73. step_num = 0
  74. non_descent_cnt = 0
  75. for epoch in range(1, hps.n_epochs + 1):
  76. epoch_loss = 0.0
  77. train_loss = 0.0
  78. total_example_num = 0
  79. match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
  80. epoch_start_time = time.time()
  81. for i, (batch_x, batch_y) in enumerate(train_loader):
  82. # if i > 10:
  83. # break
  84. model.train()
  85. iter_start_time=time.time()
  86. input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
  87. label = batch_y[Const.TARGET]
  88. # logger.info(batch_x["text"][0])
  89. # logger.info(input[0,:,:])
  90. # logger.info(input_len[0:5,:])
  91. # logger.info(batch_y["summary"][0:5])
  92. # logger.info(label[0:5,:])
  93. # logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0)))
  94. batch_size, N, seq_len = input.size()
  95. if hps.cuda:
  96. input = input.cuda() # [batch, N, seq_len]
  97. label = label.cuda()
  98. input_len = input_len.cuda()
  99. input = Variable(input)
  100. label = Variable(label)
  101. input_len = Variable(input_len)
  102. model_outputs = model.forward(input, input_len) # [batch, N, 2]
  103. outputs = model_outputs["p_sent"].view(-1, 2)
  104. label = label.view(-1)
  105. loss = criterion(outputs, label) # [batch_size, doc_max_timesteps]
  106. # input_len = input_len.float().view(-1)
  107. loss = loss.view(batch_size, -1)
  108. loss = loss.masked_fill(input_len.eq(0), 0)
  109. loss = loss.sum(1).mean()
  110. logger.debug("loss %f", loss)
  111. if not (np.isfinite(loss.data)).numpy():
  112. logger.error("train Loss is not finite. Stopping.")
  113. logger.info(loss)
  114. for name, param in model.named_parameters():
  115. if param.requires_grad:
  116. logger.info(name)
  117. logger.info(param.grad.data.sum())
  118. raise Exception("train Loss is not finite. Stopping.")
  119. optimizer.zero_grad()
  120. loss.backward()
  121. if hps.grad_clip:
  122. torch.nn.utils.clip_grad_norm_(model.parameters(), hps.max_grad_norm)
  123. optimizer.step()
  124. step_num += 1
  125. train_loss += float(loss.data)
  126. epoch_loss += float(loss.data)
  127. if i % 100 == 0:
  128. # start debugger
  129. # import pdb; pdb.set_trace()
  130. for name, param in model.named_parameters():
  131. if param.requires_grad:
  132. logger.debug(name)
  133. logger.debug(param.grad.data.sum())
  134. logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | '
  135. .format(i, (time.time() - iter_start_time),
  136. float(train_loss / 100)))
  137. train_loss = 0.0
  138. # calculate the precision, recall and F
  139. prediction = outputs.max(1)[1]
  140. prediction = prediction.data
  141. label = label.data
  142. pred += prediction.sum()
  143. true += label.sum()
  144. match_true += ((prediction == label) & (prediction == 1)).sum()
  145. match += (prediction == label).sum()
  146. total_example_num += int(batch_size * N)
  147. if hps.lr_descent:
  148. # new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5),
  149. # step_num * pow(hps.warmup_steps, -1.5))
  150. new_lr = max(5e-6, lr / (epoch + 1))
  151. for param_group in list(optimizer.param_groups):
  152. param_group['lr'] = new_lr
  153. logger.info("[INFO] The learning rate now is %f", new_lr)
  154. epoch_avg_loss = epoch_loss / len(train_loader)
  155. logger.info(' | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | '
  156. .format(epoch, (time.time() - epoch_start_time),
  157. float(epoch_avg_loss)))
  158. logger.info("[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match)
  159. accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match)
  160. logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", total_example_num / hps.doc_max_timesteps, accu, precision, recall, F)
  161. if not best_train_loss or epoch_avg_loss < best_train_loss:
  162. save_file = os.path.join(train_dir, "bestmodel.pkl")
  163. logger.info('[INFO] Found new best model with %.3f running_train_loss. Saving to %s', float(epoch_avg_loss), save_file)
  164. saver = ModelSaver(save_file)
  165. saver.save_pytorch(model)
  166. best_train_loss = epoch_avg_loss
  167. elif epoch_avg_loss > best_train_loss:
  168. logger.error("[Error] training loss does not descent. Stopping supervisor...")
  169. save_file = os.path.join(train_dir, "earlystop.pkl")
  170. saver = ModelSaver(save_file)
  171. saver.save_pytorch(model)
  172. logger.info('[INFO] Saving early stop model to %s', save_file)
  173. return
  174. if not best_train_F or F > best_train_F:
  175. save_file = os.path.join(train_dir, "bestFmodel.pkl")
  176. logger.info('[INFO] Found new best model with %.3f F score. Saving to %s', float(F), save_file)
  177. saver = ModelSaver(save_file)
  178. saver.save_pytorch(model)
  179. best_train_F = F
  180. best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps, best_loss, best_F, non_descent_cnt)
  181. if non_descent_cnt >= 3:
  182. logger.error("[Error] val loss does not descent for three times. Stopping supervisor...")
  183. save_file = os.path.join(train_dir, "earlystop")
  184. saver = ModelSaver(save_file)
  185. saver.save_pytorch(model)
  186. logger.info('[INFO] Saving early stop model to %s', save_file)
  187. return
  188. def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt):
  189. """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
  190. logger.info("[INFO] Starting eval for this model ...")
  191. eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data
  192. if not os.path.exists(eval_dir): os.makedirs(eval_dir)
  193. model.eval()
  194. running_loss = 0.0
  195. match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
  196. pairs = {}
  197. pairs["hyps"] = []
  198. pairs["refer"] = []
  199. total_example_num = 0
  200. criterion = torch.nn.CrossEntropyLoss(reduction='none')
  201. iter_start_time = time.time()
  202. with torch.no_grad():
  203. for i, (batch_x, batch_y) in enumerate(loader):
  204. # if i > 10:
  205. # break
  206. input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
  207. label = batch_y[Const.TARGET]
  208. if hps.cuda:
  209. input = input.cuda() # [batch, N, seq_len]
  210. label = label.cuda()
  211. input_len = input_len.cuda()
  212. batch_size, N, _ = input.size()
  213. input = Variable(input, requires_grad=False)
  214. label = Variable(label)
  215. input_len = Variable(input_len, requires_grad=False)
  216. model_outputs = model.forward(input, input_len) # [batch, N, 2]
  217. outputs = model_outputs["p_sent"]
  218. prediction = model_outputs["prediction"]
  219. outputs = outputs.view(-1, 2) # [batch * N, 2]
  220. label = label.view(-1) # [batch * N]
  221. loss = criterion(outputs, label)
  222. loss = loss.view(batch_size, -1)
  223. loss = loss.masked_fill(input_len.eq(0), 0)
  224. loss = loss.sum(1).mean()
  225. logger.debug("loss %f", loss)
  226. running_loss += float(loss.data)
  227. label = label.data.view(batch_size, -1)
  228. pred += prediction.sum()
  229. true += label.sum()
  230. match_true += ((prediction == label) & (prediction == 1)).sum()
  231. match += (prediction == label).sum()
  232. total_example_num += batch_size * N
  233. # rouge
  234. prediction = prediction.view(batch_size, -1)
  235. for j in range(batch_size):
  236. original_article_sents = batch_x["text"][j]
  237. sent_max_number = len(original_article_sents)
  238. refer = "\n".join(batch_x["summary"][j])
  239. hyps = "\n".join(original_article_sents[id] for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number)
  240. if sent_max_number < hps.m and len(hyps) <= 1:
  241. logger.error("sent_max_number is too short %d, Skip!" , sent_max_number)
  242. continue
  243. if len(hyps) >= 1 and hyps != '.':
  244. # logger.debug(prediction[j])
  245. pairs["hyps"].append(hyps)
  246. pairs["refer"].append(refer)
  247. elif refer == "." or refer == "":
  248. logger.error("Refer is None!")
  249. logger.debug("label:")
  250. logger.debug(label[j])
  251. logger.debug(refer)
  252. elif hyps == "." or hyps == "":
  253. logger.error("hyps is None!")
  254. logger.debug("sent_max_number:%d", sent_max_number)
  255. logger.debug("prediction:")
  256. logger.debug(prediction[j])
  257. logger.debug(hyps)
  258. else:
  259. logger.error("Do not select any sentences!")
  260. logger.debug("sent_max_number:%d", sent_max_number)
  261. logger.debug(original_article_sents)
  262. logger.debug("label:")
  263. logger.debug(label[j])
  264. continue
  265. running_avg_loss = running_loss / len(loader)
  266. if hps.use_pyrouge:
  267. logger.info("The number of pairs is %d", len(pairs["hyps"]))
  268. logging.getLogger('global').setLevel(logging.WARNING)
  269. if not len(pairs["hyps"]):
  270. logger.error("During testing, no hyps is selected!")
  271. return
  272. if isinstance(pairs["refer"][0], list):
  273. logger.info("Multi Reference summaries!")
  274. scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"])
  275. else:
  276. scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"])
  277. else:
  278. if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0 :
  279. logger.error("During testing, no hyps is selected!")
  280. return
  281. rouge = Rouge()
  282. scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
  283. # try:
  284. # scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
  285. # except ValueError as e:
  286. # logger.error(repr(e))
  287. # scores_all = []
  288. # for idx in range(len(pairs["hyps"])):
  289. # try:
  290. # scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0]
  291. # scores_all.append(scores)
  292. # except ValueError as e:
  293. # logger.error(repr(e))
  294. # logger.debug("HYPS:\t%s", pairs["hyps"][idx])
  295. # logger.debug("REFER:\t%s", pairs["refer"][idx])
  296. # finally:
  297. # logger.error("During testing, some errors happen!")
  298. # logger.error(len(scores_all))
  299. # exit(1)
  300. logger.info('[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | '
  301. .format((time.time() - iter_start_time),
  302. float(running_avg_loss)))
  303. logger.info("[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d", match_true, pred, true, total_example_num, match)
  304. accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match)
  305. logger.info("[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f",
  306. total_example_num / hps.doc_max_timesteps, accu, precision, recall, F)
  307. res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
  308. + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
  309. + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
  310. logger.info(res)
  311. # If running_avg_loss is best so far, save this checkpoint (early stopping).
  312. # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
  313. if best_loss is None or running_avg_loss < best_loss:
  314. bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved
  315. if best_loss is not None:
  316. logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s', float(running_avg_loss), float(best_loss), bestmodel_save_path)
  317. else:
  318. logger.info('[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s', float(running_avg_loss), bestmodel_save_path)
  319. saver = ModelSaver(bestmodel_save_path)
  320. saver.save_pytorch(model)
  321. best_loss = running_avg_loss
  322. non_descent_cnt = 0
  323. else:
  324. non_descent_cnt += 1
  325. if best_F is None or best_F < F:
  326. bestmodel_save_path = os.path.join(eval_dir, 'bestFmodel.pkl') # this is where checkpoints of best models are saved
  327. if best_F is not None:
  328. logger.info('[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s', float(F), float(best_F), bestmodel_save_path)
  329. else:
  330. logger.info('[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s', float(F), bestmodel_save_path)
  331. saver = ModelSaver(bestmodel_save_path)
  332. saver.save_pytorch(model)
  333. best_F = F
  334. return best_loss, best_F, non_descent_cnt
  335. def run_test(model, loader, hps, limited=False):
  336. """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
  337. test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data
  338. eval_dir = os.path.join(hps.save_root, "eval")
  339. if not os.path.exists(test_dir) : os.makedirs(test_dir)
  340. if not os.path.exists(eval_dir) :
  341. logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir)
  342. raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir))
  343. if hps.test_model == "evalbestmodel":
  344. bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved
  345. elif hps.test_model == "evalbestFmodel":
  346. bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl')
  347. elif hps.test_model == "trainbestmodel":
  348. train_dir = os.path.join(hps.save_root, "train")
  349. bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl')
  350. elif hps.test_model == "trainbestFmodel":
  351. train_dir = os.path.join(hps.save_root, "train")
  352. bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl')
  353. elif hps.test_model == "earlystop":
  354. train_dir = os.path.join(hps.save_root, "train")
  355. bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl')
  356. else:
  357. logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
  358. raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
  359. logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path)
  360. modelloader = ModelLoader()
  361. modelloader.load_pytorch(model, bestmodel_load_path)
  362. import datetime
  363. nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')#现在
  364. if hps.save_label:
  365. log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1])
  366. resfile = open(log_dir, "w")
  367. else:
  368. log_dir = os.path.join(test_dir, nowTime)
  369. resfile = open(log_dir, "wb")
  370. logger.info("[INFO] Write the Evaluation into %s", log_dir)
  371. model.eval()
  372. match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
  373. total_example_num = 0.0
  374. pairs = {}
  375. pairs["hyps"] = []
  376. pairs["refer"] = []
  377. pred_list = []
  378. iter_start_time=time.time()
  379. with torch.no_grad():
  380. for i, (batch_x, batch_y) in enumerate(loader):
  381. input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
  382. label = batch_y[Const.TARGET]
  383. if hps.cuda:
  384. input = input.cuda() # [batch, N, seq_len]
  385. label = label.cuda()
  386. input_len = input_len.cuda()
  387. batch_size, N, _ = input.size()
  388. input = Variable(input)
  389. input_len = Variable(input_len, requires_grad=False)
  390. model_outputs = model.forward(input, input_len) # [batch, N, 2]
  391. prediction = model_outputs["prediction"]
  392. if hps.save_label:
  393. pred_list.extend(model_outputs["pred_idx"].data.cpu().view(-1).tolist())
  394. continue
  395. pred += prediction.sum()
  396. true += label.sum()
  397. match_true += ((prediction == label) & (prediction == 1)).sum()
  398. match += (prediction == label).sum()
  399. total_example_num += batch_size * N
  400. for j in range(batch_size):
  401. original_article_sents = batch_x["text"][j]
  402. sent_max_number = len(original_article_sents)
  403. refer = "\n".join(batch_x["summary"][j])
  404. hyps = "\n".join(original_article_sents[id].replace("\n", "") for id in range(len(prediction[j])) if prediction[j][id]==1 and id < sent_max_number)
  405. if limited:
  406. k = len(refer.split())
  407. hyps = " ".join(hyps.split()[:k])
  408. logger.info((len(refer.split()),len(hyps.split())))
  409. resfile.write(b"Original_article:")
  410. resfile.write("\n".join(batch_x["text"][j]).encode('utf-8'))
  411. resfile.write(b"\n")
  412. resfile.write(b"Reference:")
  413. if isinstance(refer, list):
  414. for ref in refer:
  415. resfile.write(ref.encode('utf-8'))
  416. resfile.write(b"\n")
  417. resfile.write(b'*' * 40)
  418. resfile.write(b"\n")
  419. else:
  420. resfile.write(refer.encode('utf-8'))
  421. resfile.write(b"\n")
  422. resfile.write(b"hypothesis:")
  423. resfile.write(hyps.encode('utf-8'))
  424. resfile.write(b"\n")
  425. if hps.use_pyrouge:
  426. pairs["hyps"].append(hyps)
  427. pairs["refer"].append(refer)
  428. else:
  429. try:
  430. scores = utils.rouge_all(hyps, refer)
  431. pairs["hyps"].append(hyps)
  432. pairs["refer"].append(refer)
  433. except ValueError:
  434. logger.error("Do not select any sentences!")
  435. logger.debug("sent_max_number:%d", sent_max_number)
  436. logger.debug(original_article_sents)
  437. logger.debug("label:")
  438. logger.debug(label[j])
  439. continue
  440. # single example res writer
  441. res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \
  442. + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \
  443. + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'])
  444. resfile.write(res.encode('utf-8'))
  445. resfile.write(b'-' * 89)
  446. resfile.write(b"\n")
  447. if hps.save_label:
  448. import json
  449. json.dump(pred_list, resfile)
  450. logger.info(' | end of test | time: {:5.2f}s | '.format((time.time() - iter_start_time)))
  451. return
  452. resfile.write(b"\n")
  453. resfile.write(b'=' * 89)
  454. resfile.write(b"\n")
  455. if hps.use_pyrouge:
  456. logger.info("The number of pairs is %d", len(pairs["hyps"]))
  457. if not len(pairs["hyps"]):
  458. logger.error("During testing, no hyps is selected!")
  459. return
  460. if isinstance(pairs["refer"][0], list):
  461. logger.info("Multi Reference summaries!")
  462. scores_all = utils.pyrouge_score_all_multi(pairs["hyps"], pairs["refer"])
  463. else:
  464. scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"])
  465. else:
  466. logger.info("The number of pairs is %d", len(pairs["hyps"]))
  467. if not len(pairs["hyps"]):
  468. logger.error("During testing, no hyps is selected!")
  469. return
  470. rouge = Rouge()
  471. scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
  472. # the whole model res writer
  473. resfile.write(b"The total testset is:")
  474. res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
  475. + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
  476. + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
  477. resfile.write(res.encode("utf-8"))
  478. logger.info(res)
  479. logger.info(' | end of test | time: {:5.2f}s | '
  480. .format((time.time() - iter_start_time)))
  481. # label prediction
  482. logger.info("match_true %d, pred %d, true %d, total %d, match %d", match, pred, true, total_example_num, match)
  483. accu, precision, recall, F = utils.eval_label(match_true, pred, true, total_example_num, match)
  484. res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (total_example_num / hps.doc_max_timesteps, accu, precision, recall, F)
  485. resfile.write(res.encode('utf-8'))
  486. logger.info("The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f", len(loader), accu, precision, recall, F)
  487. def main():
  488. parser = argparse.ArgumentParser(description='Transformer Model')
  489. # Where to find data
  490. parser.add_argument('--data_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/train.label.jsonl', help='Path expression to pickle datafiles.')
  491. parser.add_argument('--valid_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/val.label.jsonl', help='Path expression to pickle valid datafiles.')
  492. parser.add_argument('--vocab_path', type=str, default='/remote-home/dqwang/Datasets/CNNDM/vocab', help='Path expression to text vocabulary file.')
  493. parser.add_argument('--embedding_path', type=str, default='/remote-home/dqwang/Glove/glove.42B.300d.txt', help='Path expression to external word embedding.')
  494. # Important settings
  495. parser.add_argument('--mode', type=str, default='train', help='must be one of train/test')
  496. parser.add_argument('--restore_model', type=str , default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]')
  497. parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]')
  498. parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge')
  499. # Where to save output
  500. parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.')
  501. parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.')
  502. # Hyperparameters
  503. parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]')
  504. parser.add_argument('--cuda', action='store_true', default=False, help='use cuda')
  505. parser.add_argument('--vocab_size', type=int, default=100000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.')
  506. parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]')
  507. parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]')
  508. parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding')
  509. parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]')
  510. parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]')
  511. parser.add_argument('--min_kernel_size', type=int, default=1, help='kernel min length for CNN [default:1]')
  512. parser.add_argument('--max_kernel_size', type=int, default=7, help='kernel max length for CNN [default:7]')
  513. parser.add_argument('--output_channel', type=int, default=50, help='output channel: repeated times for one kernel')
  514. parser.add_argument('--n_layers', type=int, default=12, help='Number of deeplstm layers')
  515. parser.add_argument('--hidden_size', type=int, default=512, help='hidden size [default: 512]')
  516. parser.add_argument('--ffn_inner_hidden_size', type=int, default=2048, help='PositionwiseFeedForward inner hidden size [default: 2048]')
  517. parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]')
  518. parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1, help='recurrent dropout prob [default: 0.1]')
  519. parser.add_argument('--atten_dropout_prob', type=float, default=0.1,help='attention dropout prob [default: 0.1]')
  520. parser.add_argument('--ffn_dropout_prob', type=float, default=0.1, help='PositionwiseFeedForward dropout prob [default: 0.1]')
  521. parser.add_argument('--use_orthnormal_init', action='store_true', default=True, help='use orthnormal init for lstm [default: true]')
  522. parser.add_argument('--sent_max_len', type=int, default=100, help='max length of sentences (max source text sentence tokens)')
  523. parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)')
  524. parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention')
  525. # Training
  526. parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
  527. parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent')
  528. parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps')
  529. parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping')
  530. parser.add_argument('--max_grad_norm', type=float, default=1.0, help='for gradient clipping max gradient normalization')
  531. parser.add_argument('-m', type=int, default=3, help='decode summary length')
  532. parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length')
  533. args = parser.parse_args()
  534. os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
  535. torch.set_printoptions(threshold=50000)
  536. hps = args
  537. # File paths
  538. DATA_FILE = args.data_path
  539. VALID_FILE = args.valid_path
  540. VOCAL_FILE = args.vocab_path
  541. LOG_PATH = args.log_root
  542. # train_log setting
  543. if not os.path.exists(LOG_PATH):
  544. if hps.mode == "train":
  545. os.makedirs(LOG_PATH)
  546. else:
  547. logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH)
  548. raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH))
  549. nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
  550. log_path = os.path.join(LOG_PATH, hps.mode + "_" + nowTime)
  551. file_handler = logging.FileHandler(log_path)
  552. file_handler.setFormatter(formatter)
  553. logger.addHandler(file_handler)
  554. logger.info("Pytorch %s", torch.__version__)
  555. logger.info(args)
  556. logger.info(args)
  557. sum_loader = SummarizationLoader()
  558. if hps.mode == 'test':
  559. paths = {"test": DATA_FILE}
  560. hps.recurrent_dropout_prob = 0.0
  561. hps.atten_dropout_prob = 0.0
  562. hps.ffn_dropout_prob = 0.0
  563. logger.info(hps)
  564. else:
  565. paths = {"train": DATA_FILE, "valid": VALID_FILE}
  566. dataInfo = sum_loader.process(paths=paths, vocab_size=hps.vocab_size, vocab_path=VOCAL_FILE, sent_max_len=hps.sent_max_len, doc_max_timesteps=hps.doc_max_timesteps, load_vocab=os.path.exists(VOCAL_FILE))
  567. vocab = dataInfo.vocabs["vocab"]
  568. model = TransformerModel(hps, vocab)
  569. if len(hps.gpu) > 1:
  570. gpuid = hps.gpu.split(',')
  571. gpuid = [int(s) for s in gpuid]
  572. model = nn.DataParallel(model,device_ids=gpuid)
  573. logger.info("[INFO] Use Multi-gpu: %s", hps.gpu)
  574. if hps.cuda:
  575. model = model.cuda()
  576. logger.info("[INFO] Use cuda")
  577. if hps.mode == 'train':
  578. trainset = dataInfo.datasets["train"]
  579. train_sampler = BucketSampler(batch_size=hps.batch_size, seq_len_field_name=Const.INPUT)
  580. train_batch = DataSetIter(dataset=trainset, batch_size=hps.batch_size, sampler=train_sampler)
  581. validset = dataInfo.datasets["valid"]
  582. validset.set_input("text", "summary")
  583. valid_batch = DataSetIter(dataset=validset, batch_size=hps.batch_size)
  584. setup_training(model, train_batch, valid_batch, hps)
  585. elif hps.mode == 'test':
  586. logger.info("[INFO] Decoding...")
  587. testset = dataInfo.datasets["test"]
  588. testset.set_input("text", "summary")
  589. test_batch = DataSetIter(dataset=testset, batch_size=hps.batch_size)
  590. run_test(model, test_batch, hps, limited=hps.limited)
  591. else:
  592. logger.error("The 'mode' flag must be one of train/eval/test")
  593. raise ValueError("The 'mode' flag must be one of train/eval/test")
  594. if __name__ == '__main__':
  595. main()