From 50d8d20b8e2846e687454b8361ce88b585321d62 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 7 Sep 2019 13:01:00 -0500 Subject: [PATCH] assign_moving_average --- src/TensorFlowNET.Core/APIs/tf.train.cs | 3 +++ src/TensorFlowNET.Core/Train/moving_averages.cs | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index a943308b..39915425 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -31,6 +31,9 @@ namespace Tensorflow public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name); + public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam") + => new AdamOptimizer(learning_rate, name: name); + public ExponentialMovingAverage ExponentialMovingAverage(float decay) => new ExponentialMovingAverage(decay); diff --git a/src/TensorFlowNET.Core/Train/moving_averages.cs b/src/TensorFlowNET.Core/Train/moving_averages.cs index d77367f3..de4e7f2e 100644 --- a/src/TensorFlowNET.Core/Train/moving_averages.cs +++ b/src/TensorFlowNET.Core/Train/moving_averages.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Train if (decay.dtype != variable.dtype.as_base_dtype()) decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); - return decay; + return state_ops.assign_sub(variable, (variable - value) * decay, name: scope); }); } }