You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autodiff.cc 904 B

1234567891011121314151617181920212223242526272829303132
  1. #include "autodiff.h"
  2. namespace autodiff {
  3. std::vector<std::shared_ptr<ScalarFunction>> topoSort(const std::vector<std::shared_ptr<ScalarFunction>>& scalars) {
  4. std::vector<std::shared_ptr<ScalarFunction>> sorted;
  5. std::vector<std::shared_ptr<ScalarFunction>> frontier;
  6. std::unordered_map<std::shared_ptr<ScalarFunction>, int> degree;
  7. for (auto it: scalars) {
  8. if (it->degree == 0) {
  9. frontier.push_back(it);
  10. }
  11. else {
  12. degree.insert({it, it->degree});
  13. }
  14. }
  15. while (!frontier.empty()) {
  16. auto back = frontier.back();
  17. sorted.push_back(back);
  18. for (auto &it: degree) {
  19. if (it.second > 0 && it.first == back) {
  20. it.second--;
  21. if (it.second == 0) {
  22. frontier.push_back(it.first);
  23. }
  24. }
  25. }
  26. }
  27. return sorted;
  28. }
  29. }