You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_shift_relay.py 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. from fastNLP import cache_results
  3. from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader
  4. from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel
  5. from fastNLP.io.embed_loader import EmbeddingOption
  6. from fastNLP.core.vocabulary import VocabularyOption
  7. from fastNLP import Trainer
  8. from torch.optim import Adam
  9. from fastNLP import BucketSampler
  10. from fastNLP import GradientClipCallback
  11. from reproduction.seqence_labelling.cws.model.metric import RelayMetric
  12. # 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果
  13. @cache_results(None)
  14. def prepare_data():
  15. data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt,
  16. bigram_vocab_opt=bigram_vocab_opt,
  17. bigram_embed_opt=bigram_embed_opt,
  18. L=L)
  19. return data
  20. #########hyper
  21. L = 4
  22. hidden_size = 200
  23. num_layers = 1
  24. drop_p = 0.2
  25. lr = 0.02
  26. #########hyper
  27. device = 0
  28. # !!!!这里前往不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到
  29. # 你们的reproduction路径下,然后设置.gitignore
  30. file_dir = '/path/to/pku'
  31. char_embed_path = '/path/to/1grams_t3_m50_corpus.txt'
  32. bigram_embed_path = 'path/to/2grams_t3_m50_corpus.txt'
  33. bigram_vocab_opt = VocabularyOption(min_freq=3)
  34. char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path)
  35. bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path)
  36. data_name = os.path.basename(file_dir)
  37. cache_fp = 'caches/{}.pkl'.format(data_name)
  38. data = prepare_data(_cache_fp=cache_fp, _refresh=False)
  39. model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'],
  40. hidden_size=hidden_size, num_layers=num_layers,
  41. L=L, num_bigram_per_char=1, drop_p=drop_p)
  42. sampler = BucketSampler(batch_size=32)
  43. optimizer = Adam(model.parameters(), lr=lr)
  44. clipper = GradientClipCallback(clip_value=5, clip_type='value')
  45. callbacks = [clipper]
  46. # if pretrain:
  47. # fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until)
  48. # callbacks.append(fixer)
  49. trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None,
  50. batch_size=32, sampler=sampler, update_every=5,
  51. n_epochs=3, print_every=5,
  52. dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f',
  53. validate_every=-1, save_path=None,
  54. prefetch=True, use_tqdm=True, device=device,
  55. callbacks=callbacks,
  56. check_code_level=0)
  57. trainer.train()