You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autodiff.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #pragma once
  2. #include <vector>
  3. #include <memory>
  4. #include <cmath>
  5. #include <unordered_map>
  6. namespace autodiff {
  7. template<typename T, typename F>
  8. auto central_difference(std::vector<T>& vec, F func, std::size_t arg, float epsilon = 1e-6) -> decltype(func(vec))
  9. {
  10. std::vector<T> vec1=vec;
  11. std::vector<T> vec2=vec;
  12. vec1[arg]+=epsilon;
  13. vec2[arg]-=epsilon;
  14. return (func(vec1)-func(vec2))/(2.0*epsilon);
  15. }
  16. class ScalarFunction {
  17. public:
  18. float data;
  19. float grad;
  20. int degree = 0;
  21. public:
  22. ScalarFunction() {}
  23. }; // class ScalarFunction
  24. class ConstantScalar: public ScalarFunction {
  25. public:
  26. ConstantScalar(float data): ScalarFunction() {
  27. this->data = data;
  28. }
  29. }; // class ConstantScalar
  30. class Add: public ScalarFunction {
  31. public:
  32. std::shared_ptr<ScalarFunction> a;
  33. std::shared_ptr<ScalarFunction> b;
  34. public:
  35. // 思考这个构造函数的写法(或让LLM进行解释)
  36. Add(std::shared_ptr<ScalarFunction> a, std::shared_ptr<ScalarFunction> b): a(a), b(b) {
  37. this->data = a->data + b->data;
  38. this->degree = 2;
  39. }
  40. float forward() {
  41. // 修改这里的return
  42. return a->data+b->data;
  43. }
  44. std::vector<float> backward(float d_input) {
  45. // 修改这里的return
  46. return {d_input, d_input};
  47. }
  48. }; // class Add
  49. class Log: public ScalarFunction {
  50. public:
  51. std::shared_ptr<ScalarFunction> a;
  52. public:
  53. Log(std::shared_ptr<ScalarFunction> a): a(a) {
  54. this->data = this->forward();
  55. this->degree = 1;
  56. }
  57. float forward() {
  58. // 补全这里的return语句
  59. return logf(a->data);
  60. }
  61. std::vector<float> backward(float d_input) {
  62. // 算了,我来帮你写求导的部分吧
  63. // 估计你已经忘记$log(x)$求导是什么了
  64. return {(1.0f * d_input / a->data)};
  65. }
  66. }; // class Log
  67. class Mul: public ScalarFunction {
  68. public:
  69. std::shared_ptr<ScalarFunction> a;
  70. std::shared_ptr<ScalarFunction> b;
  71. public:
  72. Mul(std::shared_ptr<ScalarFunction> a, std::shared_ptr<ScalarFunction> b) : a(a), b(b) {
  73. this->data = this->forward();
  74. this->degree = 2;
  75. }
  76. float forward() {
  77. // 修改这里的return
  78. return 1.0f* (a->data)*(b->data);
  79. }
  80. std::vector<float> backward(float d_input) {
  81. // 修改这里的return
  82. return {1.0f*b->data, 1.0f*a->data};
  83. }
  84. }; // class Mul
  85. class Inv: public ScalarFunction {
  86. public:
  87. std::shared_ptr<ScalarFunction> a;
  88. public:
  89. Inv(std::shared_ptr<ScalarFunction> a): a(a) {
  90. this->data = this->forward();
  91. this->degree = 1;
  92. }
  93. float forward() {
  94. return 1.0f / a->data;
  95. }
  96. std::vector<float> backward(float d_input) {
  97. // 修改这里的return语句
  98. // 1/x求导是-1/x^2
  99. return {1.0f*d_input/(a->data*a->data)};
  100. }
  101. }; // class Inv
  102. class Sigmoid: public ScalarFunction {
  103. public:
  104. std::shared_ptr<ScalarFunction> a;
  105. public:
  106. Sigmoid(std::shared_ptr<ScalarFunction> a): a(a) {
  107. this->data = this->forward();
  108. this->degree = 1;
  109. }
  110. float forward() {
  111. if (this->a->data >= 0.0) {
  112. return 1.0 / (1.0 + expf(-this->a->data));
  113. }
  114. else {
  115. return expf(this->a->data) / (1.0 + expf(this->a->data));
  116. }
  117. }
  118. std::vector<float> backward(float d_input) {
  119. // 你还是来求一下导吧,预防上大学以后变傻了
  120. // 补全这里的代码
  121. return {1.0f*d_input*expf(-this->a->data)/((1-expf(-this->a->data))*(1-expf(-this->a->data)))};
  122. }
  123. }; // class Sigmoid
  124. // for testing
  125. bool test_central_difference() {
  126. std::vector<float> x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
  127. auto func = [](const std::vector<float>& x) -> float {
  128. return x[0] + x[1] + x[2] + x[3] + x[4];
  129. };
  130. auto grad = central_difference(x, func, 2);
  131. if (abs(grad-1.0f) > 1e-4) {
  132. return false;
  133. }
  134. return true;
  135. }
  136. bool test_addscalar() {
  137. auto a = std::make_shared<ConstantScalar>(1.0f);
  138. auto b = std::make_shared<ConstantScalar>(2.0f);
  139. auto c = std::make_shared<Add>(a, b);
  140. if (c->data != 3.0f) {
  141. return false;
  142. }
  143. auto res = c->backward(2.0f);
  144. auto a_grad = res[0];
  145. auto b_grad = res[1];
  146. if (a_grad != 2.0f || b_grad != 2.0f) {
  147. return false;
  148. }
  149. return true;
  150. }
  151. bool test_mulscalar() {
  152. auto a = std::make_shared<ConstantScalar>(2.0f);
  153. auto b = std::make_shared<ConstantScalar>(3.0f);
  154. auto c = std::make_shared<Mul>(a, b);
  155. if (c->data != 6.0f) {
  156. return false;
  157. }
  158. auto res = c->backward(2.0f);
  159. auto a_grad = res[0];
  160. auto b_grad = res[1];
  161. if (a_grad != 6.0f || b_grad != 4.0f) {
  162. return false;
  163. }
  164. return true;
  165. }
  166. bool test_logscalar() {
  167. auto a = std::make_shared<ConstantScalar>(2.0f);
  168. auto b = std::make_shared<Log>(a);
  169. if (abs(b->data - logf(2.0f)) > 1e-4) {
  170. return false;
  171. }
  172. auto res = b->backward(2.0f);
  173. auto a_grad = res[0];
  174. if (abs(a_grad - 1.0f) > 1e-4) {
  175. return false;
  176. }
  177. return true;
  178. }
  179. bool test_invscalar() {
  180. auto a = std::make_shared<ConstantScalar>(2.0f);
  181. auto b = std::make_shared<Inv>(a);
  182. if (abs(b->data - 0.5f) > 1e-4) {
  183. return false;
  184. }
  185. auto res = b->backward(2.0f);
  186. auto a_grad = res[0];
  187. if (abs(a_grad + 0.5f) > 1e-4) {
  188. return false;
  189. }
  190. return true;
  191. }
  192. bool test_sigmoidscalar() {
  193. auto a = std::make_shared<ConstantScalar>(2.0f);
  194. auto b = std::make_shared<Sigmoid>(a);
  195. // TODO:麻烦自己写下测试用例,谢谢
  196. // 禁止直接return true,世界上最聪明的智能人工将会逐一检查这段代码
  197. return false;
  198. }
  199. }