diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index a943308b..39915425 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -31,6 +31,9 @@ namespace Tensorflow public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name); + public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam") + => new AdamOptimizer(learning_rate, name: name); + public ExponentialMovingAverage ExponentialMovingAverage(float decay) => new ExponentialMovingAverage(decay); diff --git a/src/TensorFlowNET.Core/Train/moving_averages.cs b/src/TensorFlowNET.Core/Train/moving_averages.cs index d77367f3..de4e7f2e 100644 --- a/src/TensorFlowNET.Core/Train/moving_averages.cs +++ b/src/TensorFlowNET.Core/Train/moving_averages.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Train if (decay.dtype != variable.dtype.as_base_dtype()) decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); - return decay; + return state_ops.assign_sub(variable, (variable - value) * decay, name: scope); }); } }