You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_task2.py 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from uctc.framework import basis
  2. import numpy as np
  3. import math
  4. binary_arguments = [
  5. (1.0, 2.0),
  6. (2.0, 1.0),
  7. (-1.0, 1.0),
  8. (2.0, -2.0),
  9. (1.0, 3.0),
  10. (3.0, -1.0),
  11. (3.0, 3.0),
  12. (-4.0, -5.0),
  13. (5.0, 4.0),
  14. (4.0, 4.0),
  15. (5.0, 5.0)
  16. ]
  17. singular_arguments = [
  18. 1.0, -3.2, 4.3, 5.5, -6.7, 4.8, 3.33, 2.22, 1.11
  19. ]
  20. def is_close(x, y):
  21. return abs(x - y) < 1e-5
  22. def sigmoid(x):
  23. if x >= 0:
  24. return 1 / (1 + math.exp(-x))
  25. else:
  26. return math.exp(x) / (1 + math.exp(x))
  27. def iterate_binary_arguments(func, std_func):
  28. for argument in binary_arguments:
  29. if not is_close(func(*argument), std_func(*argument)):
  30. print(f"\033[1;31mError: {func.__name__}({argument}) = {func(*argument)} != {std_func.__name__}({argument}) = {std_func(*argument)}\033[0m")
  31. exit(0)
  32. print(f"\033[1;34mPassed: {func.__name__} passed all tests\033[0m")
  33. return True
  34. def iterate_singular_arguments(func, std_func):
  35. for argument in singular_arguments:
  36. if not is_close(func(argument), std_func(argument)):
  37. print(f"\033[1;31mError: {func.__name__}({argument}) = {func(argument)} != {std_func.__name__}({argument}) = {std_func(argument)}\033[0m")
  38. exit(0)
  39. print(f"\033[1;34mPassed: {func.__name__} passed all tests\033[0m")
  40. return True
  41. # Test task 1
  42. iterate_binary_arguments(basis.is_close, lambda x, y: 1.0*int(is_close(x, y)))
  43. iterate_singular_arguments(basis.sigmoid, lambda x: sigmoid(x))
  44. iterate_singular_arguments(basis.relu, lambda x: x if x > 0.0 else 0.0)
  45. iterate_singular_arguments(basis.inv, lambda x: 1.0/x)
  46. iterate_binary_arguments(basis.inv_back, lambda x, d: -d/(x*x))
  47. iterate_binary_arguments(basis.relu_back, lambda x, d: d * 1.0 if x > 0.0 else 0.0)
  48. print(f"\033[1;32m[PASSED] Task 2 finished!\033[0m")

计算机大作业

Contributors (1)