From badf70d46d5aae2afae6b0d79f013423bb49007e Mon Sep 17 00:00:00 2001 From: Precreator <1689487228@qq.com> Date: Tue, 8 Jul 2025 15:42:45 +0800 Subject: [PATCH] debug 2 --- cc/operators/ops.cc | 3 ++- cc/operators/ops.h | 2 +- cc/tensor/tensor.cc | 2 +- cc/tensor/tensor.h | 20 ++++++++++++++++++-- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/cc/operators/ops.cc b/cc/operators/ops.cc index c256d95..c06014a 100644 --- a/cc/operators/ops.cc +++ b/cc/operators/ops.cc @@ -45,7 +45,7 @@ auto prodList(const std::vector& vec) -> float { auto addLists(const std::vector& vec1, const std::vector& vec2) -> std::vector { // 请修改这里的return语句 - return zipWith(vec1, vec2, add); + return zipWith(vec1, vec2, add); } auto negList(const std::vector& vec) -> std::vector { @@ -53,3 +53,4 @@ auto negList(const std::vector& vec) -> std::vector { return map(vec,neg); } } + diff --git a/cc/operators/ops.h b/cc/operators/ops.h index bdf0a89..e2ace63 100644 --- a/cc/operators/ops.h +++ b/cc/operators/ops.h @@ -63,7 +63,7 @@ auto zipWith(const std::vector& vec1, const std::vector& vec2, F func) // 我们已经在这里throw一个异常 throw std::invalid_argument("Vectors must have the same size"); } - result.reverse(vec1.size()); + result.reserve(vec1.size()); for(size_t i=0;i exp(const std::shared_ptr& tensor) { return result; } -} \ No newline at end of file +}//Tensor \ No newline at end of file diff --git a/cc/tensor/tensor.h b/cc/tensor/tensor.h index adaef3c..c4011e1 100644 --- a/cc/tensor/tensor.h +++ b/cc/tensor/tensor.h @@ -18,8 +18,9 @@ public: std::size_t size; public: - Tensor(const std::vector& shape, bool rand_init = false) { - this->size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + Tensor(const std::vector& shape, bool rand_init = false) + { + this->size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); this->data.resize(this->size); this->shape = shape; if (rand_init) { @@ -31,6 +32,21 @@ public: } } } + Tensor(const std::vector& shape, const std::vector& data) + { + // 计算总元素数(size) + this->size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + + // 校验 data 长度是否与 shape 匹配 + if (data.size() != this->size) { + throw std::runtime_error("Tensor 构造失败:data 长度与 shape 不匹配"); + } + + // 初始化成员变量 + this->shape = shape; + this->data = data; // 直接复制传入的 data + } + std::shared_ptr transpose(); Tensor operator+(const Tensor& other) const {