You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cc 4.3 kB

11 months ago
11 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #include "tensor.h"
  2. namespace tensor {
  3. std::shared_ptr<Tensor> Tensor::transpose() {
  4. // 放心,下面的代码暂时不会被触发,我们假定所有的tensor都是2维的
  5. // if (shape.size() != 2) {
  6. // throw std::runtime_error("Transpose is only supported for 2D tensors.");
  7. // }
  8. // 这里能够获得矩阵的行数和列数,但是我们是使用一个一维的vector来存储数据的。该如何实现“转置”呢?
  9. std::size_t rows = shape[0];
  10. std::size_t cols = shape[1];
  11. std::vector<size_t> new_shape = {cols, rows};
  12. // 创建一个空的 Tensor 对象,不进行随机初始化
  13. auto transposed_tensor = std::make_shared<Tensor>(new_shape, false);
  14. // 计算转置后的数据
  15. for(std::size_t i = 0; i < cols; ++i) {
  16. for(std::size_t j = 0; j < rows; ++j) {
  17. // 计算转置后的索引
  18. std::size_t transposed_index = i * rows + j;
  19. // 计算原数据的索引
  20. std::size_t original_index = j * cols + i;
  21. transposed_tensor->data[transposed_index] = data[original_index];
  22. }
  23. }
  24. return transposed_tensor;
  25. }
  26. std::shared_ptr<Tensor> pyarray_to_tensor(py::array_t<float> array) {
  27. py::buffer_info info = array.request();
  28. float* dataPtr = static_cast<float*>(info.ptr);
  29. std::vector<std::size_t> shape = {};
  30. for (auto &it: info.shape) {
  31. shape.push_back(it);
  32. }
  33. auto tensor = std::make_shared<Tensor>(shape);
  34. std::vector<float> result(dataPtr, dataPtr + info.size);
  35. tensor->data = result;
  36. return tensor;
  37. }
  38. std::shared_ptr<Tensor> argmax(const std::shared_ptr<Tensor>& tensor, int axis) {
  39. // you only need to handle the two dimensional tensor, and the axis can be either 0 or 1
  40. // the tensor's shape is (batch_size, features)
  41. // if the axis is 0, it outputs a tensor (1, features)
  42. // if the axis is 1, it outputs a tensor (batch_size, 1)
  43. // compute the output's shape
  44. std::vector<std::size_t> output_shape = tensor->shape;
  45. output_shape.erase(output_shape.begin() + axis);
  46. auto result = std::make_shared<Tensor>(output_shape);
  47. // 这个问题似乎有点难,所以我们决定给你送点分。一个简单的办法是分axis为0还是为1来进行讨论,反正我们已经把问题简化为了,在一个二维的tensor里面,找到每一行或者每一列的最大值,并输出一个一维的tensor。
  48. // 补全这里的代码。
  49. size_t rows=tensor->shape[0];
  50. size_t cols=tensor->shape[1];
  51. if(axis==0)
  52. {
  53. output_shape={1,cols};
  54. }
  55. else if(axis==1)
  56. {
  57. output_shape={rows,1};
  58. }
  59. if(axis==0)//qiu lie xiang liang
  60. {
  61. for(size_t j=0;j<cols;j++)
  62. {
  63. float maxx=0;
  64. size_t maxx_id=0;//ji lu ID
  65. for(size_t i=0;i<rows;i++)
  66. {
  67. if(tensor->data[i*cols+j]>maxx)
  68. {
  69. maxx=tensor->data[i*cols+j];
  70. maxx_id=i;
  71. }
  72. }
  73. result->data[j] = static_cast<float>(maxx_id);
  74. }
  75. }
  76. else
  77. {
  78. for(size_t i=0;i<rows;i++)
  79. {
  80. float maxx=0;
  81. size_t maxx_id=0;
  82. for (size_t j=0;j<cols;j++)
  83. {
  84. if(maxx<tensor->data[i*cols+j])
  85. {
  86. maxx=tensor->data[i*cols+j];
  87. maxx_id=j;
  88. }
  89. }
  90. result->data[i] = static_cast<float>(maxx_id);
  91. }
  92. }
  93. return result;
  94. }
  95. std::shared_ptr<Tensor> mean(const std::shared_ptr<Tensor>& tensor) {
  96. std::vector<std::size_t> shape = {1};
  97. auto result = std::make_shared<Tensor>(shape);
  98. auto sum = 0.0f;
  99. for (auto &it: tensor->data) {
  100. sum += it;
  101. }
  102. sum /= tensor->size;
  103. result->data[0] = sum;
  104. return result;
  105. }
  106. std::shared_ptr<Tensor> exp(const std::shared_ptr<Tensor>& tensor) {
  107. auto result = std::make_shared<Tensor>(tensor->shape);
  108. for (auto i = 0; i < tensor->size; i++) {
  109. result->data[i] = expf(tensor->data[i]);
  110. }
  111. return result;
  112. }
  113. }