#pragma once #include #include #include #include #include #include #include namespace py = pybind11; namespace tensor { class Tensor { public: std::vector data; std::vector shape; std::size_t size; public: Tensor(const std::vector& shape, bool rand_init = false) { this->size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); this->data.resize(this->size); this->shape = shape; if (rand_init) { double limit = std::sqrt(3.0 / ((shape[0] + shape[1]) / 2.0)); std::mt19937 gen(42); std::uniform_real_distribution dis(-limit, limit); for (std::size_t i = 0; i < this->size; ++i) { this->data[i] = dis(gen); } } } Tensor(const std::vector& shape, const std::vector& data) { // 计算总元素数(size) this->size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); // 校验 data 长度是否与 shape 匹配 if (data.size() != this->size) { throw std::runtime_error("Tensor 构造失败:data 长度与 shape 不匹配"); } // 初始化成员变量 this->shape = shape; this->data = data; // 直接复制传入的 data } std::shared_ptr transpose(); Tensor operator+(const Tensor& other) const { if (this->shape != other.shape) { throw std::runtime_error("Shapes do not match"); } Tensor result(this->shape); for (std::size_t i = 0; i < this->size; ++i) { result.data[i] = this->data[i] + other.data[i]; } return result; } Tensor operator=(const Tensor& other) const { if (this->shape != other.shape) { throw std::runtime_error("Shapes do not match"); } Tensor result(this->shape); for (auto i = 0; i < this->size; i++) { result.data[i] = (this->data[i] == other.data[i]); } return result; } std::vector get_shape() const { return this->shape; } std::vector get_data() const { return this->data; } float get(const std::vector& indices) const { std::size_t index = 0; std::size_t stride = 1; for (int i = shape.size() - 1; i >= 0; i--) { index += indices[i] * stride; stride *= shape[i]; } return data[index]; } void set(const std::vector& indices, float value) { std::size_t index = 0; std::size_t stride = 1; for (int i = shape.size() - 1; i >= 0; i--) { index += indices[i] * stride; stride *= shape[i]; } data[index] = value; } ~Tensor() = default; }; // class Tensor std::shared_ptr pyarray_to_tensor(py::array_t array); std::shared_ptr argmax(const std::shared_ptr& tensor, int axis); std::shared_ptr mean(const std::shared_ptr& tensor); std::shared_ptr exp(const std::shared_ptr& tensor); } // namespace tensor