You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_bert.py 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """
  2. 使用Bert进行英文命名实体识别
  3. """
  4. import sys
  5. sys.path.append('../../../')
  6. from reproduction.sequence_labelling.ner.model.bert_crf import BertCRF
  7. from fastNLP.embeddings import BertEmbedding
  8. from fastNLP import Trainer, Const
  9. from fastNLP import BucketSampler, SpanFPreRecMetric, GradientClipCallback
  10. from fastNLP.core.callback import WarmupCallback
  11. from fastNLP.core.optimizer import AdamW
  12. from fastNLP.io import Conll2003NERPipe
  13. from fastNLP import cache_results, EvaluateCallback
  14. encoding_type = 'bioes'
  15. @cache_results('caches/conll2003.pkl', _refresh=False)
  16. def load_data():
  17. # 替换路径
  18. paths = 'data/conll2003'
  19. data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths)
  20. return data
  21. data = load_data()
  22. print(data)
  23. embed = BertEmbedding(data.get_vocab(Const.INPUT), model_dir_or_name='en-base-cased',
  24. pool_method='max', requires_grad=True, layers='11', include_cls_sep=False, dropout=0.5,
  25. word_dropout=0.01)
  26. callbacks = [
  27. GradientClipCallback(clip_type='norm', clip_value=1),
  28. WarmupCallback(warmup=0.1, schedule='linear'),
  29. EvaluateCallback(data.get_dataset('test'))
  30. ]
  31. model = BertCRF(embed, tag_vocab=data.get_vocab('target'), encoding_type=encoding_type)
  32. optimizer = AdamW(model.parameters(), lr=2e-5)
  33. trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
  34. device=0, dev_data=data.datasets['dev'], batch_size=6,
  35. metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
  36. loss=None, callbacks=callbacks, num_workers=2, n_epochs=5,
  37. check_code_level=0, update_every=3, test_use_tqdm=False)
  38. trainer.train()