From f0ff82a94c90861ddeb0c98b76954e5da40f3f2d Mon Sep 17 00:00:00 2001 From: Harshitha Parnandi Venkata Date: Fri, 18 Oct 2019 17:24:45 -0700 Subject: [PATCH] Fixed a bug that overrites the learning rate when sent as a Tensor. --- .../Train/GradientDescentOptimizer.cs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs index 1a2821bb..d4682066 100644 --- a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs @@ -35,22 +35,29 @@ namespace Tensorflow.Train /// for changing these values across different invocations of optimizer /// functions. /// + private bool _useTensor; public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") : base(learning_rate, use_locking, name) { _lr = learning_rate; + _useTensor = false; } public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") : base(learning_rate, use_locking, name) { _lr_t = learning_rate; + _useTensor = true; } public override void _prepare() { - var lr = _call_if_callable(_lr); - _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + if(!_useTensor) + { + var lr = _call_if_callable(_lr); + _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + } + } } }