You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. import _pickle as pickle
  2. import argparse
  3. import collections
  4. import logging
  5. import math
  6. import os
  7. import pickle
  8. import random
  9. import sys
  10. import time
  11. from sys import maxsize
  12. import fastNLP
  13. import fastNLP.embeddings
  14. import numpy as np
  15. import torch
  16. import torch.distributed as dist
  17. import torch.nn as nn
  18. from fastNLP import BucketSampler, DataSetIter, SequentialSampler, logger
  19. from torch.nn.parallel import DistributedDataParallel
  20. from torch.utils.data.distributed import DistributedSampler
  21. import models
  22. import optm
  23. import utils
  24. NONE_TAG = "<NONE>"
  25. START_TAG = "<sos>"
  26. END_TAG = "<eos>"
  27. DEFAULT_WORD_EMBEDDING_SIZE = 100
  28. DEBUG_SCALE = 200
  29. # ===-----------------------------------------------------------------------===
  30. # Argument parsing
  31. # ===-----------------------------------------------------------------------===
  32. # fmt: off
  33. parser = argparse.ArgumentParser()
  34. parser.add_argument("--dataset", required=True, dest="dataset", help="processed data dir")
  35. parser.add_argument("--word-embeddings", dest="word_embeddings", help="File from which to read in pretrained embeds")
  36. parser.add_argument("--bigram-embeddings", dest="bigram_embeddings", help="File from which to read in pretrained embeds")
  37. parser.add_argument("--crf", dest="crf", action="store_true", help="crf")
  38. # parser.add_argument("--devi", default="0", dest="devi", help="gpu")
  39. parser.add_argument("--step", default=0, dest="step", type=int,help="step")
  40. parser.add_argument("--num-epochs", default=100, dest="num_epochs", type=int,
  41. help="Number of full passes through training set")
  42. parser.add_argument("--batch-size", default=128, dest="batch_size", type=int,
  43. help="Minibatch size of training set")
  44. parser.add_argument("--d_model", default=256, dest="d_model", type=int, help="d_model")
  45. parser.add_argument("--d_ff", default=1024, dest="d_ff", type=int, help="d_ff")
  46. parser.add_argument("--N", default=6, dest="N", type=int, help="N")
  47. parser.add_argument("--h", default=4, dest="h", type=int, help="h")
  48. parser.add_argument("--factor", default=2, dest="factor", type=float, help="Initial learning rate")
  49. parser.add_argument("--dropout", default=0.2, dest="dropout", type=float,
  50. help="Amount of dropout(not keep rate, but drop rate) to apply to embeddings part of graph")
  51. parser.add_argument("--log-dir", default="result", dest="log_dir",
  52. help="Directory where to write logs / serialized models")
  53. parser.add_argument("--task-name", default=time.strftime("%Y-%m-%d-%H-%M-%S"), dest="task_name",
  54. help="Name for this task, use a comprehensive one")
  55. parser.add_argument("--no-model", dest="no_model", action="store_true", help="Don't serialize model")
  56. parser.add_argument("--always-model", dest="always_model", action="store_true",
  57. help="Always serialize model after every epoch")
  58. parser.add_argument("--old-model", dest="old_model", help="Path to old model for incremental training")
  59. parser.add_argument("--skip-dev", dest="skip_dev", action="store_true", help="Skip dev set, would save some time")
  60. parser.add_argument("--freeze", dest="freeze", action="store_true", help="freeze pretrained embedding")
  61. parser.add_argument("--only-task", dest="only_task", action="store_true", help="only train task embedding")
  62. parser.add_argument("--subset", dest="subset", help="Only train and test on a subset of the whole dataset")
  63. parser.add_argument("--seclude", dest="seclude", help="train and test except a subset")
  64. parser.add_argument("--instances", default=None, dest="instances", type=int,help="num of instances of subset")
  65. parser.add_argument("--seed", dest="python_seed", type=int, default=random.randrange(maxsize),
  66. help="Random seed of Python and NumPy")
  67. parser.add_argument("--debug", dest="debug", default=False, action="store_true", help="Debug mode")
  68. parser.add_argument("--test", dest="test", action="store_true", help="Test mode")
  69. parser.add_argument('--local_rank', type=int, default=None)
  70. parser.add_argument('--init_method', type=str, default='env://')
  71. # fmt: on
  72. options, _ = parser.parse_known_args()
  73. print("unknown args", _)
  74. task_name = options.task_name
  75. root_dir = "{}/{}".format(options.log_dir, task_name)
  76. utils.make_sure_path_exists(root_dir)
  77. if options.local_rank is not None:
  78. torch.cuda.set_device(options.local_rank)
  79. dist.init_process_group("nccl", init_method=options.init_method)
  80. def init_logger():
  81. if not os.path.exists(root_dir):
  82. os.mkdir(root_dir)
  83. log_formatter = logging.Formatter("%(asctime)s - %(message)s")
  84. logger = logging.getLogger()
  85. file_handler = logging.FileHandler("{0}/info.log".format(root_dir), mode="w")
  86. file_handler.setFormatter(log_formatter)
  87. logger.addHandler(file_handler)
  88. console_handler = logging.StreamHandler()
  89. console_handler.setFormatter(log_formatter)
  90. logger.addHandler(console_handler)
  91. if options.local_rank is None or options.local_rank == 0:
  92. logger.setLevel(logging.INFO)
  93. else:
  94. logger.setLevel(logging.WARNING)
  95. return logger
  96. # ===-----------------------------------------------------------------------===
  97. # Set up logging
  98. # ===-----------------------------------------------------------------------===
  99. # logger = init_logger()
  100. logger.add_file("{}/info.log".format(root_dir), "INFO")
  101. logger.setLevel(logging.INFO if dist.get_rank() == 0 else logging.WARNING)
  102. # ===-----------------------------------------------------------------------===
  103. # Log some stuff about this run
  104. # ===-----------------------------------------------------------------------===
  105. logger.info(" ".join(sys.argv))
  106. logger.info("")
  107. logger.info(options)
  108. if options.debug:
  109. logger.info("DEBUG MODE")
  110. options.num_epochs = 2
  111. options.batch_size = 20
  112. random.seed(options.python_seed)
  113. np.random.seed(options.python_seed % (2 ** 32 - 1))
  114. torch.cuda.manual_seed_all(options.python_seed)
  115. logger.info("Python random seed: {}".format(options.python_seed))
  116. # ===-----------------------------------------------------------------------===
  117. # Read in dataset
  118. # ===-----------------------------------------------------------------------===
  119. dataset = pickle.load(open(options.dataset + "/total_dataset.pkl", "rb"))
  120. train_set = dataset["train_set"]
  121. test_set = dataset["test_set"]
  122. uni_vocab = dataset["uni_vocab"]
  123. bi_vocab = dataset["bi_vocab"]
  124. task_vocab = dataset["task_vocab"]
  125. tag_vocab = dataset["tag_vocab"]
  126. for v in (bi_vocab, uni_vocab, tag_vocab, task_vocab):
  127. if hasattr(v, "_word2idx"):
  128. v.word2idx = v._word2idx
  129. for ds in (train_set, test_set):
  130. ds.rename_field("ori_words", "words")
  131. logger.info("{} {}".format(bi_vocab.to_word(0), tag_vocab.word2idx))
  132. logger.info(task_vocab.word2idx)
  133. if options.skip_dev:
  134. dev_set = test_set
  135. else:
  136. train_set, dev_set = train_set.split(0.1)
  137. logger.info("{} {} {}".format(len(train_set), len(dev_set), len(test_set)))
  138. if options.debug:
  139. train_set = train_set[0:DEBUG_SCALE]
  140. dev_set = dev_set[0:DEBUG_SCALE]
  141. test_set = test_set[0:DEBUG_SCALE]
  142. # ===-----------------------------------------------------------------------===
  143. # Build model and trainer
  144. # ===-----------------------------------------------------------------------===
  145. # ===============================
  146. if dist.get_rank() != 0:
  147. dist.barrier()
  148. if options.word_embeddings is None:
  149. init_embedding = None
  150. else:
  151. # logger.info("Load: {}".format(options.word_embeddings))
  152. # init_embedding = utils.embedding_load_with_cache(options.word_embeddings, options.cache_dir, uni_vocab, normalize=False)
  153. init_embedding = fastNLP.embeddings.StaticEmbedding(
  154. uni_vocab, options.word_embeddings, word_drop=0.01
  155. )
  156. bigram_embedding = None
  157. if options.bigram_embeddings:
  158. # logger.info("Load: {}".format(options.bigram_embeddings))
  159. # bigram_embedding = utils.embedding_load_with_cache(options.bigram_embeddings, options.cache_dir, bi_vocab, normalize=False)
  160. bigram_embedding = fastNLP.embeddings.StaticEmbedding(
  161. bi_vocab, options.bigram_embeddings
  162. )
  163. if dist.get_rank() == 0:
  164. dist.barrier()
  165. # ===============================
  166. # select subset training
  167. if options.seclude is not None:
  168. setname = "<{}>".format(options.seclude)
  169. logger.info("seclude {}".format(setname))
  170. train_set.drop(lambda x: x["words"][0] == setname, inplace=True)
  171. test_set.drop(lambda x: x["words"][0] == setname, inplace=True)
  172. dev_set.drop(lambda x: x["words"][0] == setname, inplace=True)
  173. if options.subset is not None:
  174. setname = "<{}>".format(options.subset)
  175. logger.info("select {}".format(setname))
  176. train_set.drop(lambda x: x["words"][0] != setname, inplace=True)
  177. test_set.drop(lambda x: x["words"][0] != setname, inplace=True)
  178. dev_set.drop(lambda x: x["words"][0] != setname, inplace=True)
  179. # build model and optimizer
  180. i2t = None
  181. if options.crf:
  182. # i2t=utils.to_id_list(tag_vocab.word2idx)
  183. i2t = {}
  184. for x, y in tag_vocab.word2idx.items():
  185. i2t[y] = x
  186. logger.info(i2t)
  187. freeze = True if options.freeze else False
  188. model = models.make_CWS(
  189. d_model=options.d_model,
  190. N=options.N,
  191. h=options.h,
  192. d_ff=options.d_ff,
  193. dropout=options.dropout,
  194. word_embedding=init_embedding,
  195. bigram_embedding=bigram_embedding,
  196. tag_size=len(tag_vocab),
  197. task_size=len(task_vocab),
  198. crf=i2t,
  199. freeze=freeze,
  200. )
  201. device = "cpu"
  202. if torch.cuda.device_count() > 0:
  203. if options.local_rank is not None:
  204. device = "cuda:{}".format(options.local_rank)
  205. # model=nn.DataParallel(model)
  206. model = model.to(device)
  207. model = torch.nn.parallel.DistributedDataParallel(
  208. model, device_ids=[options.local_rank], output_device=options.local_rank
  209. )
  210. else:
  211. device = "cuda:0"
  212. model.to(device)
  213. if options.only_task and options.old_model is not None:
  214. logger.info("fix para except task embedding")
  215. for name, para in model.named_parameters():
  216. if name.find("task_embed") == -1:
  217. para.requires_grad = False
  218. else:
  219. para.requires_grad = True
  220. logger.info(name)
  221. optimizer = optm.NoamOpt(
  222. options.d_model,
  223. options.factor,
  224. 4000,
  225. torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9),
  226. )
  227. optimizer._step = options.step
  228. best_model_file_name = "{}/model.bin".format(root_dir)
  229. if options.local_rank is None:
  230. train_sampler = BucketSampler(
  231. batch_size=options.batch_size, seq_len_field_name="seq_len"
  232. )
  233. else:
  234. train_sampler = DistributedSampler(
  235. train_set, dist.get_world_size(), dist.get_rank()
  236. )
  237. dev_sampler = SequentialSampler()
  238. i2t = utils.to_id_list(tag_vocab.word2idx)
  239. i2task = utils.to_id_list(task_vocab.word2idx)
  240. dev_set.set_input("words")
  241. test_set.set_input("words")
  242. test_batch = DataSetIter(test_set, options.batch_size, num_workers=2)
  243. word_dic = pickle.load(open(options.dataset + "/oovdict.pkl", "rb"))
  244. def batch_to_device(batch, device):
  245. for k, v in batch.items():
  246. if torch.is_tensor(v):
  247. batch[k] = v.to(device)
  248. return batch
  249. def tester(model, test_batch, write_out=False):
  250. res = []
  251. prf = utils.CWSEvaluator(i2t)
  252. prf_dataset = {}
  253. oov_dataset = {}
  254. logger.info("start evaluation")
  255. # import ipdb; ipdb.set_trace()
  256. with torch.no_grad():
  257. for batch_x, batch_y in test_batch:
  258. batch_to_device(batch_x, device)
  259. # batch_to_device(batch_y, device)
  260. if bigram_embedding is not None:
  261. out = model(
  262. batch_x["task"],
  263. batch_x["uni"],
  264. batch_x["seq_len"],
  265. batch_x["bi1"],
  266. batch_x["bi2"],
  267. )
  268. else:
  269. out = model(batch_x["task"], batch_x["uni"], batch_x["seq_len"])
  270. out = out["pred"]
  271. # print(out)
  272. num = out.size(0)
  273. out = out.detach().cpu().numpy()
  274. for i in range(num):
  275. length = int(batch_x["seq_len"][i])
  276. out_tags = out[i, 1:length].tolist()
  277. sentence = batch_x["words"][i]
  278. gold_tags = batch_y["tags"][i][1:length].numpy().tolist()
  279. dataset_name = sentence[0]
  280. sentence = sentence[1:]
  281. # print(out_tags,gold_tags)
  282. assert utils.is_dataset_tag(dataset_name), dataset_name
  283. assert len(gold_tags) == len(out_tags) and len(gold_tags) == len(
  284. sentence
  285. )
  286. if dataset_name not in prf_dataset:
  287. prf_dataset[dataset_name] = utils.CWSEvaluator(i2t)
  288. oov_dataset[dataset_name] = utils.CWS_OOV(
  289. word_dic[dataset_name[1:-1]]
  290. )
  291. prf_dataset[dataset_name].add_instance(gold_tags, out_tags)
  292. prf.add_instance(gold_tags, out_tags)
  293. if write_out:
  294. gold_strings = utils.to_tag_strings(i2t, gold_tags)
  295. obs_strings = utils.to_tag_strings(i2t, out_tags)
  296. word_list = utils.bmes_to_words(sentence, obs_strings)
  297. oov_dataset[dataset_name].update(
  298. utils.bmes_to_words(sentence, gold_strings), word_list
  299. )
  300. raw_string = " ".join(word_list)
  301. res.append(dataset_name + " " + raw_string + " " + dataset_name)
  302. Ap = 0.0
  303. Ar = 0.0
  304. Af = 0.0
  305. Aoov = 0.0
  306. tot = 0
  307. nw = 0.0
  308. for dataset_name, performance in sorted(prf_dataset.items()):
  309. p = performance.result()
  310. if write_out:
  311. nw = oov_dataset[dataset_name].oov()
  312. # nw = 0
  313. logger.info(
  314. "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format(
  315. dataset_name, p[0], p[1], p[2], nw
  316. )
  317. )
  318. else:
  319. logger.info(
  320. "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format(
  321. dataset_name, p[0], p[1], p[2]
  322. )
  323. )
  324. Ap += p[0]
  325. Ar += p[1]
  326. Af += p[2]
  327. Aoov += nw
  328. tot += 1
  329. prf = prf.result()
  330. logger.info(
  331. "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format("TOT", prf[0], prf[1], prf[2])
  332. )
  333. if not write_out:
  334. logger.info(
  335. "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format(
  336. "AVG", Ap / tot, Ar / tot, Af / tot
  337. )
  338. )
  339. else:
  340. logger.info(
  341. "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format(
  342. "AVG", Ap / tot, Ar / tot, Af / tot, Aoov / tot
  343. )
  344. )
  345. return prf[-1], res
  346. # start training
  347. if not options.test:
  348. if options.old_model:
  349. # incremental training
  350. logger.info("Incremental training from old model: {}".format(options.old_model))
  351. model.load_state_dict(torch.load(options.old_model, map_location="cuda:0"))
  352. logger.info("Number training instances: {}".format(len(train_set)))
  353. logger.info("Number dev instances: {}".format(len(dev_set)))
  354. train_batch = DataSetIter(dataset=train_set, batch_size=options.batch_size, sampler=train_sampler, num_workers=4)
  355. dev_batch = DataSetIter(dataset=dev_set, batch_size=options.batch_size, sampler=dev_sampler, num_workers=4)
  356. best_f1 = 0.0
  357. for epoch in range(int(options.num_epochs)):
  358. logger.info("Epoch {} out of {}".format(epoch + 1, options.num_epochs))
  359. train_loss = 0.0
  360. model.train()
  361. tot = 0
  362. t1 = time.time()
  363. for batch_x, batch_y in train_batch:
  364. model.zero_grad()
  365. if bigram_embedding is not None:
  366. out = model(
  367. batch_x["task"],
  368. batch_x["uni"],
  369. batch_x["seq_len"],
  370. batch_x["bi1"],
  371. batch_x["bi2"],
  372. batch_y["tags"],
  373. )
  374. else:
  375. out = model(
  376. batch_x["task"], batch_x["uni"], batch_x["seq_len"], batch_y["tags"]
  377. )
  378. loss = out["loss"]
  379. train_loss += loss.item()
  380. tot += 1
  381. loss.backward()
  382. # nn.utils.clip_grad_value_(model.parameters(), 1)
  383. optimizer.step()
  384. t2 = time.time()
  385. train_loss = train_loss / tot
  386. logger.info(
  387. "time: {} loss: {} step: {}".format(t2 - t1, train_loss, optimizer._step)
  388. )
  389. # Evaluate dev data
  390. if options.skip_dev and dist.get_rank() == 0:
  391. logger.info("Saving model to {}".format(best_model_file_name))
  392. torch.save(model.module.state_dict(), best_model_file_name)
  393. continue
  394. model.eval()
  395. if dist.get_rank() == 0:
  396. f1, _ = tester(model.module, dev_batch)
  397. if f1 > best_f1:
  398. best_f1 = f1
  399. logger.info("- new best score!")
  400. if not options.no_model:
  401. logger.info("Saving model to {}".format(best_model_file_name))
  402. torch.save(model.module.state_dict(), best_model_file_name)
  403. elif options.always_model:
  404. logger.info("Saving model to {}".format(best_model_file_name))
  405. torch.save(model.module.state_dict(), best_model_file_name)
  406. dist.barrier()
  407. # Evaluate test data (once)
  408. logger.info("\nNumber test instances: {}".format(len(test_set)))
  409. if not options.skip_dev:
  410. if options.test:
  411. model.module.load_state_dict(
  412. torch.load(options.old_model, map_location="cuda:0")
  413. )
  414. else:
  415. model.module.load_state_dict(
  416. torch.load(best_model_file_name, map_location="cuda:0")
  417. )
  418. if dist.get_rank() == 0:
  419. for name, para in model.named_parameters():
  420. if name.find("task_embed") != -1:
  421. tm = para.detach().cpu().numpy()
  422. logger.info(tm.shape)
  423. np.save("{}/task.npy".format(root_dir), tm)
  424. break
  425. _, res = tester(model.module, test_batch, True)
  426. if dist.get_rank() == 0:
  427. with open("{}/testout.txt".format(root_dir), "w", encoding="utf-8") as raw_writer:
  428. for sent in res:
  429. raw_writer.write(sent)
  430. raw_writer.write("\n")