You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

valid.py 743 B

6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324
  1. import torch
  2. from reproduction.coreference_resolution.model.config import Config
  3. from reproduction.coreference_resolution.model.metric import CRMetric
  4. from fastNLP.io.pipe.coreference import CoreferencePipe
  5. from fastNLP import Tester
  6. import argparse
  7. if __name__=='__main__':
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument('--path')
  10. args = parser.parse_args()
  11. config = Config()
  12. bundle = CoreferencePipe(Config()).process_from_file(
  13. {'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path})
  14. metirc = CRMetric()
  15. model = torch.load(args.path)
  16. tester = Tester(bundle.datasets['test'],model,metirc,batch_size=1,device="cuda:0")
  17. tester.test()
  18. print('test over')