You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_metrics.py 19 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from fastNLP.core.metrics import AccuracyMetric
  5. from fastNLP.core.metrics import BMESF1PreRecMetric
  6. from fastNLP.core.metrics import _pred_topk, _accuracy_topk
  7. class TestAccuracyMetric(unittest.TestCase):
  8. def test_AccuracyMetric1(self):
  9. # (1) only input, targets passed
  10. pred_dict = {"pred": torch.zeros(4, 3)}
  11. target_dict = {'target': torch.zeros(4)}
  12. metric = AccuracyMetric()
  13. metric(pred_dict=pred_dict, target_dict=target_dict)
  14. print(metric.get_metric())
  15. def test_AccuracyMetric2(self):
  16. # (2) with corrupted size
  17. try:
  18. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  19. target_dict = {'target': torch.zeros(4)}
  20. metric = AccuracyMetric()
  21. metric(pred_dict=pred_dict, target_dict=target_dict, )
  22. print(metric.get_metric())
  23. except Exception as e:
  24. print(e)
  25. return
  26. print("No exception catches.")
  27. def test_AccuracyMetric3(self):
  28. # (3) the second batch is corrupted size
  29. try:
  30. metric = AccuracyMetric()
  31. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  32. target_dict = {'target': torch.zeros(4, 3)}
  33. metric(pred_dict=pred_dict, target_dict=target_dict)
  34. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  35. target_dict = {'target': torch.zeros(4)}
  36. metric(pred_dict=pred_dict, target_dict=target_dict)
  37. print(metric.get_metric())
  38. except Exception as e:
  39. print(e)
  40. return
  41. self.assertTrue(True, False), "No exception catches."
  42. def test_AccuaryMetric4(self):
  43. # (5) check reset
  44. metric = AccuracyMetric()
  45. pred_dict = {"pred": torch.randn(4, 3, 2)}
  46. target_dict = {'target': torch.ones(4, 3)}
  47. metric(pred_dict=pred_dict, target_dict=target_dict)
  48. ans = torch.argmax(pred_dict["pred"], dim=2).to(target_dict["target"]) == target_dict["target"]
  49. res = metric.get_metric()
  50. self.assertTrue(isinstance(res, dict))
  51. self.assertTrue("acc" in res)
  52. self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3)
  53. def test_AccuaryMetric5(self):
  54. # (5) check reset
  55. metric = AccuracyMetric()
  56. pred_dict = {"pred": torch.randn(4, 3, 2)}
  57. target_dict = {'target': torch.zeros(4, 3)}
  58. metric(pred_dict=pred_dict, target_dict=target_dict)
  59. res = metric.get_metric(reset=False)
  60. ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean()
  61. self.assertAlmostEqual(res["acc"], float(ans), places=4)
  62. def test_AccuaryMetric6(self):
  63. # (6) check numpy array is not acceptable
  64. try:
  65. metric = AccuracyMetric()
  66. pred_dict = {"pred": np.zeros((4, 3, 2))}
  67. target_dict = {'target': np.zeros((4, 3))}
  68. metric(pred_dict=pred_dict, target_dict=target_dict)
  69. except Exception as e:
  70. print(e)
  71. return
  72. self.assertTrue(True, False), "No exception catches."
  73. def test_AccuaryMetric7(self):
  74. # (7) check map, match
  75. metric = AccuracyMetric(pred='predictions', target='targets')
  76. pred_dict = {"predictions": torch.randn(4, 3, 2)}
  77. target_dict = {'targets': torch.zeros(4, 3)}
  78. metric(pred_dict=pred_dict, target_dict=target_dict)
  79. res = metric.get_metric()
  80. ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean()
  81. self.assertAlmostEqual(res["acc"], float(ans), places=4)
  82. def test_AccuaryMetric8(self):
  83. try:
  84. metric = AccuracyMetric(pred='predictions', target='targets')
  85. pred_dict = {"prediction": torch.zeros(4, 3, 2)}
  86. target_dict = {'targets': torch.zeros(4, 3)}
  87. metric(pred_dict=pred_dict, target_dict=target_dict, )
  88. self.assertDictEqual(metric.get_metric(), {'acc': 1})
  89. except Exception as e:
  90. print(e)
  91. return
  92. self.assertTrue(True, False), "No exception catches."
  93. def test_AccuaryMetric9(self):
  94. # (9) check map, include unused
  95. try:
  96. metric = AccuracyMetric(pred='prediction', target='targets')
  97. pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1}
  98. target_dict = {'targets': torch.zeros(4, 3)}
  99. metric(pred_dict=pred_dict, target_dict=target_dict)
  100. self.assertDictEqual(metric.get_metric(), {'acc': 1})
  101. except Exception as e:
  102. print(e)
  103. return
  104. self.assertTrue(True, False), "No exception catches."
  105. def test_AccuaryMetric10(self):
  106. # (10) check _fast_metric
  107. try:
  108. metric = AccuracyMetric()
  109. pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)}
  110. target_dict = {'targets': torch.zeros(4, 3)}
  111. metric(pred_dict=pred_dict, target_dict=target_dict)
  112. self.assertDictEqual(metric.get_metric(), {'acc': 1})
  113. except Exception as e:
  114. print(e)
  115. return
  116. self.assertTrue(True, False), "No exception catches."
  117. class SpanF1PreRecMetric(unittest.TestCase):
  118. def test_case1(self):
  119. from fastNLP.core.metrics import _bmes_tag_to_spans
  120. from fastNLP.core.metrics import _bio_tag_to_spans
  121. bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8']
  122. bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8']
  123. expect_bmes_res = set()
  124. expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)),
  125. ('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))])
  126. expect_bio_res = set()
  127. expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)),
  128. ('6', (4, 5)), ('7', (6, 7))])
  129. self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst)))
  130. self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst)))
  131. # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
  132. # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
  133. # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans
  134. # for i in range(1000):
  135. # strs = list(map(str, np.random.randint(100, size=1000)))
  136. # bmes = list('bmes'.upper())
  137. # bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))]
  138. # bio = list('bio'.upper())
  139. # bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))]
  140. # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs)))
  141. # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs)))
  142. def test_case2(self):
  143. # 测试不带label的
  144. from fastNLP.core.metrics import _bmes_tag_to_spans
  145. from fastNLP.core.metrics import _bio_tag_to_spans
  146. bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E']
  147. bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O']
  148. expect_bmes_res = set()
  149. expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))])
  150. expect_bio_res = set()
  151. expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))])
  152. self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst)))
  153. self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst)))
  154. # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
  155. # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
  156. # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans
  157. # for i in range(1000):
  158. # bmes = list('bmes'.upper())
  159. # bmes_strs = np.random.choice(bmes, size=1000)
  160. # bio = list('bio'.upper())
  161. # bio_strs = np.random.choice(bio, size=100)
  162. # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs)))
  163. # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs)))
  164. def tese_case3(self):
  165. from fastNLP.core.vocabulary import Vocabulary
  166. from collections import Counter
  167. from fastNLP.core.metrics import SpanFPreRecMetric
  168. # 与allennlp测试能否正确计算f metric
  169. #
  170. def generate_allen_tags(encoding_type, number_labels=4):
  171. vocab = {}
  172. for i in range(number_labels):
  173. label = str(i)
  174. for tag in encoding_type:
  175. if tag == 'O':
  176. if tag not in vocab:
  177. vocab['O'] = len(vocab) + 1
  178. continue
  179. vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count
  180. return vocab
  181. number_labels = 4
  182. # bio tag
  183. fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
  184. fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels))
  185. fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
  186. bio_sequence = torch.FloatTensor(
  187. [[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011,
  188. 0.0470, 0.0971],
  189. [-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523,
  190. 0.7987, -0.3970],
  191. [0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898,
  192. 0.6880, 1.4348],
  193. [-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793,
  194. -1.6876, -0.8917],
  195. [-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824,
  196. 1.4217, 0.2622]],
  197. [[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136,
  198. 1.3592, -0.8973],
  199. [0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887,
  200. -0.4025, -0.3417],
  201. [-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698,
  202. 0.2861, -0.3966],
  203. [-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275,
  204. 0.0213, 1.4777],
  205. [-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566,
  206. 1.3024, 0.2001]]]
  207. )
  208. bio_target = torch.LongTensor([[5., 0., 3., 3., 3.],
  209. [5., 6., 8., 6., 0.]])
  210. fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target})
  211. expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775,
  212. 'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0,
  213. 'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845,
  214. 'f': 0.12499999999994846}
  215. self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric())
  216. #bmes tag
  217. bmes_sequence = torch.FloatTensor(
  218. [[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352,
  219. -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332,
  220. -0.3505, -0.6002],
  221. [0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641,
  222. 1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002,
  223. 0.4439, 0.2514],
  224. [-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696,
  225. -1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753,
  226. 0.5802, -0.0516],
  227. [-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121,
  228. 4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390,
  229. -0.9242, -0.0333],
  230. [0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393,
  231. 0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809,
  232. -0.3779, -0.3195]],
  233. [[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753,
  234. 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957,
  235. -0.1103, 0.4417],
  236. [-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049,
  237. -0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631,
  238. -0.6386, 0.2797],
  239. [0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947,
  240. 0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522,
  241. 1.1237, -1.5385],
  242. [0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977,
  243. -0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376,
  244. -0.0591, -0.2461],
  245. [-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097,
  246. 2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142,
  247. -0.7344, -1.2046]]]
  248. )
  249. bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.],
  250. [ 6., 15., 6., 15., 5.]])
  251. fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None)
  252. fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels))
  253. fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes')
  254. fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target})
  255. expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001,
  256. 'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775,
  257. 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314,
  258. 'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504,
  259. 'pre': 0.499999999999995, 'rec': 0.499999999999995}
  260. self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res)
  261. # 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码
  262. # from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary
  263. # from allennlp.training.metrics import SpanBasedF1Measure
  264. # allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)},
  265. # non_padded_namespaces=['tags'])
  266. # allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags')
  267. # bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1))
  268. # bio_target = torch.randint(2 * number_labels + 1, size=(2, 20))
  269. # allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20))
  270. # fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
  271. # fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels))
  272. # fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
  273. #
  274. # def convert_allen_res_to_fastnlp_res(metric_result):
  275. # allen_result = {}
  276. # key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"}
  277. # for key, value in metric_result.items():
  278. # if key in key_map:
  279. # key = key_map[key]
  280. # else:
  281. # label = key.split('-')[-1]
  282. # if key.startswith('f1'):
  283. # key = 'f-{}'.format(label)
  284. # else:
  285. # key = '{}-{}'.format(key[:3], label)
  286. # allen_result[key] = value
  287. # return allen_result
  288. #
  289. # # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()))
  290. # # print(fastnlp_bio_metric.get_metric())
  291. # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()),
  292. # fastnlp_bio_metric.get_metric())
  293. #
  294. # allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)})
  295. # allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES')
  296. # fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None)
  297. # fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels))
  298. # fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes')
  299. # bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels))
  300. # bmes_target = torch.randint(4 * number_labels, size=(2, 20))
  301. # allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20))
  302. # fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target})
  303. #
  304. # # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()))
  305. # # print(fastnlp_bmes_metric.get_metric())
  306. # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()),
  307. # fastnlp_bmes_metric.get_metric())
  308. class TestBMESF1PreRecMetric(unittest.TestCase):
  309. def test_case1(self):
  310. seq_lens = torch.LongTensor([4, 2])
  311. pred = torch.randn(2, 4, 4)
  312. target = torch.LongTensor([[0, 1, 2, 3],
  313. [3, 3, 0, 0]])
  314. pred_dict = {'pred': pred}
  315. target_dict = {'target': target, 'seq_lens': seq_lens}
  316. metric = BMESF1PreRecMetric()
  317. metric(pred_dict, target_dict)
  318. metric.get_metric()
  319. def test_case2(self):
  320. # 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1}
  321. seq_lens = torch.LongTensor([4, 2])
  322. target = torch.LongTensor([[0, 1, 2, 3],
  323. [3, 3, 0, 0]])
  324. pred_dict = {'pred': target}
  325. target_dict = {'target': target, 'seq_lens': seq_lens}
  326. metric = BMESF1PreRecMetric()
  327. metric(pred_dict, target_dict)
  328. self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0})
  329. class TestUsefulFunctions(unittest.TestCase):
  330. # 测试metrics.py中一些看上去挺有用的函数
  331. def test_case_1(self):
  332. # multi-class
  333. _ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3)
  334. _ = _pred_topk(np.random.randint(0, 3, size=(10, 1)))
  335. # 跑通即可