| @@ -35,22 +35,29 @@ namespace Tensorflow.Train | |||||
| /// for changing these values across different invocations of optimizer | /// for changing these values across different invocations of optimizer | ||||
| /// functions. | /// functions. | ||||
| /// </remarks> | /// </remarks> | ||||
| private bool _useTensor; | |||||
| public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") | public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") | ||||
| : base(learning_rate, use_locking, name) | : base(learning_rate, use_locking, name) | ||||
| { | { | ||||
| _lr = learning_rate; | _lr = learning_rate; | ||||
| _useTensor = false; | |||||
| } | } | ||||
| public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") | public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") | ||||
| : base(learning_rate, use_locking, name) | : base(learning_rate, use_locking, name) | ||||
| { | { | ||||
| _lr_t = learning_rate; | _lr_t = learning_rate; | ||||
| _useTensor = true; | |||||
| } | } | ||||
| public override void _prepare() | public override void _prepare() | ||||
| { | { | ||||
| var lr = _call_if_callable(_lr); | |||||
| _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||||
| if(!_useTensor) | |||||
| { | |||||
| var lr = _call_if_callable(_lr); | |||||
| _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||