diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index 1086828a..15557679 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -97,6 +97,25 @@ namespace Tensorflow.Train } } + public override Operation _finish(Operation[] update_ops, string name_scope) + { + var operations = new List(); + operations.AddRange(update_ops); + + with(ops.control_dependencies(update_ops), delegate + { + var (beta1_power, beta2_power) = _get_beta_accumulators(); + ops.colocate_with(beta1_power); + var update_beta1 = beta1_power.assign(beta1_power * _beta1_t, use_locking: _use_locking); + var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking); + + operations.Add(update_beta1); + operations.Add(update_beta1); + }); + + return control_flow_ops.group(operations.ToArray(), name: name_scope); + } + private (RefVariable, RefVariable) _get_beta_accumulators() { ops.init_scope();