You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OptimizerApi.cs 2.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. using Tensorflow.Keras.ArgsDefinition;
  2. namespace Tensorflow.Keras.Optimizers
  3. {
  4. public class OptimizerApi
  5. {
  6. /// <summary>
  7. /// Adam optimization is a stochastic gradient descent method that is based on
  8. /// adaptive estimation of first-order and second-order moments.
  9. /// </summary>
  10. /// <param name="learning_rate"></param>
  11. /// <param name="beta_1"></param>
  12. /// <param name="beta_2"></param>
  13. /// <param name="epsilon"></param>
  14. /// <param name="amsgrad"></param>
  15. /// <param name="name"></param>
  16. /// <returns></returns>
  17. public OptimizerV2 Adam(float learning_rate = 0.001f,
  18. float beta_1 = 0.9f,
  19. float beta_2 = 0.999f,
  20. float epsilon = 1e-7f,
  21. bool amsgrad = false,
  22. string name = "Adam")
  23. => new Adam(learning_rate: learning_rate,
  24. beta_1: beta_1,
  25. beta_2: beta_2,
  26. epsilon: epsilon,
  27. amsgrad: amsgrad,
  28. name: name);
  29. /// <summary>
  30. /// Construct a new RMSprop optimizer.
  31. /// </summary>
  32. /// <param name="learning_rate"></param>
  33. /// <param name="rho"></param>
  34. /// <param name="momentum"></param>
  35. /// <param name="epsilon"></param>
  36. /// <param name="centered"></param>
  37. /// <param name="name"></param>
  38. /// <returns></returns>
  39. public OptimizerV2 RMSprop(float learning_rate = 0.001f,
  40. float rho = 0.9f,
  41. float momentum = 0.0f,
  42. float epsilon = 1e-7f,
  43. bool centered = false,
  44. string name = "RMSprop")
  45. => new RMSprop(new RMSpropArgs
  46. {
  47. LearningRate = learning_rate,
  48. RHO = rho,
  49. Momentum = momentum,
  50. Epsilon = epsilon,
  51. Centered = centered,
  52. Name = name
  53. });
  54. public SGD SGD(float learning_rate)
  55. => new SGD(learning_rate);
  56. }
  57. }