You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autodiff.h 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. #pragma once
  2. #include <vector>
  3. #include <memory>
  4. #include <cmath>
  5. #include <unordered_map>
  6. namespace autodiff {
  7. template<typename T, typename F>
  8. auto central_difference(std::vector<T>& vec, F func, std::size_t arg, float epsilon = 1e-6) {
  9. // 补全函数,并修改return语句
  10. return 0;
  11. }
  12. class ScalarFunction {
  13. public:
  14. float data;
  15. float grad;
  16. int degree = 0;
  17. public:
  18. ScalarFunction() {}
  19. }; // class ScalarFunction
  20. class ConstantScalar: public ScalarFunction {
  21. public:
  22. ConstantScalar(float data): ScalarFunction() {
  23. this->data = data;
  24. }
  25. }; // class ConstantScalar
  26. class Add: public ScalarFunction {
  27. public:
  28. std::shared_ptr<ScalarFunction> a;
  29. std::shared_ptr<ScalarFunction> b;
  30. public:
  31. // 思考这个构造函数的写法(或让LLM进行解释)
  32. Add(std::shared_ptr<ScalarFunction> a, std::shared_ptr<ScalarFunction> b): a(a), b(b) {
  33. this->data = a->data + b->data;
  34. this->degree = 2;
  35. }
  36. float forward() {
  37. // 修改这里的return
  38. return 0;
  39. }
  40. std::vector<float> backward(float d_input) {
  41. // 修改这里的return
  42. return {0, 0};
  43. }
  44. }; // class Add
  45. class Log: public ScalarFunction {
  46. public:
  47. std::shared_ptr<ScalarFunction> a;
  48. public:
  49. Log(std::shared_ptr<ScalarFunction> a): a(a) {
  50. this->data = this->forward();
  51. this->degree = 1;
  52. }
  53. float forward() {
  54. // 补全这里的return语句
  55. return 0.0f;
  56. }
  57. std::vector<float> backward(float d_input) {
  58. // 算了,我来帮你写求导的部分吧
  59. // 估计你已经忘记$log(x)$求导是什么了
  60. return {(1.0f * d_input / a->data)};
  61. }
  62. }; // class Log
  63. class Mul: public ScalarFunction {
  64. public:
  65. std::shared_ptr<ScalarFunction> a;
  66. std::shared_ptr<ScalarFunction> b;
  67. public:
  68. Mul(std::shared_ptr<ScalarFunction> a, std::shared_ptr<ScalarFunction> b) : a(a), b(b) {
  69. this->data = this->forward();
  70. this->degree = 2;
  71. }
  72. float forward() {
  73. // 修改这里的return
  74. return 0;
  75. }
  76. std::vector<float> backward(float d_input) {
  77. // 修改这里的return
  78. return {0, 0};
  79. }
  80. }; // class Mul
  81. class Inv: public ScalarFunction {
  82. public:
  83. std::shared_ptr<ScalarFunction> a;
  84. public:
  85. Inv(std::shared_ptr<ScalarFunction> a): a(a) {
  86. this->data = this->forward();
  87. this->degree = 1;
  88. }
  89. float forward() {
  90. return 1.0f / a->data;
  91. }
  92. std::vector<float> backward(float d_input) {
  93. // 修改这里的return语句
  94. // 1/x求导是-1/x^2
  95. return {0.0f};
  96. }
  97. }; // class Inv
  98. class Sigmoid: public ScalarFunction {
  99. public:
  100. std::shared_ptr<ScalarFunction> a;
  101. public:
  102. Sigmoid(std::shared_ptr<ScalarFunction> a): a(a) {
  103. this->data = this->forward();
  104. this->degree = 1;
  105. }
  106. float forward() {
  107. if (this->a->data >= 0.0) {
  108. return 1.0 / (1.0 + expf(-this->a->data));
  109. }
  110. else {
  111. return expf(this->a->data) / (1.0 + expf(this->a->data));
  112. }
  113. }
  114. std::vector<float> backward(float d_input) {
  115. // 你还是来求一下导吧,预防上大学以后变傻了
  116. // 补全这里的代码
  117. return {0.0f};
  118. }
  119. }; // class Sigmoid
  120. // for testing
  121. bool test_central_difference() {
  122. std::vector<float> x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
  123. auto func = [](const std::vector<float>& x) -> float {
  124. return x[0] + x[1] + x[2] + x[3] + x[4];
  125. };
  126. auto grad = central_difference(x, func, 2);
  127. if (abs(grad-1.0f) > 1e-4) {
  128. return false;
  129. }
  130. return true;
  131. }
  132. bool test_addscalar() {
  133. auto a = std::make_shared<ConstantScalar>(1.0f);
  134. auto b = std::make_shared<ConstantScalar>(2.0f);
  135. auto c = std::make_shared<Add>(a, b);
  136. if (c->data != 3.0f) {
  137. return false;
  138. }
  139. auto res = c->backward(2.0f);
  140. auto a_grad = res[0];
  141. auto b_grad = res[1];
  142. if (a_grad != 2.0f || b_grad != 2.0f) {
  143. return false;
  144. }
  145. return true;
  146. }
  147. bool test_mulscalar() {
  148. auto a = std::make_shared<ConstantScalar>(2.0f);
  149. auto b = std::make_shared<ConstantScalar>(3.0f);
  150. auto c = std::make_shared<Mul>(a, b);
  151. if (c->data != 6.0f) {
  152. return false;
  153. }
  154. auto res = c->backward(2.0f);
  155. auto a_grad = res[0];
  156. auto b_grad = res[1];
  157. if (a_grad != 6.0f || b_grad != 4.0f) {
  158. return false;
  159. }
  160. return true;
  161. }
  162. bool test_logscalar() {
  163. auto a = std::make_shared<ConstantScalar>(2.0f);
  164. auto b = std::make_shared<Log>(a);
  165. if (abs(b->data - logf(2.0f)) > 1e-4) {
  166. return false;
  167. }
  168. auto res = b->backward(2.0f);
  169. auto a_grad = res[0];
  170. if (abs(a_grad - 1.0f) > 1e-4) {
  171. return false;
  172. }
  173. return true;
  174. }
  175. bool test_invscalar() {
  176. auto a = std::make_shared<ConstantScalar>(2.0f);
  177. auto b = std::make_shared<Inv>(a);
  178. if (abs(b->data - 0.5f) > 1e-4) {
  179. return false;
  180. }
  181. auto res = b->backward(2.0f);
  182. auto a_grad = res[0];
  183. if (abs(a_grad + 0.5f) > 1e-4) {
  184. return false;
  185. }
  186. return true;
  187. }
  188. bool test_sigmoidscalar() {
  189. auto a = std::make_shared<ConstantScalar>(2.0f);
  190. auto b = std::make_shared<Sigmoid>(a);
  191. // TODO:麻烦自己写下测试用例,谢谢
  192. // 禁止直接return true,世界上最聪明的智能人工将会逐一检查这段代码
  193. return false;
  194. }
  195. }

计算机大作业