You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tester.py 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import unittest
  2. import numpy as np
  3. from torch import nn
  4. import time
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. from fastNLP import AccuracyMetric
  8. from fastNLP import Tester
  9. data_name = "pku_training.utf8"
  10. pickle_path = "data_for_tests"
  11. def prepare_fake_dataset():
  12. mean = np.array([-3, -3])
  13. cov = np.array([[1, 0], [0, 1]])
  14. class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
  15. mean = np.array([3, 3])
  16. cov = np.array([[1, 0], [0, 1]])
  17. class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
  18. data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
  19. [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
  20. return data_set
  21. def prepare_fake_dataset2(*args, size=100):
  22. ys = np.random.randint(4, size=100, dtype=np.int64)
  23. data = {'y': ys}
  24. for arg in args:
  25. data[arg] = np.random.randn(size, 5)
  26. return DataSet(data=data)
  27. class TestTester(unittest.TestCase):
  28. def test_case_1(self):
  29. # 检查报错提示能否正确提醒用户
  30. dataset = prepare_fake_dataset2('x1', 'x_unused')
  31. dataset.rename_field('x_unused', 'x2')
  32. dataset.set_input('x1', 'x2')
  33. dataset.set_target('y', 'x1')
  34. class Model(nn.Module):
  35. def __init__(self):
  36. super().__init__()
  37. self.fc = nn.Linear(5, 4)
  38. def forward(self, x1, x2):
  39. x1 = self.fc(x1)
  40. x2 = self.fc(x2)
  41. x = x1 + x2
  42. time.sleep(0.1)
  43. # loss = F.cross_entropy(x, y)
  44. return {'preds': x}
  45. model = Model()
  46. with self.assertRaises(NameError):
  47. tester = Tester(
  48. data=dataset,
  49. model=model,
  50. metrics=AccuracyMetric())
  51. tester.test()