You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import sys
  2. sys.path.append('../..')
  3. from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader
  4. from fastNLP.embeddings.static_embedding import StaticEmbedding
  5. from torch import nn
  6. from functools import partial
  7. from reproduction.joint_cws_parse.models.CharParser import CharParser
  8. from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric
  9. from fastNLP import BucketSampler, Trainer
  10. from torch import optim
  11. from reproduction.joint_cws_parse.models.callbacks import DevCallback
  12. from torch.optim.lr_scheduler import StepLR
  13. from fastNLP import Tester
  14. from fastNLP import GradientClipCallback, LRScheduler
  15. import os
  16. from fastNLP import cache_results
  17. def set_random_seed(random_seed=666):
  18. import random, numpy, torch
  19. random.seed(random_seed)
  20. numpy.random.seed(random_seed)
  21. torch.cuda.manual_seed(random_seed)
  22. torch.random.manual_seed(random_seed)
  23. uniform_init = partial(nn.init.normal_, std=0.02)
  24. ###################################################
  25. # 需要变动的超参放到这里
  26. lr = 0.002 # 0.01~0.001
  27. dropout = 0.33 # 0.3~0.6
  28. weight_decay = 0 # 1e-5, 1e-6, 0
  29. arc_mlp_size = 500 # 200, 300
  30. rnn_hidden_size = 400 # 200, 300, 400
  31. rnn_layers = 3 # 2, 3
  32. encoder = 'var-lstm' # var-lstm, lstm
  33. emb_size = 100 # 64 , 100
  34. label_mlp_size = 100
  35. batch_size = 32
  36. update_every = 4
  37. n_epochs = 100
  38. data_name = 'new_ctb7'
  39. ####################################################
  40. data_folder = f'/remote-home/hyan01/exps/JointCwsPosParser/data/{data_name}/output' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件
  41. vector_folder = '/remote-home/hyan01/exps/CWS/pretrain/vectors' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt
  42. set_random_seed(1234)
  43. device = 0
  44. @cache_results('caches/{}.pkl'.format(data_name))
  45. def get_data():
  46. data = CTBxJointLoader().process(data_folder)
  47. char_labels_vocab = data.vocabs['char_labels']
  48. pre_chars_vocab = data.vocabs['pre_chars']
  49. pre_bigrams_vocab = data.vocabs['pre_bigrams']
  50. pre_trigrams_vocab = data.vocabs['pre_trigrams']
  51. chars_vocab = data.vocabs['chars']
  52. bigrams_vocab = data.vocabs['bigrams']
  53. trigrams_vocab = data.vocabs['trigrams']
  54. pre_chars_embed = StaticEmbedding(pre_chars_vocab,
  55. model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'),
  56. init_method=uniform_init, normalize=False)
  57. pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data / pre_chars_embed.embedding.weight.data.std()
  58. pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab,
  59. model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'),
  60. init_method=uniform_init, normalize=False)
  61. pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data / pre_bigrams_embed.embedding.weight.data.std()
  62. pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab,
  63. model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'),
  64. init_method=uniform_init, normalize=False)
  65. pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data / pre_trigrams_embed.embedding.weight.data.std()
  66. return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data
  67. chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data()
  68. print(data)
  69. model = CharParser(char_vocab_size=len(chars_vocab),
  70. emb_dim=emb_size,
  71. bigram_vocab_size=len(bigrams_vocab),
  72. trigram_vocab_size=len(trigrams_vocab),
  73. num_label=len(char_labels_vocab),
  74. rnn_layers=rnn_layers,
  75. rnn_hidden_size=rnn_hidden_size,
  76. arc_mlp_size=arc_mlp_size,
  77. label_mlp_size=label_mlp_size,
  78. dropout=dropout,
  79. encoder=encoder,
  80. use_greedy_infer=False,
  81. app_index=char_labels_vocab['APP'],
  82. pre_chars_embed=pre_chars_embed,
  83. pre_bigrams_embed=pre_bigrams_embed,
  84. pre_trigrams_embed=pre_trigrams_embed)
  85. metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP'])
  86. metric2 = CWSMetric(char_labels_vocab['APP'])
  87. metrics = [metric1, metric2]
  88. optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr,
  89. weight_decay=weight_decay, betas=[0.9, 0.9])
  90. sampler = BucketSampler(seq_len_field_name='seq_lens')
  91. callbacks = []
  92. from fastNLP.core.callback import Callback
  93. from torch.optim.lr_scheduler import LambdaLR
  94. class SchedulerCallback(Callback):
  95. def __init__(self, scheduler):
  96. super().__init__()
  97. self.scheduler = scheduler
  98. def on_backward_end(self):
  99. if self.step % self.update_every==0:
  100. self.scheduler.step()
  101. scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
  102. # scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
  103. # scheduler = StepLR(optimizer, step_size=18, gamma=0.75)
  104. scheduler_callback = SchedulerCallback(scheduler)
  105. # callbacks.append(optim_callback)
  106. # scheduler_callback = LRScheduler(scheduler)
  107. callbacks.append(scheduler_callback)
  108. callbacks.append(GradientClipCallback(clip_type='value', clip_value=5))
  109. tester = Tester(data=data.datasets['test'], model=model, metrics=metrics,
  110. batch_size=64, device=device, verbose=0)
  111. dev_callback = DevCallback(tester)
  112. callbacks.append(dev_callback)
  113. trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3,
  114. validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer,
  115. check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True,
  116. device=device, callbacks=callbacks, update_every=update_every)
  117. trainer.train()