You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 6.6 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from functools import partial
  2. import pytest
  3. from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg
  4. from fastNLP.core.metrics import Metric
  5. class TestAutoParamCall:
  6. def test_basic(self):
  7. def fn(x):
  8. return x
  9. x = {'x': 3, 'y': 4}
  10. r = auto_param_call(fn, x)
  11. assert r==3
  12. xs = []
  13. for i in range(10):
  14. xs.append({f'x{i}': i})
  15. def fn(x0, x1, x2, x3):
  16. return x0 + x1 + x2 + x3
  17. r = auto_param_call(fn, *xs)
  18. assert r == 0 + 1+ 2+ 3
  19. def fn(chongfu1, chongfu2, buChongFu):
  20. pass
  21. with pytest.raises(BaseException) as exc_info:
  22. auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2},
  23. {'chongfu1': 1, 'chongfu2':2, 'buChongFu':2})
  24. assert 'The following key present in several inputs' in exc_info.value.args[0]
  25. assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0]
  26. # 没用到不报错
  27. def fn(chongfu1, buChongFu):
  28. pass
  29. auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2},
  30. {'chongfu1': 1, 'chongfu2':2, 'buChongFu':2})
  31. # 可以定制signature_fn
  32. def fn1(**kwargs):
  33. kwargs.pop('x')
  34. kwargs.pop('y')
  35. assert len(kwargs)==0
  36. def fn(x, y):
  37. pass
  38. x = {'x': 3, 'y': 4}
  39. r = auto_param_call(fn1, x, signature_fn=fn)
  40. # 没提供的时候报错
  41. def fn(meiti1, meiti2, tigong):
  42. pass
  43. with pytest.raises(BaseException) as exc_info:
  44. auto_param_call(fn, {'tigong':1})
  45. assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0]
  46. # 默认值替换
  47. def fn(x, y=100):
  48. return x + y
  49. r = auto_param_call(fn, {'x': 10, 'y': 20})
  50. assert r==30
  51. assert auto_param_call(fn, {'x': 10, 'z': 20})==110
  52. # 测试mapping的使用
  53. def fn(x, y=100):
  54. return x + y
  55. r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'})
  56. assert r==30
  57. # 测试不需要任何参数
  58. def fn():
  59. return 1
  60. assert 1 == auto_param_call(fn, {'x':1})
  61. # 测试调用类的方法没问题
  62. assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1})
  63. assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'})
  64. def test_msg(self):
  65. with pytest.raises(BaseException) as exc_info:
  66. auto_param_call(self.call_this, {'x':1})
  67. assert 'TestAutoParamCall.call_this' in exc_info.value.args[0]
  68. with pytest.raises(BaseException) as exc_info:
  69. auto_param_call(call_this_for_auto_param_call, {'x':1})
  70. assert __file__ in exc_info.value.args[0]
  71. assert 'call_this_for_auto_param_call' in exc_info.value.args[0]
  72. with pytest.raises(BaseException) as exc_info:
  73. auto_param_call(self.call_this_two, {'x':1})
  74. assert __file__ in exc_info.value.args[0]
  75. with pytest.raises(BaseException) as exc_info:
  76. auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this)
  77. assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息
  78. def call_this(self, x, y):
  79. return x + y
  80. def call_this_two(self, x, y, z=pytest, **kwargs):
  81. return x + y
  82. def test_metric_auto_param_call(self):
  83. metric = AutoParamCallMetric()
  84. with pytest.raises(BaseException):
  85. auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__)
  86. class AutoParamCallMetric(Metric):
  87. def update(self, x):
  88. pass
  89. def call_this_for_auto_param_call(x, y):
  90. return x + y
  91. class TestCheckNumberOfParameters:
  92. def test_validate_every(self):
  93. def validate_every(trainer):
  94. pass
  95. _check_valid_parameters_number(validate_every, expected_params=['trainer'])
  96. # 无默认值,多了报错
  97. def validate_every(trainer, other):
  98. pass
  99. with pytest.raises(RuntimeError) as exc_info:
  100. _check_valid_parameters_number(validate_every, expected_params=['trainer'])
  101. assert "2 parameters" in exc_info.value.args[0]
  102. print(exc_info.value.args[0])
  103. # 有默认值ok
  104. def validate_every(trainer, other=1):
  105. pass
  106. _check_valid_parameters_number(validate_every, expected_params=['trainer'])
  107. # 参数多了
  108. def validate_every(trainer):
  109. pass
  110. with pytest.raises(RuntimeError) as exc_info:
  111. _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other'])
  112. assert "accepts 1 parameters" in exc_info.value.args[0]
  113. print(exc_info.value.args[0])
  114. # 使用partial
  115. def validate_every(trainer, other):
  116. pass
  117. _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer'])
  118. _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other'])
  119. with pytest.raises(RuntimeError) as exc_info:
  120. _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more'])
  121. assert 'accepts 2 parameters' in exc_info.value.args[0]
  122. print(exc_info.value.args[0])
  123. # 如果存在 *args 或 *kwargs 不报错多的
  124. def validate_every(trainer, *args):
  125. pass
  126. _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more'])
  127. def validate_every(trainer, **kwargs):
  128. pass
  129. _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more'])
  130. # class 的方法删掉self
  131. class InnerClass:
  132. def demo(self, x):
  133. pass
  134. def no_param(self):
  135. pass
  136. def param_kwargs(self, **kwargs):
  137. pass
  138. inner = InnerClass()
  139. with pytest.raises(RuntimeError) as exc_info:
  140. _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more'])
  141. assert 'accepts 1 parameters' in exc_info.value.args[0]
  142. _check_valid_parameters_number(inner.demo, expected_params=['trainer'])
  143. def test_get_fun_msg():
  144. # 测试运行
  145. def demo(x):
  146. pass
  147. print(_get_fun_msg(_get_fun_msg))