diff --git a/cc/operators/autodiff.h b/cc/operators/autodiff.h index f250267..a8979bb 100644 --- a/cc/operators/autodiff.h +++ b/cc/operators/autodiff.h @@ -37,17 +37,14 @@ public: std::shared_ptr a; std::shared_ptr b; public: - // 思考这个构造函数的写法(或让LLM进行解释) Add(std::shared_ptr a, std::shared_ptr b): a(a), b(b) { this->data = a->data + b->data; this->degree = 2; } float forward() { - // 修改这里的return - return a->data+b->data; + return a->data + b->data; } std::vector backward(float d_input) { - // 修改这里的return return {d_input, d_input}; } }; // class Add @@ -61,13 +58,10 @@ public: this->degree = 1; } float forward() { - // 补全这里的return语句 return logf(a->data); } std::vector backward(float d_input) { - // 算了,我来帮你写求导的部分吧 - // 估计你已经忘记$log(x)$求导是什么了 - return {(1.0f * d_input / a->data)}; + return {d_input / a->data}; } }; // class Log @@ -81,12 +75,10 @@ public: this->degree = 2; } float forward() { - // 修改这里的return - return 1.0f* (a->data)*(b->data); + return a->data * b->data; } std::vector backward(float d_input) { - // 修改这里的return - return {1.0f*b->data, 1.0f*a->data}; + return {b->data * d_input, a->data * d_input}; } }; // class Mul @@ -102,9 +94,7 @@ public: return 1.0f / a->data; } std::vector backward(float d_input) { - // 修改这里的return语句 - // 1/x求导是-1/x^2 - return {1.0f*d_input/(a->data*a->data)}; + return {-d_input / (a->data * a->data)}; } }; // class Inv @@ -117,17 +107,17 @@ public: this->degree = 1; } float forward() { - if (this->a->data >= 0.0) { - return 1.0 / (1.0 + expf(-this->a->data)); - } - else { - return expf(this->a->data) / (1.0 + expf(this->a->data)); + float x = a->data; + if (x >= 0) { + return 1.0f / (1.0f + expf(-x)); + } else { + float exp_x = expf(x); + return exp_x / (1.0f + exp_x); } } std::vector backward(float d_input) { - // 你还是来求一下导吧,预防上大学以后变傻了 - // 补全这里的代码 - return {1.0f*d_input*expf(-this->a->data)/((1-expf(-this->a->data))*(1-expf(-this->a->data)))}; + float sig = this->data; + return {d_input * sig * (1.0f - sig)}; } }; // class Sigmoid @@ -207,9 +197,26 @@ bool test_invscalar() { bool test_sigmoidscalar() { auto a = std::make_shared(2.0f); auto b = std::make_shared(a); - // TODO:麻烦自己写下测试用例,谢谢 - // 禁止直接return true,世界上最聪明的智能人工将会逐一检查这段代码 - return false; + + // 计算预期的sigmoid值 + float expected_data = 1.0f / (1.0f + expf(-2.0f)); + + // 检查前向传播结果 + if (abs(b->data - expected_data) > 1e-4) { + return false; + } + + // 计算预期的导数 + float expected_grad = expected_data * (1.0f - expected_data); + auto res = b->backward(2.0f); + auto a_grad = res[0]; + + // 检查反向传播结果 + if (abs(a_grad - 2.0f * expected_grad) > 1e-4) { + return false; + } + + return true; } -} \ No newline at end of file +}