diff --git a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs index 1a2821bb..d4682066 100644 --- a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs @@ -35,22 +35,29 @@ namespace Tensorflow.Train /// for changing these values across different invocations of optimizer /// functions. /// + private bool _useTensor; public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") : base(learning_rate, use_locking, name) { _lr = learning_rate; + _useTensor = false; } public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") : base(learning_rate, use_locking, name) { _lr_t = learning_rate; + _useTensor = true; } public override void _prepare() { - var lr = _call_if_callable(_lr); - _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + if(!_useTensor) + { + var lr = _call_if_callable(_lr); + _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + } + } } }