Browse Source

atuodiff

master
Precreator 11 months ago
parent
commit
47c6521ce1
1 changed files with 34 additions and 27 deletions
  1. +34
    -27
      cc/operators/autodiff.h

+ 34
- 27
cc/operators/autodiff.h View File

@@ -37,17 +37,14 @@ public:
std::shared_ptr<ScalarFunction> a;
std::shared_ptr<ScalarFunction> b;
public:
// 思考这个构造函数的写法(或让LLM进行解释)
Add(std::shared_ptr<ScalarFunction> a, std::shared_ptr<ScalarFunction> b): a(a), b(b) {
this->data = a->data + b->data;
this->degree = 2;
}
float forward() {
// 修改这里的return
return a->data+b->data;
return a->data + b->data;
}
std::vector<float> backward(float d_input) {
// 修改这里的return
return {d_input, d_input};
}
}; // class Add
@@ -61,13 +58,10 @@ public:
this->degree = 1;
}
float forward() {
// 补全这里的return语句
return logf(a->data);
}
std::vector<float> backward(float d_input) {
// 算了,我来帮你写求导的部分吧
// 估计你已经忘记$log(x)$求导是什么了
return {(1.0f * d_input / a->data)};
return {d_input / a->data};
}
}; // class Log

@@ -81,12 +75,10 @@ public:
this->degree = 2;
}
float forward() {
// 修改这里的return
return 1.0f* (a->data)*(b->data);
return a->data * b->data;
}
std::vector<float> backward(float d_input) {
// 修改这里的return
return {1.0f*b->data, 1.0f*a->data};
return {b->data * d_input, a->data * d_input};
}
}; // class Mul

@@ -102,9 +94,7 @@ public:
return 1.0f / a->data;
}
std::vector<float> backward(float d_input) {
// 修改这里的return语句
// 1/x求导是-1/x^2
return {1.0f*d_input/(a->data*a->data)};
return {-d_input / (a->data * a->data)};
}
}; // class Inv

@@ -117,17 +107,17 @@ public:
this->degree = 1;
}
float forward() {
if (this->a->data >= 0.0) {
return 1.0 / (1.0 + expf(-this->a->data));
}
else {
return expf(this->a->data) / (1.0 + expf(this->a->data));
float x = a->data;
if (x >= 0) {
return 1.0f / (1.0f + expf(-x));
} else {
float exp_x = expf(x);
return exp_x / (1.0f + exp_x);
}
}
std::vector<float> backward(float d_input) {
// 你还是来求一下导吧,预防上大学以后变傻了
// 补全这里的代码
return {1.0f*d_input*expf(-this->a->data)/((1-expf(-this->a->data))*(1-expf(-this->a->data)))};
float sig = this->data;
return {d_input * sig * (1.0f - sig)};
}
}; // class Sigmoid

@@ -207,9 +197,26 @@ bool test_invscalar() {
bool test_sigmoidscalar() {
auto a = std::make_shared<ConstantScalar>(2.0f);
auto b = std::make_shared<Sigmoid>(a);
// TODO:麻烦自己写下测试用例,谢谢
// 禁止直接return true,世界上最聪明的智能人工将会逐一检查这段代码
return false;
// 计算预期的sigmoid值
float expected_data = 1.0f / (1.0f + expf(-2.0f));
// 检查前向传播结果
if (abs(b->data - expected_data) > 1e-4) {
return false;
}
// 计算预期的导数
float expected_grad = expected_data * (1.0f - expected_data);
auto res = b->backward(2.0f);
auto a_grad = res[0];
// 检查反向传播结果
if (abs(a_grad - 2.0f * expected_grad) > 1e-4) {
return false;
}
return true;
}

}
}

Loading…
Cancel
Save