You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gradient_descent.cpp 623 B

123456789101112131415161718192021222324252627
  1. #include <iostream>
  2. #include "../../root/include/vectmath.h"
  3. #include "../../root/include/node.h"
  4. Node function(std::vector<Node>& x){
  5. return pow(x[0], 2) + pow(x[1], 2); // x^2 + y^2
  6. }
  7. int main(int argc, char const *argv[]) {
  8. Graph* graph = Graph::getInstance();
  9. std::vector<Node> x = {50, 50};
  10. Node f;
  11. int epochs = 30;
  12. float learning_rate = 0.1;
  13. for(size_t i=0 ; i<epochs ; i++){
  14. f = function(x);
  15. x -= learning_rate*f.gradient(x);
  16. graph->new_recording();
  17. }
  18. std::cout << "f = " << f << std::endl;
  19. std::cout << "x = " << x << std::endl;
  20. return 0;
  21. }

Edge : 一个开源的科学计算引擎

Contributors (1)