You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #pragma once
  2. #include <vector>
  3. #include <memory>
  4. #include <unordered_set>
  5. #include <unordered_map>
  6. #include <algorithm>
  7. #include <pybind11/pybind11.h>
  8. #include <pybind11/numpy.h>
  9. #include <iostream>
  10. #include "../tensor/tensor.h"
  11. #include "../math/arith.h"
  12. namespace py = pybind11;
  13. namespace nn {
  14. class Node {
  15. public:
  16. std::shared_ptr<tensor::Tensor> data;
  17. std::vector<std::shared_ptr<Node>> objects;
  18. std::vector<std::shared_ptr<tensor::Tensor>> gradient;
  19. public:
  20. Node() {}
  21. virtual std::shared_ptr<tensor::Tensor> forward() = 0;
  22. virtual std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) = 0;
  23. std::vector<std::shared_ptr<Node>> get_parents() {
  24. return this->objects;
  25. }
  26. std::vector<float> get_data() {
  27. return this->data->data;
  28. }
  29. std::shared_ptr<tensor::Tensor> get_tensor() {
  30. return this->data;
  31. }
  32. // virtual void update(std::shared_ptr<tensor::Tensor> grad, float lr) = 0;
  33. // virtual void zero_grad() = 0;
  34. virtual ~Node() {}
  35. };
  36. class DataNode: public Node {
  37. public:
  38. DataNode() {}
  39. }; // class DataNode
  40. class Parameter: public DataNode {
  41. public:
  42. // Parameter(const std::vector<std::size_t>& shape) {
  43. // this->data = std::make_shared<tensor::Tensor>(shape, true);
  44. // }
  45. Parameter(py::array_t<float> array) {
  46. py::buffer_info info = array.request();
  47. float* dataPtr = static_cast<float*>(info.ptr);
  48. std::vector<std::size_t> shape = {};
  49. for (auto &it: info.shape) {
  50. shape.push_back(it);
  51. }
  52. auto tensor = std::make_shared<tensor::Tensor>(shape);
  53. std::vector<float> result(dataPtr, dataPtr + info.size);
  54. tensor->data = result;
  55. this->data = tensor;
  56. }
  57. std::shared_ptr<tensor::Tensor> forward() {
  58. return this->data;
  59. };
  60. std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) {
  61. return {gradient};
  62. };
  63. void update(std::shared_ptr<tensor::Tensor> grad, double lr) {
  64. for (auto i = 0; i < this->data->size; i++) {
  65. this->data->data[i] -= lr * grad->data[i];
  66. }
  67. }
  68. }; // class Parameter
  69. class Constant: public DataNode {
  70. public:
  71. Constant(std::shared_ptr<tensor::Tensor> data) {
  72. this->data = data;
  73. }
  74. Constant(py::array_t<float> array) {
  75. this->data = tensor::pyarray_to_tensor(array);
  76. }
  77. std::shared_ptr<tensor::Tensor> forward() {
  78. return this->data;
  79. };
  80. std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) {
  81. return {gradient};
  82. };
  83. // void update(std::shared_ptr<tensor::Tensor> grad, float lr) {}
  84. }; // class Constant
  85. class FunctionNode: public Node {
  86. public:
  87. FunctionNode(std::shared_ptr<Node> a, std::shared_ptr<Node> b) {
  88. this->objects.emplace_back(a);
  89. this->objects.emplace_back(b);
  90. }
  91. FunctionNode(std::shared_ptr<Node> a) {
  92. this->objects.emplace_back(a);
  93. }
  94. std::shared_ptr<tensor::Tensor> forward() override {
  95. return nullptr;
  96. }
  97. }; //class FunctionNode
  98. class Add: public FunctionNode {
  99. public:
  100. Add(std::shared_ptr<Node> a, std::shared_ptr<Node> b) : FunctionNode(a, b) {
  101. this->data = this->forward();
  102. }
  103. std::shared_ptr<tensor::Tensor> forward() override {
  104. auto a = this->objects[0];
  105. auto b = this->objects[1];
  106. auto outNode = std::make_shared<tensor::Tensor>(a->data->shape);
  107. for (auto i = 0; i < a->data->size; i++) {
  108. outNode->data[i] = a->data->data[i] + b->data->data[i];
  109. }
  110. return outNode;
  111. }
  112. std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) override {
  113. // assertion needed
  114. return {gradient, gradient};
  115. }
  116. };
  117. class AddBias: public FunctionNode {
  118. public:
  119. AddBias(std::shared_ptr<Node> a, std::shared_ptr<Node> b) : FunctionNode(a, b) {
  120. this->data = this->forward();
  121. }
  122. std::shared_ptr<tensor::Tensor> forward() override {
  123. // features: a Node with shape (batch_size x num_features)
  124. // bias: a Node with shape (1 x num_features)
  125. auto features = this->objects[0];
  126. auto bias = this->objects[1];
  127. auto outNode = std::make_shared<tensor::Tensor>(features->data->shape);
  128. // for循环写加法总会写吧🤔
  129. // 补全这里的代码
  130. return outNode;
  131. }
  132. std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) override {
  133. // assertion needed
  134. auto g_bias = std::make_shared<tensor::Tensor>(this->objects[1]->data->shape);
  135. // 补全这里的代码
  136. return {gradient, g_bias};
  137. }
  138. std::vector<float> get_data() {
  139. return this->data->data;
  140. }
  141. }; // class AddBias
  142. class Linear: public FunctionNode {
  143. public:
  144. Linear(std::shared_ptr<Node> a, std::shared_ptr<Node> b) : FunctionNode(a, b) {
  145. // 这段代码就一行,参考下别的类是怎么写的呢?
  146. // 在这里补全
  147. }
  148. std::shared_ptr<tensor::Tensor> forward() override {
  149. // features: (batch_size x input_features)
  150. auto features = this->objects[0];
  151. // weights: (input_features x output_features)
  152. auto weights = this->objects[1];
  153. auto m = features->data->shape[0];
  154. auto k = features->data->shape[1];
  155. auto n = weights->data->shape[1];
  156. // std::cout << m << " " << n << " " << k << std::endl;
  157. // output: (batch_size x output_features)
  158. auto shape = {m, n};
  159. auto outNode = std::make_shared<tensor::Tensor>(shape);
  160. // 实际上你需要补全的是arith::mm函数,快去找找它在哪里
  161. // 其余部分不需要动
  162. arith::mm(features->data->data, weights->data->data, outNode->data, m, k, n);
  163. return outNode;
  164. }
  165. std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) override {
  166. auto features = this->objects[0];
  167. auto weights = this->objects[1];
  168. // gradient.shape[0] == features.shape[0]
  169. // gradient.shape[1] == weights.shape[1]
  170. auto grad_features_shape = {gradient->shape[0], weights->data->shape[0]};
  171. auto grad_features = std::make_shared<tensor::Tensor>(grad_features_shape);
  172. auto grad_weights_shape = {features->data->shape[1], gradient->shape[1]};
  173. auto grad_weights = std::make_shared<tensor::Tensor>(grad_weights_shape);
  174. // 这里要调用两次arith:mm,是分别把哪两个矩阵相乘呢?
  175. return {grad_features, grad_weights};
  176. }
  177. }; //class Linear
  178. class ReLU: public FunctionNode {
  179. public:
  180. ReLU(std::shared_ptr<Node> a) : FunctionNode(a) {
  181. // 补全这里
  182. }
  183. std::shared_ptr<tensor::Tensor> forward() override {
  184. // x: a Node with shape (batch_size x num_features)
  185. auto outNode = std::make_shared<tensor::Tensor>(this->objects[0]->data->shape);
  186. // 补全这里,调用arith::vector_scalar_max
  187. return outNode;
  188. }
  189. std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) override {
  190. auto grads = std::make_shared<tensor::Tensor>(this->objects[0]->data->shape);
  191. // 补全这里,一个for循环
  192. return {grads};
  193. }
  194. }; // class ReLU
  195. class Loss: public FunctionNode {
  196. public:
  197. bool used = false;
  198. public:
  199. Loss(std::shared_ptr<Node> a, std::shared_ptr<Node> b) : FunctionNode(a, b) {}
  200. };
  201. class SquareLoss: public Loss {
  202. public:
  203. SquareLoss(std::shared_ptr<Node> a, std::shared_ptr<Node> b): Loss(a, b) {
  204. // 补全这里的代码
  205. }
  206. std::shared_ptr<tensor::Tensor> forward() {
  207. // a: a Node with shape (batch_size x dim)
  208. // b: a Node with shape (batch_size x dim)
  209. // 这个简单,就是要注意返回的res需要是一个tensor就行
  210. // 修改下面的代码
  211. std::vector<size_t> res_shape = {1};
  212. auto res = std::make_shared<tensor::Tensor>(res_shape);
  213. return res;
  214. }
  215. std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) override {
  216. float g = gradient->data[0];
  217. auto a = this->objects[0];
  218. auto b = this->objects[1];
  219. auto grad_a = std::make_shared<tensor::Tensor>(a->data->shape);
  220. auto grad_b = std::make_shared<tensor::Tensor>(b->data->shape);
  221. // 补全下面的代码
  222. return {grad_a, grad_b};
  223. }
  224. }; // class SquareLoss
  225. std::shared_ptr<tensor::Tensor> log_softmax(std::shared_ptr<tensor::Tensor> logits);
  226. class SoftmaxLoss: public Loss {
  227. public:
  228. SoftmaxLoss(std::shared_ptr<Node> logits, std::shared_ptr<Node> labels): Loss(logits, labels) {
  229. this->data = this->forward();
  230. }
  231. std::shared_ptr<tensor::Tensor> forward() {
  232. // 我们已经帮你写好log_softmax
  233. auto log_probs = log_softmax(this->objects[0]->data);
  234. // 补全下面的代码,计算softmax loss
  235. std::vector<size_t> res_shape = {1};
  236. auto res = std::make_shared<tensor::Tensor>(res_shape);
  237. return res;
  238. }
  239. std::vector<std::shared_ptr<tensor::Tensor>> backward(std::shared_ptr<tensor::Tensor> gradient) override {
  240. auto log_probs = log_softmax(this->objects[0]->data);
  241. auto labels = this->objects[1]->data;
  242. auto batch_size = log_probs->shape[0];
  243. auto num_classes = log_probs->shape[1];
  244. auto grad_logits = std::make_shared<tensor::Tensor>(log_probs->shape);
  245. auto grad_labels = std::make_shared<tensor::Tensor>(labels->shape);
  246. // 补全下面的代码
  247. return {grad_logits, grad_labels};
  248. }
  249. }; // class SoftmaxLoss
  250. std::vector<std::shared_ptr<tensor::Tensor>> gradients(std::shared_ptr<Loss> loss, std::vector<std::shared_ptr<Node>> parameters);
  251. }

计算机大作业