You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tester.py 5.3 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import torch
  2. from torch import nn
  3. from fastNLP.core.batch import Batch
  4. from fastNLP.core.dataset import DataSet
  5. from fastNLP.core.metrics import _prepare_metrics
  6. from fastNLP.core.sampler import SequentialSampler
  7. from fastNLP.core.utils import CheckError
  8. from fastNLP.core.utils import _build_args
  9. from fastNLP.core.utils import _check_loss_evaluate
  10. from fastNLP.core.utils import _move_dict_value_to_device
  11. from fastNLP.core.utils import get_func_signature
  12. class Tester(object):
  13. """An collection of model inference and evaluation of performance, used over validation/dev set and test set.
  14. :param DataSet data: a validation/development set
  15. :param torch.nn.modules.module model: a PyTorch model
  16. :param MetricBase metrics: a metric object or a list of metrics (List[MetricBase])
  17. :param int batch_size: batch size for validation
  18. :param bool use_cuda: whether to use CUDA in validation.
  19. :param int verbose: the number of steps after which an information is printed.
  20. """
  21. def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=1):
  22. super(Tester, self).__init__()
  23. if not isinstance(data, DataSet):
  24. raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.")
  25. if not isinstance(model, nn.Module):
  26. raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")
  27. self.metrics = _prepare_metrics(metrics)
  28. self.data = data
  29. self.use_cuda = use_cuda
  30. self.batch_size = batch_size
  31. self.verbose = verbose
  32. if torch.cuda.is_available() and self.use_cuda:
  33. self._model = model.cuda()
  34. else:
  35. self._model = model
  36. self._model_device = model.parameters().__next__().device
  37. # check predict
  38. if hasattr(self._model, 'predict'):
  39. self._predict_func = self._model.predict
  40. if not callable(self._predict_func):
  41. _model_name = model.__class__.__name__
  42. raise TypeError(f"`{_model_name}.predict` must be callable to be used "
  43. f"for evaluation, not `{type(self._predict_func)}`.")
  44. else:
  45. self._predict_func = self._model.forward
  46. def test(self):
  47. """Start test or validation.
  48. :return eval_results: a dictionary whose keys are the class name of metrics to use, values are the evaluation results of these metrics.
  49. """
  50. # turn on the testing mode; clean up the history
  51. network = self._model
  52. self._mode(network, is_test=True)
  53. data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
  54. eval_results = {}
  55. try:
  56. with torch.no_grad():
  57. for batch_x, batch_y in data_iterator:
  58. _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
  59. pred_dict = self._data_forward(self._predict_func, batch_x)
  60. if not isinstance(pred_dict, dict):
  61. raise TypeError(f"The return value of {get_func_signature(self._predict_func)} "
  62. f"must be `dict`, got {type(pred_dict)}.")
  63. for metric in self.metrics:
  64. metric(pred_dict, batch_y)
  65. for metric in self.metrics:
  66. eval_result = metric.get_metric()
  67. if not isinstance(eval_result, dict):
  68. raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be "
  69. f"`dict`, got {type(eval_result)}")
  70. metric_name = metric.__class__.__name__
  71. eval_results[metric_name] = eval_result
  72. except CheckError as e:
  73. prev_func_signature = get_func_signature(self._predict_func)
  74. _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
  75. check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
  76. dataset=self.data, check_level=0)
  77. if self.verbose >= 1:
  78. print("[tester] \n{}".format(self._format_eval_results(eval_results)))
  79. self._mode(network, is_test=False)
  80. return eval_results
  81. def _mode(self, model, is_test=False):
  82. """Train mode or Test mode. This is for PyTorch currently.
  83. :param model: a PyTorch model
  84. :param is_test: bool, whether in test mode or not.
  85. """
  86. if is_test:
  87. model.eval()
  88. else:
  89. model.train()
  90. def _data_forward(self, func, x):
  91. """A forward pass of the model. """
  92. x = _build_args(func, **x)
  93. y = func(**x)
  94. return y
  95. def _format_eval_results(self, results):
  96. """Override this method to support more print formats.
  97. :param results: dict, (str: float) is (metrics name: value)
  98. """
  99. _str = ''
  100. for metric_name, metric_result in results.items():
  101. _str += metric_name + ': '
  102. _str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()])
  103. _str += '\n'
  104. return _str[:-1]