Browse Source

完成了test15 16,基本完成了线性层的代码部分

pull/7/head
Precreator 11 months ago
parent
commit
feb8759764
1 changed files with 31 additions and 2 deletions
  1. +31
    -2
      cc/operators/nn.h

+ 31
- 2
cc/operators/nn.h View File

@@ -135,15 +135,43 @@ public:
auto features = this->objects[0]; auto features = this->objects[0];
auto bias = this->objects[1]; auto bias = this->objects[1];
auto outNode = std::make_shared<tensor::Tensor>(features->data->shape); auto outNode = std::make_shared<tensor::Tensor>(features->data->shape);
auto batch_size = features->data->shape[0];
auto num_features = features->data->shape[1];
for (size_t i = 0; i < batch_size; ++i) {
for (size_t j = 0; j < num_features; ++j) {
// 计算索引:batch_size行,num_features列的二维张量
size_t idx = i * num_features + j;
// 每个样本的特征向量加上偏置向量
outNode->data[idx] = features->data->data[idx] + bias->data->data[j];
}
}
// for循环写加法总会写吧🤔 // for循环写加法总会写吧🤔
// 补全这里的代码 // 补全这里的代码
return outNode; return outNode;
} }
std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) override { std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) override {
// assertion needed // assertion needed
auto g_bias = std::make_shared<tensor::Tensor>(this->objects[1]->data->shape);
auto g_bias = std::make_shared<tensor::Tensor>(this->objects[1]->data->shape);
// 从张量形状获取维度信息
auto batch_size = gradient->shape[0];
auto num_features = gradient->shape[1]; // 从shape中获取num_features
// 补全这里的代码 // 补全这里的代码
auto batch_size = gradient->shape[0];
auto num_features = gradient->shape[1];
// 初始化偏置梯度为零
for (size_t j = 0; j < num_features; ++j)
{
g_bias->data[j] = 0.0f;
}
// 计算偏置的梯度:对每个特征维度,将所有样本的梯度累加
for (size_t i = 0; i < batch_size; ++i) {
for (size_t j = 0; j < num_features; ++j) {
// 累加每个样本对该特征维度的梯度贡献
g_bias->data[j] += gradient->data[i * num_features + j];
}
}
return {gradient, g_bias}; return {gradient, g_bias};
} }
std::vector<float> get_data() { std::vector<float> get_data() {
@@ -155,6 +183,7 @@ public:
class Linear: public FunctionNode { class Linear: public FunctionNode {
public: public:
Linear(std::shared_ptr<Node> a, std::shared_ptr<Node> b) : FunctionNode(a, b) { Linear(std::shared_ptr<Node> a, std::shared_ptr<Node> b) : FunctionNode(a, b) {
this->data=this->forward();
// 这段代码就一行,参考下别的类是怎么写的呢? // 这段代码就一行,参考下别的类是怎么写的呢?
// 在这里补全 // 在这里补全
} }


Loading…
Cancel
Save