|
|
|
@@ -18,8 +18,9 @@ public: |
|
|
|
std::size_t size; |
|
|
|
|
|
|
|
public: |
|
|
|
Tensor(const std::vector<std::size_t>& shape, bool rand_init = false) { |
|
|
|
this->size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); |
|
|
|
Tensor(const std::vector<std::size_t>& shape, bool rand_init = false) |
|
|
|
{ |
|
|
|
this->size = std::accumulate(shape.begin(), shape.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>()); |
|
|
|
this->data.resize(this->size); |
|
|
|
this->shape = shape; |
|
|
|
if (rand_init) { |
|
|
|
@@ -31,6 +32,21 @@ public: |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
Tensor(const std::vector<std::size_t>& shape, const std::vector<float>& data) |
|
|
|
{ |
|
|
|
// 计算总元素数(size) |
|
|
|
this->size = std::accumulate(shape.begin(), shape.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>()); |
|
|
|
|
|
|
|
// 校验 data 长度是否与 shape 匹配 |
|
|
|
if (data.size() != this->size) { |
|
|
|
throw std::runtime_error("Tensor 构造失败:data 长度与 shape 不匹配"); |
|
|
|
} |
|
|
|
|
|
|
|
// 初始化成员变量 |
|
|
|
this->shape = shape; |
|
|
|
this->data = data; // 直接复制传入的 data |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<Tensor> transpose(); |
|
|
|
|
|
|
|
Tensor operator+(const Tensor& other) const { |
|
|
|
|