diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index 3a790327..03b0a0e2 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -34,6 +34,9 @@ namespace Tensorflow public Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate); + public Optimizer GradientDescentOptimizer(Tensor learning_rate) + => new GradientDescentOptimizer(learning_rate); + public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name);