|
|
|
@@ -35,22 +35,29 @@ namespace Tensorflow.Train |
|
|
|
/// for changing these values across different invocations of optimizer |
|
|
|
/// functions. |
|
|
|
/// </remarks> |
|
|
|
private bool _useTensor; |
|
|
|
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") |
|
|
|
: base(learning_rate, use_locking, name) |
|
|
|
{ |
|
|
|
_lr = learning_rate; |
|
|
|
_useTensor = false; |
|
|
|
}
|
|
|
|
|
|
|
|
public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent")
|
|
|
|
: base(learning_rate, use_locking, name) |
|
|
|
{ |
|
|
|
_lr_t = learning_rate; |
|
|
|
_useTensor = true; |
|
|
|
} |
|
|
|
|
|
|
|
public override void _prepare() |
|
|
|
{ |
|
|
|
var lr = _call_if_callable(_lr); |
|
|
|
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); |
|
|
|
if(!_useTensor)
|
|
|
|
{
|
|
|
|
var lr = _call_if_callable(_lr);
|
|
|
|
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |