You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

02_linear_test.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import uctc.nn as nn
  2. import std_model as stdnn
  3. import numpy as np
  4. class LinearTestModel:
  5. def __init__(self, input_features, output_features):
  6. self.w1 = nn.Parameter([input_features, output_features])
  7. self.b1 = nn.Parameter([1, output_features])
  8. def forward(self, x):
  9. l1 = nn.Linear(x, self.w1)
  10. l2 = nn.AddBias(l1, self.b1)
  11. return l2
  12. def get_loss(self, x, y):
  13. return nn.SquareLoss(self.forward(x), y)
  14. def backward(self, x, y):
  15. loss = self.get_loss(x, y)
  16. g_w1, g_b1 = nn.gradients(loss, [self.w1, self.b1])
  17. return g_w1.data(), g_b1.data()
  18. class StdLinerTestModel:
  19. def __init__(self, input_features, output_features, tmodel: LinearTestModel):
  20. self.w1 = stdnn.Parameter(input_features, output_features)
  21. self.b1 = stdnn.Parameter(1, output_features)
  22. self.w1.data = np.array(tmodel.w1.data()).reshape(input_features, output_features)
  23. self.b1.data = np.array(tmodel.b1.data()).reshape(1, output_features)
  24. def forward(self, x):
  25. l1 = stdnn.Linear(x, self.w1)
  26. l2 = stdnn.AddBias(l1, self.b1)
  27. return l2
  28. def get_loss(self, x, y):
  29. return stdnn.SquareLoss(self.forward(x), y)
  30. def backward(self, x, y):
  31. loss = self.get_loss(x, y)
  32. g_w1, g_b1 = stdnn.gradients(loss, [self.w1, self.b1])
  33. return g_w1.data.flatten().tolist(), g_b1.data.flatten().tolist()
  34. input_features = 16
  35. output_features = 32
  36. batch_size = 4
  37. x = np.random.randn(batch_size, input_features).astype(np.float32)
  38. y = np.random.randn(batch_size, output_features).astype(np.float32)
  39. model = LinearTestModel(input_features, output_features)
  40. test_x = nn.Constant(x)
  41. predict_y = model.forward(test_x).data()
  42. test_y = nn.Constant(y)
  43. loss = model.get_loss(test_x, test_y).data()
  44. g_w1, g_b1 = model.backward(test_x, test_y)
  45. stdmodel = StdLinerTestModel(input_features, output_features, model)
  46. std_test_x = stdnn.Constant(x)
  47. std_predict_y = stdmodel.forward(std_test_x)
  48. std_test_y = stdnn.Constant(y)
  49. std_loss = stdmodel.get_loss(std_test_x, std_test_y)
  50. std_g_w1, std_g_b1 = stdmodel.backward(std_test_x, std_test_y)
  51. # check forward
  52. for x, y in zip(predict_y, std_predict_y.data.tolist()[0]):
  53. if (abs(x-y) > 1e-4):
  54. assert 0, "Forward data mismatch!"
  55. # check loss
  56. if abs(loss[0] - std_loss.data) > 1e-4:
  57. assert 0, "Loss mismatch!"
  58. # check backward
  59. for i, (x, y) in enumerate(zip(g_w1, std_g_w1)):
  60. if (abs(x-y) > 1e-4):
  61. assert 0, f"Gradient w1 mismatch at position {i}, g_w1 is {x} while std g_w1 is {y}"
  62. for i, (x, y) in enumerate(zip(g_b1, std_g_b1)):
  63. if (abs(x-y) > 1e-4):
  64. assert 0, f"Gradient b1 mismatch at position {i}, g_b1 is {x} while std g_b1 is {y}"
  65. print("Test passed")

计算机大作业