You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_idcnn.py 5.1 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader
  2. from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader
  3. from fastNLP.core.callback import FitlogCallback, LRScheduler
  4. from fastNLP import GradientClipCallback
  5. from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
  6. from torch.optim import SGD, Adam
  7. from fastNLP import Const
  8. from fastNLP import RandomSampler, BucketSampler
  9. from fastNLP import SpanFPreRecMetric
  10. from fastNLP import Trainer, Tester
  11. from fastNLP.core.metrics import MetricBase
  12. from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN
  13. from fastNLP.core.utils import Option
  14. from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding
  15. from fastNLP.core.utils import cache_results
  16. from fastNLP.core.vocabulary import VocabularyOption
  17. import fitlog
  18. import sys
  19. import torch.cuda
  20. import os
  21. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  22. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  23. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  24. encoding_type = 'bioes'
  25. def get_path(path):
  26. return os.path.join(os.environ['HOME'], path)
  27. ops = Option(
  28. batch_size=128,
  29. num_epochs=100,
  30. lr=3e-4,
  31. repeats=3,
  32. num_layers=3,
  33. num_filters=400,
  34. use_crf=False,
  35. gradient_clip=5,
  36. )
  37. @cache_results('ontonotes-case-cache')
  38. def load_data():
  39. print('loading data')
  40. data = OntoNoteNERDataLoader(encoding_type=encoding_type).process(
  41. paths = get_path('workdir/datasets/ontonotes-v4'),
  42. lower=False,
  43. word_vocab_opt=VocabularyOption(min_freq=0),
  44. )
  45. # data = Conll2003DataLoader(task='ner', encoding_type=encoding_type).process(
  46. # paths=get_path('workdir/datasets/conll03'),
  47. # lower=False, word_vocab_opt=VocabularyOption(min_freq=0)
  48. # )
  49. # char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30],
  50. # kernel_sizes=[3])
  51. print('loading embedding')
  52. word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
  53. model_dir_or_name='en-glove-840b-300',
  54. requires_grad=True)
  55. return data, [word_embed]
  56. data, embeds = load_data()
  57. print(data)
  58. print(data.datasets['train'][0])
  59. print(list(data.vocabs.keys()))
  60. # for ds in data.datasets.values():
  61. # ds.rename_field('cap_words', 'chars')
  62. # ds.set_input('chars')
  63. word_embed = embeds[0]
  64. word_embed.embedding.weight.data /= word_embed.embedding.weight.data.std()
  65. # char_embed = CNNCharEmbedding(data.vocabs['cap_words'])
  66. char_embed = None
  67. # for ds in data.datasets:
  68. # ds.rename_field('')
  69. print(data.vocabs[Const.TARGET].word2idx)
  70. model = IDCNN(init_embed=word_embed,
  71. char_embed=char_embed,
  72. num_cls=len(data.vocabs[Const.TARGET]),
  73. repeats=ops.repeats,
  74. num_layers=ops.num_layers,
  75. num_filters=ops.num_filters,
  76. kernel_size=3,
  77. use_crf=ops.use_crf, use_projection=True,
  78. block_loss=True,
  79. input_dropout=0.5, hidden_dropout=0.2, inner_dropout=0.2)
  80. print(model)
  81. callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='value'),]
  82. metrics = []
  83. metrics.append(
  84. SpanFPreRecMetric(
  85. tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type,
  86. pred=Const.OUTPUT, target=Const.TARGET, seq_len=Const.INPUT_LEN,
  87. )
  88. )
  89. class LossMetric(MetricBase):
  90. def __init__(self, loss=None):
  91. super(LossMetric, self).__init__()
  92. self._init_param_map(loss=loss)
  93. self.total_loss = 0.0
  94. self.steps = 0
  95. def evaluate(self, loss):
  96. self.total_loss += float(loss)
  97. self.steps += 1
  98. def get_metric(self, reset=True):
  99. result = {'loss': self.total_loss / (self.steps + 1e-12)}
  100. if reset:
  101. self.total_loss = 0.0
  102. self.steps = 0
  103. return result
  104. metrics.append(
  105. LossMetric(loss=Const.LOSS)
  106. )
  107. optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=0)
  108. scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)))
  109. callbacks.append(scheduler)
  110. # callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 15)))
  111. # optimizer = SWATS(model.parameters(), verbose=True)
  112. # optimizer = Adam(model.parameters(), lr=0.005)
  113. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  114. trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer,
  115. sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
  116. device=device, dev_data=data.datasets['dev'], batch_size=ops.batch_size,
  117. metrics=metrics,
  118. check_code_level=-1,
  119. callbacks=callbacks, num_workers=2, n_epochs=ops.num_epochs)
  120. trainer.train()
  121. torch.save(model, 'idcnn.pt')
  122. tester = Tester(
  123. data=data.datasets['test'],
  124. model=model,
  125. metrics=metrics,
  126. batch_size=ops.batch_size,
  127. num_workers=2,
  128. device=device
  129. )
  130. tester.test()