You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #pragma once
  2. #include <numeric>
  3. #include <random>
  4. #include <vector>
  5. #include <memory>
  6. #include <stdexcept>
  7. #include <pybind11/pybind11.h>
  8. #include <pybind11/numpy.h>
  9. namespace py = pybind11;
  10. namespace tensor {
  11. class Tensor {
  12. public:
  13. std::vector<float> data;
  14. std::vector<std::size_t> shape;
  15. std::size_t size;
  16. public:
  17. Tensor(const std::vector<std::size_t>& shape, bool rand_init = false) {
  18. this->size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
  19. this->data.resize(this->size);
  20. this->shape = shape;
  21. if (rand_init) {
  22. double limit = std::sqrt(3.0 / ((shape[0] + shape[1]) / 2.0));
  23. std::mt19937 gen(42);
  24. std::uniform_real_distribution<float> dis(-limit, limit);
  25. for (std::size_t i = 0; i < this->size; ++i) {
  26. this->data[i] = dis(gen);
  27. }
  28. }
  29. }
  30. std::shared_ptr<Tensor> transpose();
  31. Tensor operator+(const Tensor& other) const {
  32. if (this->shape != other.shape) {
  33. throw std::runtime_error("Shapes do not match");
  34. }
  35. Tensor result(this->shape);
  36. for (std::size_t i = 0; i < this->size; ++i) {
  37. result.data[i] = this->data[i] + other.data[i];
  38. }
  39. return result;
  40. }
  41. Tensor operator=(const Tensor& other) const {
  42. if (this->shape != other.shape) {
  43. throw std::runtime_error("Shapes do not match");
  44. }
  45. Tensor result(this->shape);
  46. for (auto i = 0; i < this->size; i++) {
  47. result.data[i] = (this->data[i] == other.data[i]);
  48. }
  49. return result;
  50. }
  51. std::vector<std::size_t> get_shape() const {
  52. return this->shape;
  53. }
  54. std::vector<float> get_data() const {
  55. return this->data;
  56. }
  57. float get(const std::vector<std::size_t>& indices) const {
  58. std::size_t index = 0;
  59. std::size_t stride = 1;
  60. for (int i = shape.size() - 1; i >= 0; i--) {
  61. index += indices[i] * stride;
  62. stride *= shape[i];
  63. }
  64. return data[index];
  65. }
  66. void set(const std::vector<std::size_t>& indices, float value) {
  67. std::size_t index = 0;
  68. std::size_t stride = 1;
  69. for (int i = shape.size() - 1; i >= 0; i--) {
  70. index += indices[i] * stride;
  71. stride *= shape[i];
  72. }
  73. data[index] = value;
  74. }
  75. ~Tensor() = default;
  76. }; // class Tensor
  77. std::shared_ptr<Tensor> pyarray_to_tensor(py::array_t<float> array);
  78. std::shared_ptr<Tensor> argmax(const std::shared_ptr<Tensor>& tensor, int axis);
  79. std::shared_ptr<Tensor> mean(const std::shared_ptr<Tensor>& tensor);
  80. std::shared_ptr<Tensor> exp(const std::shared_ptr<Tensor>& tensor);
  81. } // namespace tensor

计算机大作业

Contributors (1)