You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_metrics.py 4.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import sys, os
  2. sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path
  3. from fastNLP.core import metrics
  4. from sklearn import metrics as skmetrics
  5. import unittest
  6. import numpy as np
  7. from numpy import random
  8. def generate_fake_label(low, high, size):
  9. return random.randint(low, high, size), random.randint(low, high, size)
  10. class TestMetrics(unittest.TestCase):
  11. delta = 1e-5
  12. # test for binary, multiclass, multilabel
  13. data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]
  14. fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types]
  15. def test_accuracy_score(self):
  16. for y_true, y_pred in self.fake_data:
  17. for normalize in [True, False]:
  18. for sample_weight in [None, random.rand(y_true.shape[0])]:
  19. ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
  20. test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
  21. self.assertAlmostEqual(test, ans, delta=self.delta)
  22. def test_recall_score(self):
  23. for y_true, y_pred in self.fake_data:
  24. # print(y_true.shape)
  25. labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None
  26. ans = skmetrics.recall_score(y_true, y_pred, labels=labels, average=None)
  27. test = metrics.recall_score(y_true, y_pred, labels=labels, average=None)
  28. ans = list(ans)
  29. if not isinstance(test, list):
  30. test = list(test)
  31. for a, b in zip(test, ans):
  32. # print('{}, {}'.format(a, b))
  33. self.assertAlmostEqual(a, b, delta=self.delta)
  34. # test binary
  35. y_true, y_pred = generate_fake_label(0, 2, 1000)
  36. ans = skmetrics.recall_score(y_true, y_pred)
  37. test = metrics.recall_score(y_true, y_pred)
  38. self.assertAlmostEqual(ans, test, delta=self.delta)
  39. def test_precision_score(self):
  40. for y_true, y_pred in self.fake_data:
  41. # print(y_true.shape)
  42. labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None
  43. ans = skmetrics.precision_score(y_true, y_pred, labels=labels, average=None)
  44. test = metrics.precision_score(y_true, y_pred, labels=labels, average=None)
  45. ans, test = list(ans), list(test)
  46. for a, b in zip(test, ans):
  47. # print('{}, {}'.format(a, b))
  48. self.assertAlmostEqual(a, b, delta=self.delta)
  49. # test binary
  50. y_true, y_pred = generate_fake_label(0, 2, 1000)
  51. ans = skmetrics.precision_score(y_true, y_pred)
  52. test = metrics.precision_score(y_true, y_pred)
  53. self.assertAlmostEqual(ans, test, delta=self.delta)
  54. def test_precision_score(self):
  55. for y_true, y_pred in self.fake_data:
  56. # print(y_true.shape)
  57. labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None
  58. ans = skmetrics.precision_score(y_true, y_pred, labels=labels, average=None)
  59. test = metrics.precision_score(y_true, y_pred, labels=labels, average=None)
  60. ans, test = list(ans), list(test)
  61. for a, b in zip(test, ans):
  62. # print('{}, {}'.format(a, b))
  63. self.assertAlmostEqual(a, b, delta=self.delta)
  64. # test binary
  65. y_true, y_pred = generate_fake_label(0, 2, 1000)
  66. ans = skmetrics.precision_score(y_true, y_pred)
  67. test = metrics.precision_score(y_true, y_pred)
  68. self.assertAlmostEqual(ans, test, delta=self.delta)
  69. def test_f1_score(self):
  70. for y_true, y_pred in self.fake_data:
  71. # print(y_true.shape)
  72. labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None
  73. ans = skmetrics.f1_score(y_true, y_pred, labels=labels, average=None)
  74. test = metrics.f1_score(y_true, y_pred, labels=labels, average=None)
  75. ans, test = list(ans), list(test)
  76. for a, b in zip(test, ans):
  77. # print('{}, {}'.format(a, b))
  78. self.assertAlmostEqual(a, b, delta=self.delta)
  79. # test binary
  80. y_true, y_pred = generate_fake_label(0, 2, 1000)
  81. ans = skmetrics.f1_score(y_true, y_pred)
  82. test = metrics.f1_score(y_true, y_pred)
  83. self.assertAlmostEqual(ans, test, delta=self.delta)
  84. if __name__ == '__main__':
  85. unittest.main()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等