From feb8759764e70dbc9e1abb7312109ccd004adb16 Mon Sep 17 00:00:00 2001 From: Precreator <1689487228@qq.com> Date: Tue, 1 Jul 2025 11:22:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BA=86test15=2016,?= =?UTF-8?q?=E5=9F=BA=E6=9C=AC=E5=AE=8C=E6=88=90=E4=BA=86=E7=BA=BF=E6=80=A7?= =?UTF-8?q?=E5=B1=82=E7=9A=84=E4=BB=A3=E7=A0=81=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cc/operators/nn.h | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/cc/operators/nn.h b/cc/operators/nn.h index e4824e8..ec07c1c 100644 --- a/cc/operators/nn.h +++ b/cc/operators/nn.h @@ -135,15 +135,43 @@ public: auto features = this->objects[0]; auto bias = this->objects[1]; auto outNode = std::make_shared(features->data->shape); + auto batch_size = features->data->shape[0]; + auto num_features = features->data->shape[1]; + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_features; ++j) { + // 计算索引:batch_size行,num_features列的二维张量 + size_t idx = i * num_features + j; + // 每个样本的特征向量加上偏置向量 + outNode->data[idx] = features->data->data[idx] + bias->data->data[j]; + } + } // for循环写加法总会写吧🤔 // 补全这里的代码 return outNode; } std::vector> backward(std::shared_ptr gradient) override { // assertion needed - auto g_bias = std::make_shared(this->objects[1]->data->shape); + auto g_bias = std::make_shared(this->objects[1]->data->shape); + // 从张量形状获取维度信息 + auto batch_size = gradient->shape[0]; + auto num_features = gradient->shape[1]; // 从shape中获取num_features // 补全这里的代码 - + auto batch_size = gradient->shape[0]; + auto num_features = gradient->shape[1]; + + // 初始化偏置梯度为零 + for (size_t j = 0; j < num_features; ++j) + { + g_bias->data[j] = 0.0f; + } + + // 计算偏置的梯度:对每个特征维度,将所有样本的梯度累加 + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_features; ++j) { + // 累加每个样本对该特征维度的梯度贡献 + g_bias->data[j] += gradient->data[i * num_features + j]; + } + } return {gradient, g_bias}; } std::vector get_data() { @@ -155,6 +183,7 @@ public: class Linear: public FunctionNode { public: Linear(std::shared_ptr a, std::shared_ptr b) : FunctionNode(a, b) { + this->data=this->forward(); // 这段代码就一行,参考下别的类是怎么写的呢? // 在这里补全 }