You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_shift_relay.py 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import sys
  2. sys.path.append('../../..')
  3. from fastNLP import cache_results
  4. from reproduction.sequence_labelling.cws.data.cws_shift_pipe import CWSShiftRelayPipe
  5. from reproduction.sequence_labelling.cws.model.bilstm_shift_relay import ShiftRelayCWSModel
  6. from fastNLP import Trainer
  7. from torch.optim import Adam
  8. from fastNLP import BucketSampler
  9. from fastNLP import GradientClipCallback
  10. from reproduction.sequence_labelling.cws.model.metric import RelayMetric
  11. from fastNLP.embeddings import StaticEmbedding
  12. from fastNLP import EvaluateCallback
  13. #########hyper
  14. L = 4
  15. hidden_size = 200
  16. num_layers = 1
  17. drop_p = 0.2
  18. lr = 0.008
  19. data_name = 'pku'
  20. #########hyper
  21. device = 0
  22. cache_fp = 'caches/{}.pkl'.format(data_name)
  23. @cache_results(_cache_fp=cache_fp, _refresh=True) # 将结果缓存到cache_fp中,这样下次运行就直接读取,而不需要再次运行
  24. def prepare_data():
  25. data_bundle = CWSShiftRelayPipe(dataset_name=data_name, L=L).process_from_file()
  26. # 预训练的character embedding和bigram embedding
  27. char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.5, word_dropout=0.01,
  28. model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt')
  29. bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.5, min_freq=3, word_dropout=0.01,
  30. model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt')
  31. return data_bundle, char_embed, bigram_embed
  32. data, char_embed, bigram_embed = prepare_data()
  33. model = ShiftRelayCWSModel(char_embed=char_embed, bigram_embed=bigram_embed,
  34. hidden_size=hidden_size, num_layers=num_layers, drop_p=drop_p, L=L)
  35. sampler = BucketSampler()
  36. optimizer = Adam(model.parameters(), lr=lr)
  37. clipper = GradientClipCallback(clip_value=5, clip_type='value') # 截断太大的梯度
  38. evaluator = EvaluateCallback(data.get_dataset('test')) # 额外测试在test集上的效果
  39. callbacks = [clipper, evaluator]
  40. trainer = Trainer(data.get_dataset('train'), model, optimizer=optimizer, loss=None, batch_size=128, sampler=sampler,
  41. update_every=1, n_epochs=10, print_every=5, dev_data=data.get_dataset('dev'), metrics=RelayMetric(),
  42. metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks,
  43. check_code_level=0, num_workers=1)
  44. trainer.train()