From 6af53e6310e44647a7efb11d74d0e6ab755549d8 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 21 Jun 2019 00:15:18 -0500 Subject: [PATCH] override Adam _finish. --- src/TensorFlowNET.Core/Train/AdamOptimizer.cs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index 1086828a..15557679 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -97,6 +97,25 @@ namespace Tensorflow.Train } } + public override Operation _finish(Operation[] update_ops, string name_scope) + { + var operations = new List(); + operations.AddRange(update_ops); + + with(ops.control_dependencies(update_ops), delegate + { + var (beta1_power, beta2_power) = _get_beta_accumulators(); + ops.colocate_with(beta1_power); + var update_beta1 = beta1_power.assign(beta1_power * _beta1_t, use_locking: _use_locking); + var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking); + + operations.Add(update_beta1); + operations.Add(update_beta1); + }); + + return control_flow_ops.group(operations.ToArray(), name: name_scope); + } + private (RefVariable, RefVariable) _get_beta_accumulators() { ops.init_scope();