You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_bilstm_crf.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import sys
  2. sys.path.append('../../..')
  3. from fastNLP.io.pipe.cws import CWSPipe
  4. from reproduction.sequence_labelling.cws.model.bilstm_crf_cws import BiLSTMCRF
  5. from fastNLP import Trainer, cache_results
  6. from fastNLP.embeddings import StaticEmbedding
  7. from fastNLP import EvaluateCallback, BucketSampler, SpanFPreRecMetric, GradientClipCallback
  8. from torch.optim import Adagrad
  9. ###########hyper
  10. dataname = 'pku'
  11. hidden_size = 400
  12. num_layers = 1
  13. lr = 0.05
  14. ###########hyper
  15. @cache_results('{}.pkl'.format(dataname), _refresh=False)
  16. def get_data():
  17. data_bundle = CWSPipe(dataset_name=dataname, bigrams=True, trigrams=False).process_from_file()
  18. char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.33, word_dropout=0.01,
  19. model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt')
  20. bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.33,min_freq=3, word_dropout=0.01,
  21. model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt')
  22. return data_bundle, char_embed, bigram_embed
  23. data_bundle, char_embed, bigram_embed = get_data()
  24. print(data_bundle)
  25. model = BiLSTMCRF(char_embed, hidden_size, num_layers, target_vocab=data_bundle.get_vocab('target'), bigram_embed=bigram_embed,
  26. trigram_embed=None, dropout=0.3)
  27. model.cuda()
  28. callbacks = []
  29. callbacks.append(EvaluateCallback(data_bundle.get_dataset('test')))
  30. callbacks.append(GradientClipCallback(clip_type='value', clip_value=5))
  31. optimizer = Adagrad(model.parameters(), lr=lr)
  32. metrics = []
  33. metric1 = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type='bmes')
  34. metrics.append(metric1)
  35. trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer=optimizer, loss=None,
  36. batch_size=128, sampler=BucketSampler(), update_every=1,
  37. num_workers=1, n_epochs=10, print_every=5,
  38. dev_data=data_bundle.get_dataset('dev'),
  39. metrics=metrics,
  40. metric_key=None,
  41. validate_every=-1, save_path=None, use_tqdm=True, device=0,
  42. callbacks=callbacks, check_code_level=0, dev_batch_size=128)
  43. trainer.train()