Browse Source

Fixed a bug that overrites the learning rate when sent as a Tensor.

tags/v0.12
Harshitha Parnandi Venkata 6 years ago
parent
commit
f0ff82a94c
1 changed files with 9 additions and 2 deletions
  1. +9
    -2
      src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs

+ 9
- 2
src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs View File

@@ -35,22 +35,29 @@ namespace Tensorflow.Train
/// for changing these values across different invocations of optimizer
/// functions.
/// </remarks>
private bool _useTensor;
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name)
{
_lr = learning_rate;
_useTensor = false;
}
public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name)
{
_lr_t = learning_rate;
_useTensor = true;
}

public override void _prepare()
{
var lr = _call_if_callable(_lr);
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
if(!_useTensor)
{
var lr = _call_if_callable(_lr);
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
}
}
}
}

Loading…
Cancel
Save