Browse Source

补全了autodiff.h的一系列代码,实现了若干函数的向前传播和反向传播的功能

pull/7/head
Precreator 11 months ago
parent
commit
b6def1d516
1 changed files with 14 additions and 10 deletions
  1. +14
    -10
      cc/operators/autodiff.h

+ 14
- 10
cc/operators/autodiff.h View File

@@ -7,9 +7,13 @@
namespace autodiff { namespace autodiff {


template<typename T, typename F> template<typename T, typename F>
auto central_difference(std::vector<T>& vec, F func, std::size_t arg, float epsilon = 1e-6) {
// 补全函数,并修改return语句
return 0;
auto central_difference(std::vector<T>& vec, F func, std::size_t arg, float epsilon = 1e-6) -> decltype(func(vec))
{
std::vector<T> vec1=vec;
std::vector<T> vec2=vec;
vec1[arg]+=epsilon;
vec2[arg]-=epsilon;
return (func(vec1)-func(vec2))/(2.0*epsilon);
} }


class ScalarFunction { class ScalarFunction {
@@ -40,11 +44,11 @@ public:
} }
float forward() { float forward() {
// 修改这里的return // 修改这里的return
return 0;
return a->data+b->data;
} }
std::vector<float> backward(float d_input) { std::vector<float> backward(float d_input) {
// 修改这里的return // 修改这里的return
return {0, 0};
return {d_input, d_input};
} }
}; // class Add }; // class Add


@@ -58,7 +62,7 @@ public:
} }
float forward() { float forward() {
// 补全这里的return语句 // 补全这里的return语句
return 0.0f;
return logf(a->data);
} }
std::vector<float> backward(float d_input) { std::vector<float> backward(float d_input) {
// 算了,我来帮你写求导的部分吧 // 算了,我来帮你写求导的部分吧
@@ -78,11 +82,11 @@ public:
} }
float forward() { float forward() {
// 修改这里的return // 修改这里的return
return 0;
return 1.0f* (a->data)*(b->data);
} }
std::vector<float> backward(float d_input) { std::vector<float> backward(float d_input) {
// 修改这里的return // 修改这里的return
return {0, 0};
return {1.0f*b->data, 1.0f*a->data};
} }
}; // class Mul }; // class Mul


@@ -100,7 +104,7 @@ public:
std::vector<float> backward(float d_input) { std::vector<float> backward(float d_input) {
// 修改这里的return语句 // 修改这里的return语句
// 1/x求导是-1/x^2 // 1/x求导是-1/x^2
return {0.0f};
return {1.0f*d_input/(a->data*a->data)};
} }
}; // class Inv }; // class Inv


@@ -123,7 +127,7 @@ public:
std::vector<float> backward(float d_input) { std::vector<float> backward(float d_input) {
// 你还是来求一下导吧,预防上大学以后变傻了 // 你还是来求一下导吧,预防上大学以后变傻了
// 补全这里的代码 // 补全这里的代码
return {0.0f};
return {1.0f*d_input*expf(-this->a->data)/((1-expf(-this->a->data))*(1-expf(-this->a->data)))};
} }
}; // class Sigmoid }; // class Sigmoid




Loading…
Cancel
Save