You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer.py 518 B

123456789101112131415161718192021
  1. from collections import namedtuple
  2. import numpy as np
  3. from action.trainer import Trainer
  4. from model.base_model import ToyModel
  5. def test_trainer():
  6. Config = namedtuple("config", ["epochs", "validate", "save_when_better"])
  7. train_config = Config(epochs=5, validate=True, save_when_better=True)
  8. trainer = Trainer(train_config)
  9. net = ToyModel()
  10. data = np.random.rand(20, 6)
  11. dev_data = np.random.rand(20, 6)
  12. trainer.train(net, data, dev_data)
  13. if __name__ == "__main__":
  14. test_trainer()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等