You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_ontonote.py 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import sys
  2. sys.path.append('../../..')
  3. from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding, StackEmbedding
  4. from reproduction.sequence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF
  5. from fastNLP import Trainer
  6. from fastNLP import SpanFPreRecMetric
  7. from fastNLP import Const
  8. from torch.optim import SGD
  9. from torch.optim.lr_scheduler import LambdaLR
  10. from fastNLP import GradientClipCallback
  11. from fastNLP import BucketSampler
  12. from fastNLP.core.callback import EvaluateCallback, LRScheduler
  13. from fastNLP import cache_results
  14. from fastNLP.io.pipe.conll import OntoNotesNERPipe
  15. #######hyper
  16. normalize = False
  17. lr = 0.01
  18. dropout = 0.5
  19. batch_size = 32
  20. data_name = 'ontonote'
  21. #######hyper
  22. encoding_type = 'bioes'
  23. @cache_results('caches/ontonotes.pkl', _refresh=True)
  24. def cache():
  25. data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file('../../../../others/data/v4/english')
  26. char_embed = CNNCharEmbedding(vocab=data.vocabs['words'], embed_size=30, char_emb_size=30, filter_nums=[30],
  27. kernel_sizes=[3], dropout=dropout)
  28. word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
  29. model_dir_or_name='en-glove-6b-100d',
  30. requires_grad=True,
  31. normalize=normalize,
  32. word_dropout=0.01,
  33. dropout=dropout,
  34. lower=True,
  35. min_freq=1)
  36. return data, char_embed, word_embed
  37. data, char_embed, word_embed = cache()
  38. print(data)
  39. embed = StackEmbedding([word_embed, char_embed])
  40. model = CNNBiLSTMCRF(embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
  41. encoding_type=encoding_type, dropout=dropout)
  42. callbacks = [
  43. GradientClipCallback(clip_value=5, clip_type='value'),
  44. EvaluateCallback(data.datasets['test'])
  45. ]
  46. optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
  47. scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)))
  48. callbacks.append(scheduler)
  49. trainer = Trainer(train_data=data.get_dataset('train'), model=model, optimizer=optimizer, sampler=BucketSampler(num_buckets=100),
  50. device=0, dev_data=data.get_dataset('dev'), batch_size=batch_size,
  51. metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
  52. callbacks=callbacks, num_workers=1, n_epochs=100, dev_batch_size=256)
  53. trainer.train()