From b6def1d51698bc2bde2b40052206e94479c866ce Mon Sep 17 00:00:00 2001 From: Precreator <1689487228@qq.com> Date: Thu, 26 Jun 2025 16:16:24 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=A8=E4=BA=86autodiff.h=E7=9A=84?= =?UTF-8?q?=E4=B8=80=E7=B3=BB=E5=88=97=E4=BB=A3=E7=A0=81=EF=BC=8C=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E4=BA=86=E8=8B=A5=E5=B9=B2=E5=87=BD=E6=95=B0=E7=9A=84?= =?UTF-8?q?=E5=90=91=E5=89=8D=E4=BC=A0=E6=92=AD=E5=92=8C=E5=8F=8D=E5=90=91?= =?UTF-8?q?=E4=BC=A0=E6=92=AD=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cc/operators/autodiff.h | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/cc/operators/autodiff.h b/cc/operators/autodiff.h index a485db5..f250267 100644 --- a/cc/operators/autodiff.h +++ b/cc/operators/autodiff.h @@ -7,9 +7,13 @@ namespace autodiff { template -auto central_difference(std::vector& vec, F func, std::size_t arg, float epsilon = 1e-6) { - // 补全函数,并修改return语句 - return 0; +auto central_difference(std::vector& vec, F func, std::size_t arg, float epsilon = 1e-6) -> decltype(func(vec)) +{ + std::vector vec1=vec; + std::vector vec2=vec; + vec1[arg]+=epsilon; + vec2[arg]-=epsilon; + return (func(vec1)-func(vec2))/(2.0*epsilon); } class ScalarFunction { @@ -40,11 +44,11 @@ public: } float forward() { // 修改这里的return - return 0; + return a->data+b->data; } std::vector backward(float d_input) { // 修改这里的return - return {0, 0}; + return {d_input, d_input}; } }; // class Add @@ -58,7 +62,7 @@ public: } float forward() { // 补全这里的return语句 - return 0.0f; + return logf(a->data); } std::vector backward(float d_input) { // 算了,我来帮你写求导的部分吧 @@ -78,11 +82,11 @@ public: } float forward() { // 修改这里的return - return 0; + return 1.0f* (a->data)*(b->data); } std::vector backward(float d_input) { // 修改这里的return - return {0, 0}; + return {1.0f*b->data, 1.0f*a->data}; } }; // class Mul @@ -100,7 +104,7 @@ public: std::vector backward(float d_input) { // 修改这里的return语句 // 1/x求导是-1/x^2 - return {0.0f}; + return {1.0f*d_input/(a->data*a->data)}; } }; // class Inv @@ -123,7 +127,7 @@ public: std::vector backward(float d_input) { // 你还是来求一下导吧,预防上大学以后变傻了 // 补全这里的代码 - return {0.0f}; + return {1.0f*d_input*expf(-this->a->data)/((1-expf(-this->a->data))*(1-expf(-this->a->data)))}; } }; // class Sigmoid