You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 668 B

12345678910111213141516171819202122232425
  1. from fastNLP.core.loss import Loss
  2. from fastNLP.core.preprocess import Preprocessor
  3. from fastNLP.core.trainer import Trainer
  4. from fastNLP.loader.dataset_loader import LMDataSetLoader
  5. from fastNLP.models.char_language_model import CharLM
  6. PICKLE = "./save/"
  7. def train():
  8. loader = LMDataSetLoader()
  9. train_data = loader.load()
  10. pre = Preprocessor(label_is_seq=True, share_vocab=True)
  11. train_set = pre.run(train_data, pickle_path=PICKLE)
  12. model = CharLM(50, 50, pre.vocab_size, pre.char_vocab_size)
  13. trainer = Trainer(task="language_model", loss=Loss("cross_entropy"))
  14. trainer.train(model, train_set)
  15. if __name__ == "__main__":
  16. train()