You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

04_2layers_test.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import uctc.nn as nn
  2. import std_model as stdnn
  3. import numpy as np
  4. np.random.seed(42)
  5. class LinearTestModel:
  6. def __init__(self, input_features, hidden_features, output_features):
  7. self.w1 = nn.Parameter([input_features, hidden_features])
  8. self.b1 = nn.Parameter([1, hidden_features])
  9. self.w2 = nn.Parameter([hidden_features, output_features])
  10. self.b2 = nn.Parameter([1, output_features])
  11. def forward(self, x):
  12. l1 = nn.Linear(x, self.w1)
  13. l2 = nn.AddBias(l1, self.b1)
  14. l3 = nn.ReLU(l2)
  15. l4 = nn.Linear(l3, self.w2)
  16. l5 = nn.AddBias(l4, self.b2)
  17. return l5
  18. def get_loss(self, x, y):
  19. return nn.SquareLoss(self.forward(x), y)
  20. def backward(self, x, y):
  21. loss = self.get_loss(x, y)
  22. g_w1, g_b1, g_w2, g_b2 = nn.gradients(loss, [self.w1, self.b1, self.w2, self.b2])
  23. return g_w1.data(), g_b1.data(), g_w2.data(), g_b2.data()
  24. def update(self, x, y, lr):
  25. loss = self.get_loss(x, y)
  26. g_w1, g_b1, g_w2, g_b2 = nn.gradients(loss, [self.w1, self.b1, self.w2, self.b2])
  27. self.w1.update(g_w1, lr)
  28. self.b1.update(g_b1, lr)
  29. self.w2.update(g_w2, lr)
  30. self.b2.update(g_b2, lr)
  31. print(g_w1.data())
  32. print(g_b1.data())
  33. print(g_w2.data())
  34. print(g_b2.data())
  35. return self.w1.data(), self.b1.data(), self.w2.data(), self.b2.data()
  36. class StdLinerTestModel:
  37. def __init__(self, input_features, hidden_features, output_features, tmodel: LinearTestModel):
  38. self.w1 = stdnn.Parameter(input_features, hidden_features)
  39. self.b1 = stdnn.Parameter(1, hidden_features)
  40. self.w2 = stdnn.Parameter(hidden_features, output_features)
  41. self.b2 = stdnn.Parameter(1, output_features)
  42. self.w1.data = np.array(tmodel.w1.data()).reshape(input_features, hidden_features)
  43. self.b1.data = np.array(tmodel.b1.data()).reshape(1, hidden_features)
  44. self.w2.data = np.array(tmodel.w2.data()).reshape(hidden_features, output_features)
  45. self.b2.data = np.array(tmodel.b2.data()).reshape(1, output_features)
  46. def forward(self, x):
  47. l1 = stdnn.Linear(x, self.w1)
  48. l2 = stdnn.AddBias(l1, self.b1)
  49. l3 = stdnn.ReLU(l2)
  50. l4 = stdnn.Linear(l3, self.w2)
  51. l5 = stdnn.AddBias(l4, self.b2)
  52. return l5
  53. def get_loss(self, x, y):
  54. return stdnn.SquareLoss(self.forward(x), y)
  55. def backward(self, x, y):
  56. loss = self.get_loss(x, y)
  57. g_w1, g_b1, g_w2, g_b2 = stdnn.gradients(loss, [self.w1, self.b1, self.w2, self.b2])
  58. return g_w1.data.flatten().tolist(), g_b1.data.flatten().tolist(), g_w2.data.flatten().tolist(), g_b2.data.flatten().tolist()
  59. def update(self, x, y, lr):
  60. loss = self.get_loss(x, y)
  61. g_w1, g_b1, g_w2, g_b2 = stdnn.gradients(loss, [self.w1, self.b1, self.w2, self.b2])
  62. self.w1.update(g_w1, -lr)
  63. self.b1.update(g_b1, -lr)
  64. self.w2.update(g_w2, -lr)
  65. self.b2.update(g_b2, -lr)
  66. return self.w1.data.flatten().tolist(), self.b1.data.flatten().tolist(), self.w2.data.flatten().tolist(), self.b2.data.flatten().tolist()
  67. input_features = 1
  68. hidden_features = 50
  69. output_features = 1
  70. batch_size = 10
  71. x = np.array([-5.146528720855713, 4.451905250549316, 0.4736069440841675, -0.09472138434648514, 4.8939385414123535, 5.209676265716553, -5.967447280883789, 2.9363629817962646, -5.525413990020752, 3.315248489379883]).reshape(batch_size, -1)
  72. y = np.array([0.9072322249412537, -0.9662654995918274, 0.45609915256500244, -0.09457980841398239, -0.9835651516914368, -0.8788799047470093, 0.3105180263519287, 0.2037920206785202, 0.6873041391372681, -0.17278438806533813]).reshape(batch_size, -1)
  73. model = LinearTestModel(input_features, hidden_features, output_features)
  74. stdmodel = StdLinerTestModel(input_features, hidden_features, output_features, model)
  75. test_x = nn.Constant(x)
  76. predict_y = model.forward(test_x).data()
  77. test_y = nn.Constant(y)
  78. loss = model.get_loss(test_x, test_y).data()
  79. g_w1, g_b1, g_w2, g_b2 = model.backward(test_x, test_y)
  80. new_w1, new_b1, new_w2, new_b2 = model.update(test_x, test_y, 0)
  81. std_test_x = stdnn.Constant(x)
  82. std_predict_y = stdmodel.forward(std_test_x)
  83. std_test_y = stdnn.Constant(y)
  84. std_loss = stdmodel.get_loss(std_test_x, std_test_y)
  85. std_g_w1, std_g_b1, std_g_w2, std_g_b2 = stdmodel.backward(std_test_x, std_test_y)
  86. std_new_w1, std_new_b1, std_new_w2, std_new_b2 = stdmodel.update(std_test_x, std_test_y, 0)
  87. # print(predict_y)
  88. # print()
  89. # print(std_predict_y.data.flatten().tolist())
  90. # check forward
  91. for x, y in zip(predict_y, std_predict_y.data.flatten().tolist()):
  92. if (abs(x-y) > 1e-4):
  93. assert 0, "Forward data mismatch!"
  94. # print(loss, std_loss.data)
  95. # check loss
  96. if abs(loss[0] - std_loss.data) > 1e-4:
  97. assert 0, "Loss mismatch!"
  98. # check backward
  99. for i, (x, y) in enumerate(zip(g_w1, std_g_w1)):
  100. if (abs(x-y) > 1e-4):
  101. assert 0, f"Gradient w1 mismatch at position {i}, g_w1 is {x} while std g_w1 is {y}"
  102. for i, (x, y) in enumerate(zip(g_b1, std_g_b1)):
  103. if (abs(x-y) > 1e-4):
  104. assert 0, f"Gradient b1 mismatch at position {i}, g_b1 is {x} while std g_b1 is {y}"
  105. for i, (x, y) in enumerate(zip(g_w2, std_g_w2)):
  106. if (abs(x-y) > 1e-4):
  107. assert 0, f"Gradient w2 mismatch at position {i}, g_w2 is {x} while std g_w2 is {y}"
  108. for i, (x, y) in enumerate(zip(g_b2, std_g_b2)):
  109. if (abs(x-y) > 1e-4):
  110. assert 0, f"Gradient b2 mismatch at position {i}, g_b2 is {x} while std g_b2 is {y}"
  111. # check update
  112. for i, (x, y) in enumerate(zip(new_b1, std_new_b1)):
  113. if (abs(x-y) > 1e-4):
  114. assert 0, f"Updated b1 mismatch at position {i}, new_b1 is {x} while std new_b1 is {y}"
  115. for i, (x, y) in enumerate(zip(new_w1, std_new_w1)):
  116. if (abs(x-y) > 1e-4):
  117. assert 0, f"Updated w1 mismatch at position {i}, new_w1 is {x} while std new_w1 is {y}"
  118. # for i, (x, y) in enumerate(zip(new_b2, std_new_b2)):
  119. # if (abs(x-y) > 1e-4):
  120. # assert 0, f"Updated b2 mismatch at position {i}, new_b2 is {x} while std new_b2 is {y}"
  121. # for i, (x, y) in enumerate(zip(new_w2, std_new_w2)):
  122. # if (abs(x-y) > 1e-4):
  123. # assert 0, f"Updated w2 mismatch at position {i}, new_w2 is {x} while std new_w2 is {y}"
  124. print("Test passed")

计算机大作业