|
|
@@ -155,9 +155,6 @@ public: |
|
|
// 从张量形状获取维度信息 |
|
|
// 从张量形状获取维度信息 |
|
|
auto batch_size = gradient->shape[0]; |
|
|
auto batch_size = gradient->shape[0]; |
|
|
auto num_features = gradient->shape[1]; // 从shape中获取num_features |
|
|
auto num_features = gradient->shape[1]; // 从shape中获取num_features |
|
|
// 补全这里的代码 |
|
|
|
|
|
auto batch_size = gradient->shape[0]; |
|
|
|
|
|
auto num_features = gradient->shape[1]; |
|
|
|
|
|
|
|
|
|
|
|
// 初始化偏置梯度为零 |
|
|
// 初始化偏置梯度为零 |
|
|
for (size_t j = 0; j < num_features; ++j) |
|
|
for (size_t j = 0; j < num_features; ++j) |
|
|
|