Browse Source

Fixed a bug that overrites the learning rate when sent as a Tensor.

tags/v0.12
Harshitha Parnandi Venkata 6 years ago
parent
commit
f0ff82a94c
1 changed files with 9 additions and 2 deletions
  1. +9
    -2
      src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs

+ 9
- 2
src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs View File

@@ -35,22 +35,29 @@ namespace Tensorflow.Train
/// for changing these values across different invocations of optimizer /// for changing these values across different invocations of optimizer
/// functions. /// functions.
/// </remarks> /// </remarks>
private bool _useTensor;
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name) : base(learning_rate, use_locking, name)
{ {
_lr = learning_rate; _lr = learning_rate;
_useTensor = false;
} }
public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name) : base(learning_rate, use_locking, name)
{ {
_lr_t = learning_rate; _lr_t = learning_rate;
_useTensor = true;
} }


public override void _prepare() public override void _prepare()
{ {
var lr = _call_if_callable(_lr);
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
if(!_useTensor)
{
var lr = _call_if_callable(_lr);
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
}
} }
} }
} }

Loading…
Cancel
Save