You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 1.3 kB

1234567891011121314151617181920212223242526272829303132
  1. import unittest
  2. from fastNLP.core.metrics.utils import func_post_proc
  3. class Metric:
  4. def accumulate(self, x, y):
  5. return x, y
  6. def compute(self, x, y):
  7. return x, y
  8. class TestMetricUtil(unittest.TestCase):
  9. def test_func_post_proc(self):
  10. metric = Metric()
  11. metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate')
  12. self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2))
  13. func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate')
  14. self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2))
  15. metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
  16. self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2))
  17. func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update')
  18. self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2))
  19. def test_check_accumulate_post_special_local_variable(self):
  20. metric = Metric()
  21. self.assertFalse(hasattr(metric, '__wrapped_by_fn__'))
  22. metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
  23. self.assertTrue(hasattr(metric, '__wrapped_by_fn__'))