You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_idcnn.py 4.2 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from fastNLP.io import OntoNotesNERPipe
  2. from fastNLP.core.callback import LRScheduler
  3. from fastNLP import GradientClipCallback
  4. from torch.optim.lr_scheduler import LambdaLR
  5. from torch.optim import Adam
  6. from fastNLP import Const
  7. from fastNLP import BucketSampler
  8. from fastNLP import SpanFPreRecMetric
  9. from fastNLP import Trainer, Tester
  10. from fastNLP.core.metrics import MetricBase
  11. from reproduction.sequence_labelling.ner.model.dilated_cnn import IDCNN
  12. from fastNLP.core.utils import Option
  13. from fastNLP.embeddings import StaticEmbedding
  14. from fastNLP.core.utils import cache_results
  15. import torch.cuda
  16. import os
  17. encoding_type = 'bioes'
  18. def get_path(path):
  19. return os.path.join(os.environ['HOME'], path)
  20. ops = Option(
  21. batch_size=128,
  22. num_epochs=100,
  23. lr=3e-4,
  24. repeats=3,
  25. num_layers=3,
  26. num_filters=400,
  27. use_crf=False,
  28. gradient_clip=5,
  29. )
  30. @cache_results('ontonotes-case-cache')
  31. def load_data():
  32. print('loading data')
  33. data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file(
  34. paths = get_path('workdir/datasets/ontonotes-v4'))
  35. print('loading embedding')
  36. word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
  37. model_dir_or_name='en-glove-840b-300',
  38. requires_grad=True)
  39. return data, [word_embed]
  40. data, embeds = load_data()
  41. print(data)
  42. print(data.datasets['train'][0])
  43. print(list(data.vocabs.keys()))
  44. # for ds in data.datasets.values():
  45. # ds.rename_field('cap_words', 'chars')
  46. # ds.set_input('chars')
  47. word_embed = embeds[0]
  48. word_embed.embedding.weight.data /= word_embed.embedding.weight.data.std()
  49. # char_embed = CNNCharEmbedding(data.vocabs['cap_words'])
  50. char_embed = None
  51. # for ds in data.datasets:
  52. # ds.rename_field('')
  53. print(data.vocabs[Const.TARGET].word2idx)
  54. model = IDCNN(init_embed=word_embed,
  55. char_embed=char_embed,
  56. num_cls=len(data.vocabs[Const.TARGET]),
  57. repeats=ops.repeats,
  58. num_layers=ops.num_layers,
  59. num_filters=ops.num_filters,
  60. kernel_size=3,
  61. use_crf=ops.use_crf, use_projection=True,
  62. block_loss=True,
  63. input_dropout=0.5, hidden_dropout=0.2, inner_dropout=0.2)
  64. print(model)
  65. callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='value'),]
  66. metrics = []
  67. metrics.append(
  68. SpanFPreRecMetric(
  69. tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type,
  70. pred=Const.OUTPUT, target=Const.TARGET, seq_len=Const.INPUT_LEN,
  71. )
  72. )
  73. class LossMetric(MetricBase):
  74. def __init__(self, loss=None):
  75. super(LossMetric, self).__init__()
  76. self._init_param_map(loss=loss)
  77. self.total_loss = 0.0
  78. self.steps = 0
  79. def evaluate(self, loss):
  80. self.total_loss += float(loss)
  81. self.steps += 1
  82. def get_metric(self, reset=True):
  83. result = {'loss': self.total_loss / (self.steps + 1e-12)}
  84. if reset:
  85. self.total_loss = 0.0
  86. self.steps = 0
  87. return result
  88. metrics.append(
  89. LossMetric(loss=Const.LOSS)
  90. )
  91. optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=0)
  92. scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)))
  93. callbacks.append(scheduler)
  94. # callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 15)))
  95. # optimizer = SWATS(model.parameters(), verbose=True)
  96. # optimizer = Adam(model.parameters(), lr=0.005)
  97. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  98. trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer,
  99. sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
  100. device=device, dev_data=data.datasets['dev'], batch_size=ops.batch_size,
  101. metrics=metrics,
  102. check_code_level=-1,
  103. callbacks=callbacks, num_workers=2, n_epochs=ops.num_epochs)
  104. trainer.train()
  105. torch.save(model, 'idcnn.pt')
  106. tester = Tester(
  107. data=data.datasets['test'],
  108. model=model,
  109. metrics=metrics,
  110. batch_size=ops.batch_size,
  111. num_workers=2,
  112. device=device
  113. )
  114. tester.test()