You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

valid.py 808 B

123456789101112131415161718192021222324
  1. import torch
  2. from reproduction.coreference_resolution.model.config import Config
  3. from reproduction.coreference_resolution.model.metric import CRMetric
  4. from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
  5. from fastNLP import Tester
  6. import argparse
  7. if __name__=='__main__':
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument('--path')
  10. args = parser.parse_args()
  11. cr_loader = CRLoader()
  12. config = Config()
  13. data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path,
  14. 'test': config.test_path})
  15. metirc = CRMetric()
  16. model = torch.load(args.path)
  17. tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0")
  18. tester.test()
  19. print('test over')