Browse Source

debug 2

master
Precreator 11 months ago
parent
commit
badf70d46d
4 changed files with 22 additions and 5 deletions
  1. +2
    -1
      cc/operators/ops.cc
  2. +1
    -1
      cc/operators/ops.h
  3. +1
    -1
      cc/tensor/tensor.cc
  4. +18
    -2
      cc/tensor/tensor.h

+ 2
- 1
cc/operators/ops.cc View File

@@ -45,7 +45,7 @@ auto prodList(const std::vector<float>& vec) -> float {

auto addLists(const std::vector<float>& vec1, const std::vector<float>& vec2) -> std::vector<float> {
// 请修改这里的return语句
return zipWith(vec1, vec2, add<float>);
return zipWith(vec1, vec2, add<float>);
}

auto negList(const std::vector<float>& vec) -> std::vector<float> {
@@ -53,3 +53,4 @@ auto negList(const std::vector<float>& vec) -> std::vector<float> {
return map(vec,neg<float>);
}
}


+ 1
- 1
cc/operators/ops.h View File

@@ -63,7 +63,7 @@ auto zipWith(const std::vector<T1>& vec1, const std::vector<T2>& vec2, F func)
// 我们已经在这里throw一个异常
throw std::invalid_argument("Vectors must have the same size");
}
result.reverse(vec1.size());
result.reserve(vec1.size());
for(size_t i=0;i<vec1.size();++i){
result.push_back(func(vec1[i],vec2[i]));
}


+ 1
- 1
cc/tensor/tensor.cc View File

@@ -124,4 +124,4 @@ std::shared_ptr<Tensor> exp(const std::shared_ptr<Tensor>& tensor) {
return result;
}

}
}//Tensor

+ 18
- 2
cc/tensor/tensor.h View File

@@ -18,8 +18,9 @@ public:
std::size_t size;

public:
Tensor(const std::vector<std::size_t>& shape, bool rand_init = false) {
this->size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
Tensor(const std::vector<std::size_t>& shape, bool rand_init = false)
{
this->size = std::accumulate(shape.begin(), shape.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>());
this->data.resize(this->size);
this->shape = shape;
if (rand_init) {
@@ -31,6 +32,21 @@ public:
}
}
}
Tensor(const std::vector<std::size_t>& shape, const std::vector<float>& data)
{
// 计算总元素数(size)
this->size = std::accumulate(shape.begin(), shape.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>());

// 校验 data 长度是否与 shape 匹配
if (data.size() != this->size) {
throw std::runtime_error("Tensor 构造失败:data 长度与 shape 不匹配");
}

// 初始化成员变量
this->shape = shape;
this->data = data; // 直接复制传入的 data
}
std::shared_ptr<Tensor> transpose();

Tensor operator+(const Tensor& other) const {


Loading…
Cancel
Save