You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.cc 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #include "nn.h"
  2. namespace nn {
  3. std::shared_ptr<tensor::Tensor> log_softmax(std::shared_ptr<tensor::Tensor> logits) {
  4. auto batch_size = logits->shape[0];
  5. auto num_classes = logits->shape[1];
  6. auto log_probs_shape = {batch_size, num_classes};
  7. auto log_probs = std::make_shared<tensor::Tensor>(log_probs_shape);
  8. for (auto i = 0; i < batch_size; i++) {
  9. auto max_logit = logits->data[i * num_classes];
  10. for (auto j = 1; j < num_classes; j++) {
  11. max_logit = max_logit > logits->data[i * num_classes + j] ? max_logit : logits->data[i * num_classes + j];
  12. }
  13. auto sum_exp = 0.0;
  14. for (auto j = 0; j < num_classes; j++) {
  15. log_probs->data[i * num_classes + j] = logits->data[i * num_classes + j] - max_logit;
  16. sum_exp += exp(log_probs->data[i * num_classes + j]);
  17. }
  18. // calculate log(softmax)
  19. auto log_sum_exp = log(sum_exp);
  20. for (auto j = 0; j < num_classes; j++) {
  21. log_probs->data[i * num_classes + j] -= log_sum_exp;
  22. }
  23. }
  24. return log_probs;
  25. }
  26. std::vector<std::shared_ptr<tensor::Tensor>> gradients(std::shared_ptr<Loss> loss, std::vector<std::shared_ptr<Node>> parameters) {
  27. loss->used = true;
  28. std::unordered_set<std::shared_ptr<Node>> nodes;
  29. std::vector<std::shared_ptr<Node>> tape;
  30. // 递归遍历图并构建计算图
  31. std::function<void(std::shared_ptr<Node>)> visit = [&](std::shared_ptr<Node> node) {
  32. if (nodes.find(node) == nodes.end()) {
  33. for (const auto& parent : node->get_parents()) {
  34. visit(parent);
  35. }
  36. nodes.insert(node);
  37. tape.push_back(node);
  38. }
  39. };
  40. visit(loss);
  41. for (const auto& param : parameters) {
  42. nodes.insert(param);
  43. }
  44. std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<tensor::Tensor>> grads;
  45. for (const auto& node : nodes) {
  46. grads[node] = std::make_shared<tensor::Tensor>(node->data->shape);
  47. }
  48. grads[loss] = std::make_shared<tensor::Tensor>(loss->data->shape);
  49. grads[loss]->data[0] = 1.0;
  50. for (auto it = tape.rbegin(); it != tape.rend(); it++) {
  51. // std::cout << "tape it: " << std::endl;
  52. auto node = *it;
  53. // if (node->data->shape[0] == 1) {
  54. // std::cout << "coming to squareloss" << std::endl;
  55. // }
  56. auto parent_grads = node->backward(grads[node]);
  57. auto parents = node->get_parents();
  58. for (size_t i = 0; i < parents.size(); i++) {
  59. // std::cout << "this grad shape: " << grads[parents[i]]->data.size() << std::endl;
  60. for (auto ind = 0; ind < parents[i]->data->size; ind++) {
  61. grads[parents[i]]->data[ind] += parent_grads[i]->data[ind];
  62. }
  63. }
  64. }
  65. std::vector<std::shared_ptr<tensor::Tensor>> result;
  66. for (const auto& param : parameters) {
  67. result.emplace_back(grads[param]);
  68. }
  69. // std::cout << "len(result): " << result.size() << std::endl;
  70. return result;
  71. }
  72. }

计算机大作业