You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tester.py 6.5 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import _pickle
  2. import numpy as np
  3. import torch
  4. from fastNLP.core.action import Action
  5. from fastNLP.core.action import RandomSampler, Batchifier
  6. from fastNLP.modules import utils
  7. class BaseTester(object):
  8. """docstring for Tester"""
  9. def __init__(self, test_args):
  10. """
  11. :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
  12. """
  13. super(BaseTester, self).__init__()
  14. self.validate_in_training = test_args["validate_in_training"]
  15. self.save_dev_data = None
  16. self.save_output = test_args["save_output"]
  17. self.output = None
  18. self.save_loss = test_args["save_loss"]
  19. self.mean_loss = None
  20. self.batch_size = test_args["batch_size"]
  21. self.pickle_path = test_args["pickle_path"]
  22. self.iterator = None
  23. self.use_cuda = test_args["use_cuda"]
  24. self.model = None
  25. self.eval_history = []
  26. self.batch_output = []
  27. def test(self, network):
  28. if torch.cuda.is_available() and self.use_cuda:
  29. self.model = network.cuda()
  30. else:
  31. self.model = network
  32. # turn on the testing mode; clean up the history
  33. self.mode(network, test=True)
  34. self.eval_history.clear()
  35. self.batch_output.clear()
  36. dev_data = self.prepare_input(self.pickle_path)
  37. iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True))
  38. n_batches = len(dev_data) // self.batch_size
  39. n_print = 1
  40. step = 0
  41. for batch_x, batch_y in self.make_batch(iterator, dev_data):
  42. prediction = self.data_forward(network, batch_x)
  43. eval_results = self.evaluate(prediction, batch_y)
  44. if self.save_output:
  45. self.batch_output.append(prediction)
  46. if self.save_loss:
  47. self.eval_history.append(eval_results)
  48. step += 1
  49. def prepare_input(self, data_path):
  50. """
  51. Save the dev data once it is loaded. Can return directly next time.
  52. :param data_path: str, the path to the pickle data for dev
  53. :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s).
  54. """
  55. if self.save_dev_data is None:
  56. data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb"))
  57. self.save_dev_data = data_dev
  58. return self.save_dev_data
  59. def mode(self, model, test):
  60. Action.mode(model, test)
  61. def data_forward(self, network, x):
  62. raise NotImplementedError
  63. def evaluate(self, predict, truth):
  64. raise NotImplementedError
  65. @property
  66. def metrics(self):
  67. raise NotImplementedError
  68. def show_matrices(self):
  69. """
  70. This is called by Trainer to print evaluation on dev set.
  71. :return print_str: str
  72. """
  73. raise NotImplementedError
  74. def make_batch(self, iterator, data):
  75. raise NotImplementedError
  76. class SeqLabelTester(BaseTester):
  77. """
  78. Tester for sequence labeling.
  79. """
  80. def __init__(self, test_args):
  81. """
  82. :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
  83. """
  84. super(SeqLabelTester, self).__init__(test_args)
  85. self.max_len = None
  86. self.mask = None
  87. self.seq_len = None
  88. self.batch_result = None
  89. def data_forward(self, network, inputs):
  90. if not isinstance(inputs, tuple):
  91. raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
  92. # unpack the returned value from make_batch
  93. x, seq_len = inputs[0], inputs[1]
  94. batch_size, max_len = x.size(0), x.size(1)
  95. mask = utils.seq_mask(seq_len, max_len)
  96. mask = mask.byte().view(batch_size, max_len)
  97. if torch.cuda.is_available() and self.use_cuda:
  98. mask = mask.cuda()
  99. self.mask = mask
  100. self.seq_len = seq_len
  101. y = network(x)
  102. return y
  103. def evaluate(self, predict, truth):
  104. batch_size, max_len = predict.size(0), predict.size(1)
  105. loss = self.model.loss(predict, truth, self.mask) / batch_size
  106. prediction = self.model.prediction(predict, self.mask)
  107. results = torch.Tensor(prediction).view(-1,)
  108. # make sure "results" is in the same device as "truth"
  109. results = results.to(truth)
  110. accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0]
  111. return [loss.data, accuracy.data]
  112. def metrics(self):
  113. batch_loss = np.mean([x[0] for x in self.eval_history])
  114. batch_accuracy = np.mean([x[1] for x in self.eval_history])
  115. return batch_loss, batch_accuracy
  116. def show_matrices(self):
  117. """
  118. This is called by Trainer to print evaluation on dev set.
  119. :return print_str: str
  120. """
  121. loss, accuracy = self.metrics()
  122. return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)
  123. def make_batch(self, iterator, data):
  124. return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True)
  125. class ClassificationTester(BaseTester):
  126. """Tester for classification."""
  127. def __init__(self, test_args):
  128. """
  129. :param test_args: a dict-like object that has __getitem__ method, \
  130. can be accessed by "test_args["key_str"]"
  131. """
  132. super(ClassificationTester, self).__init__(test_args)
  133. self.pickle_path = test_args["pickle_path"]
  134. self.save_dev_data = None
  135. self.output = None
  136. self.mean_loss = None
  137. self.iterator = None
  138. def make_batch(self, iterator, data, max_len=None):
  139. return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len)
  140. def data_forward(self, network, x):
  141. """Forward through network."""
  142. logits = network(x)
  143. return logits
  144. def evaluate(self, y_logit, y_true):
  145. """Return y_pred and y_true."""
  146. y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
  147. return [y_prob, y_true]
  148. def metrics(self):
  149. """Compute accuracy."""
  150. y_prob, y_true = zip(*self.eval_history)
  151. y_prob = torch.cat(y_prob, dim=0)
  152. y_pred = torch.argmax(y_prob, dim=-1)
  153. y_true = torch.cat(y_true, dim=0)
  154. acc = float(torch.sum(y_pred == y_true)) / len(y_true)
  155. return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等