You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_bert.py 1.5 kB

123456789101112131415161718192021222324252627282930313233
  1. import sys
  2. sys.path.append('../../')
  3. from reproduction.text_classification.data.IMDBLoader import IMDBLoader
  4. from fastNLP.embeddings import BertEmbedding
  5. from reproduction.text_classification.model.lstm import BiLSTMSentiment
  6. from fastNLP import Trainer
  7. from fastNLP import CrossEntropyLoss, AccuracyMetric
  8. from fastNLP import cache_results
  9. from fastNLP import Tester
  10. # 对返回结果进行缓存,下一次运行就会自动跳过预处理
  11. @cache_results('imdb.pkl')
  12. def get_data():
  13. data_bundle = IMDBLoader().process('imdb/')
  14. return data_bundle
  15. data_bundle = get_data()
  16. print(data_bundle)
  17. # 删除超过512, 但由于英语中会把word进行word piece处理,所以截取的时候做一点的裕量
  18. data_bundle.datasets['train'].drop(lambda x:len(x['words'])>400)
  19. data_bundle.datasets['dev'].drop(lambda x:len(x['words'])>400)
  20. data_bundle.datasets['test'].drop(lambda x:len(x['words'])>400)
  21. bert_embed = BertEmbedding(data_bundle.vocabs['words'], requires_grad=False,
  22. model_dir_or_name="en-base-uncased")
  23. model = BiLSTMSentiment(bert_embed, len(data_bundle.vocabs['target']))
  24. Trainer(data_bundle.datasets['train'], model, optimizer=None, loss=CrossEntropyLoss(), device=0,
  25. batch_size=10, dev_data=data_bundle.datasets['dev'], metrics=AccuracyMetric()).train()
  26. # 在测试集上测试一下效果
  27. Tester(data_bundle.datasets['test'], model, batch_size=32, metrics=AccuracyMetric()).test()