You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 3.4 kB

11 months ago
11 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #pragma once
  2. #include <numeric>
  3. #include <random>
  4. #include <vector>
  5. #include <memory>
  6. #include <stdexcept>
  7. #include <pybind11/pybind11.h>
  8. #include <pybind11/numpy.h>
  9. namespace py = pybind11;
  10. namespace tensor {
  11. class Tensor {
  12. public:
  13. std::vector<float> data;
  14. std::vector<std::size_t> shape;
  15. std::size_t size;
  16. public:
  17. Tensor(const std::vector<std::size_t>& shape, bool rand_init = false)
  18. {
  19. this->size = std::accumulate(shape.begin(), shape.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>());
  20. this->data.resize(this->size);
  21. this->shape = shape;
  22. if (rand_init) {
  23. double limit = std::sqrt(3.0 / ((shape[0] + shape[1]) / 2.0));
  24. std::mt19937 gen(42);
  25. std::uniform_real_distribution<float> dis(-limit, limit);
  26. for (std::size_t i = 0; i < this->size; ++i) {
  27. this->data[i] = dis(gen);
  28. }
  29. }
  30. }
  31. Tensor(const std::vector<std::size_t>& shape, const std::vector<float>& data)
  32. {
  33. // 计算总元素数(size)
  34. this->size = std::accumulate(shape.begin(), shape.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>());
  35. // 校验 data 长度是否与 shape 匹配
  36. if (data.size() != this->size) {
  37. throw std::runtime_error("Tensor 构造失败:data 长度与 shape 不匹配");
  38. }
  39. // 初始化成员变量
  40. this->shape = shape;
  41. this->data = data; // 直接复制传入的 data
  42. }
  43. std::shared_ptr<Tensor> transpose();
  44. Tensor operator+(const Tensor& other) const {
  45. if (this->shape != other.shape) {
  46. throw std::runtime_error("Shapes do not match");
  47. }
  48. Tensor result(this->shape);
  49. for (std::size_t i = 0; i < this->size; ++i) {
  50. result.data[i] = this->data[i] + other.data[i];
  51. }
  52. return result;
  53. }
  54. Tensor operator=(const Tensor& other) const {
  55. if (this->shape != other.shape) {
  56. throw std::runtime_error("Shapes do not match");
  57. }
  58. Tensor result(this->shape);
  59. for (auto i = 0; i < this->size; i++) {
  60. result.data[i] = (this->data[i] == other.data[i]);
  61. }
  62. return result;
  63. }
  64. std::vector<std::size_t> get_shape() const {
  65. return this->shape;
  66. }
  67. std::vector<float> get_data() const {
  68. return this->data;
  69. }
  70. float get(const std::vector<std::size_t>& indices) const {
  71. std::size_t index = 0;
  72. std::size_t stride = 1;
  73. for (int i = shape.size() - 1; i >= 0; i--) {
  74. index += indices[i] * stride;
  75. stride *= shape[i];
  76. }
  77. return data[index];
  78. }
  79. void set(const std::vector<std::size_t>& indices, float value) {
  80. std::size_t index = 0;
  81. std::size_t stride = 1;
  82. for (int i = shape.size() - 1; i >= 0; i--) {
  83. index += indices[i] * stride;
  84. stride *= shape[i];
  85. }
  86. data[index] = value;
  87. }
  88. ~Tensor() = default;
  89. }; // class Tensor
  90. std::shared_ptr<Tensor> pyarray_to_tensor(py::array_t<float> array);
  91. std::shared_ptr<Tensor> argmax(const std::shared_ptr<Tensor>& tensor, int axis);
  92. std::shared_ptr<Tensor> mean(const std::shared_ptr<Tensor>& tensor);
  93. std::shared_ptr<Tensor> exp(const std::shared_ptr<Tensor>& tensor);
  94. } // namespace tensor