You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autodiff.h 6.3 kB

11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #pragma once
  2. #include <vector>
  3. #include <memory>
  4. #include <cmath>
  5. #include <unordered_map>
  6. namespace autodiff {
  7. template<typename T, typename F>
  8. auto central_difference(std::vector<T>& vec, F func, std::size_t arg, float epsilon = 1e-6)
  9. {
  10. std::vector<T> vec_plus = vec;
  11. std::vector<T> vec_minus = vec;
  12. // 在第arg个参数上分别加上和减去epsilon
  13. vec_plus[arg] += epsilon;
  14. vec_minus[arg] -= epsilon;
  15. // 计算函数在两个扰动点的值
  16. auto f_plus = func(vec_plus);
  17. auto f_minus = func(vec_minus);
  18. // 应用中心差分公式计算导数
  19. return (f_plus - f_minus) / (2.0 * epsilon);
  20. }
  21. class ScalarFunction
  22. {
  23. public:
  24. float data;
  25. float grad;
  26. int degree = 0;
  27. public:
  28. ScalarFunction() {}
  29. }; // class ScalarFunction
  30. class ConstantScalar: public ScalarFunction {
  31. public:
  32. ConstantScalar(float data): ScalarFunction() {
  33. this->data = data;
  34. }
  35. }; // class ConstantScalar
  36. class Add: public ScalarFunction {
  37. public:
  38. std::shared_ptr<ScalarFunction> a;
  39. std::shared_ptr<ScalarFunction> b;
  40. public:
  41. // 思考这个构造函数的写法(或让LLM进行解释)
  42. Add(std::shared_ptr<ScalarFunction> a, std::shared_ptr<ScalarFunction> b): a(a), b(b) {
  43. this->data = a->data + b->data;
  44. this->degree = 2;
  45. }
  46. float forward() {
  47. return a->data + b->data;;
  48. }
  49. std::vector<float> backward(float d_input) {
  50. return {1.0f * d_input, 1.0f * d_input};
  51. }
  52. }; // class Add
  53. class Log: public ScalarFunction {
  54. public:
  55. std::shared_ptr<ScalarFunction> a;
  56. public:
  57. Log(std::shared_ptr<ScalarFunction> a): a(a) {
  58. this->data = this->forward();
  59. this->degree = 1;
  60. }
  61. float forward()
  62. {
  63. return logf(a->data);
  64. }
  65. std::vector<float> backward(float d_input)
  66. {
  67. return {(1.0f * d_input / a->data)};
  68. }
  69. }; // class Log
  70. class Mul: public ScalarFunction {
  71. public:
  72. std::shared_ptr<ScalarFunction> a;
  73. std::shared_ptr<ScalarFunction> b;
  74. public:
  75. Mul(std::shared_ptr<ScalarFunction> a, std::shared_ptr<ScalarFunction> b) : a(a), b(b) {
  76. this->data = this->forward();
  77. this->degree = 2;
  78. }
  79. float forward() {
  80. return a->data * b->data;
  81. }
  82. std::vector<float> backward(float d_input) {
  83. float grad_a = b->data * d_input; // a的梯度 = y * 上游梯度
  84. float grad_b = a->data * d_input; // b的梯度 = x * 上游梯度
  85. return {grad_a, grad_b};
  86. }
  87. }; // class Mul
  88. class Inv: public ScalarFunction {
  89. public:
  90. std::shared_ptr<ScalarFunction> a;
  91. public:
  92. Inv(std::shared_ptr<ScalarFunction> a): a(a) {
  93. this->data = this->forward();
  94. this->degree = 1;
  95. }
  96. float forward() {
  97. return 1.0f / a->data;
  98. }
  99. std::vector<float> backward(float d_input) {
  100. float x_squared = a->data * a->data; // x的平方
  101. return { -d_input / x_squared };
  102. }
  103. }; // class Inv
  104. class Sigmoid: public ScalarFunction {
  105. public:
  106. std::shared_ptr<ScalarFunction> a;
  107. public:
  108. Sigmoid(std::shared_ptr<ScalarFunction> a): a(a) {
  109. this->data = this->forward();
  110. this->degree = 1;
  111. }
  112. float forward() {
  113. if (this->a->data >= 0.0) {
  114. return 1.0 / (1.0 + expf(-this->a->data));
  115. }
  116. else {
  117. return expf(this->a->data) / (1.0 + expf(this->a->data));
  118. }
  119. }
  120. std::vector<float> backward(float d_input) {
  121. float sigmoid_val = this->data; // 直接使用前向计算好的Sigmoid值
  122. float grad = sigmoid_val * (1.0f - sigmoid_val) * d_input;
  123. return {grad};
  124. }
  125. }; // class Sigmoid
  126. // for testing
  127. bool test_central_difference() {
  128. std::vector<float> x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
  129. auto func = [](const std::vector<float>& x) -> float {
  130. return x[0] + x[1] + x[2] + x[3] + x[4];
  131. };
  132. auto grad = central_difference(x, func, 2);
  133. if (abs(grad-1.0f) > 0.05) {
  134. return false;
  135. }
  136. return true;
  137. }
  138. bool test_addscalar() {
  139. auto a = std::make_shared<ConstantScalar>(1.0f);
  140. auto b = std::make_shared<ConstantScalar>(2.0f);
  141. auto c = std::make_shared<Add>(a, b);
  142. if (c->data != 3.0f) {
  143. return false;
  144. }
  145. auto res = c->backward(2.0f);
  146. auto a_grad = res[0];
  147. auto b_grad = res[1];
  148. if (a_grad != 2.0f || b_grad != 2.0f) {
  149. return false;
  150. }
  151. return true;
  152. }
  153. bool test_mulscalar() {
  154. auto a = std::make_shared<ConstantScalar>(2.0f);
  155. auto b = std::make_shared<ConstantScalar>(3.0f);
  156. auto c = std::make_shared<Mul>(a, b);
  157. if (c->data != 6.0f) {
  158. return false;
  159. }
  160. auto res = c->backward(2.0f);
  161. auto a_grad = res[0];
  162. auto b_grad = res[1];
  163. if (a_grad != 6.0f || b_grad != 4.0f) {
  164. return false;
  165. }
  166. return true;
  167. }
  168. bool test_logscalar() {
  169. auto a = std::make_shared<ConstantScalar>(2.0f);
  170. auto b = std::make_shared<Log>(a);
  171. if (abs(b->data - logf(2.0f)) > 1e-4) {
  172. return false;
  173. }
  174. auto res = b->backward(2.0f);
  175. auto a_grad = res[0];
  176. if (abs(a_grad - 1.0f) > 1e-4) {
  177. return false;
  178. }
  179. return true;
  180. }
  181. bool test_invscalar() {
  182. auto a = std::make_shared<ConstantScalar>(2.0f);
  183. auto b = std::make_shared<Inv>(a);
  184. if (abs(b->data - 0.5f) > 1e-4) {
  185. return false;
  186. }
  187. auto res = b->backward(2.0f);
  188. auto a_grad = res[0];
  189. if (abs(a_grad + 0.5f) > 1e-4) {
  190. return false;
  191. }
  192. return true;
  193. }
  194. bool test_sigmoidscalar() {
  195. auto a = std::make_shared<ConstantScalar>(2.0f);
  196. auto b = std::make_shared<Sigmoid>(a);
  197. // TODO:麻烦自己写下测试用例,谢谢
  198. // 禁止直接return true,世界上最聪明的智能人工将会逐一检查这段代码
  199. float expected_data = 1.0f / (1.0f + expf(-2.0f));
  200. if (abs(b->data - expected_data) > 1e-4) {
  201. return false;
  202. }
  203. // 反向传播测试:手动传入上游梯度2.0f
  204. auto res = b->backward(2.0f);
  205. auto a_grad = res[0];
  206. // 计算理论梯度:dσ/dx = σ(x)·(1-σ(x)),再乘以2.0f
  207. float sigmoid_val = expected_data;
  208. float expected_grad = sigmoid_val * (1.0f - sigmoid_val) * 2.0f;
  209. if (abs(a_grad - expected_grad) > 1e-4) {
  210. return false;
  211. }
  212. return true;
  213. }
  214. }