| @@ -97,6 +97,25 @@ namespace Tensorflow.Train | |||||
| } | } | ||||
| } | } | ||||
| public override Operation _finish(Operation[] update_ops, string name_scope) | |||||
| { | |||||
| var operations = new List<ITensorOrOperation>(); | |||||
| operations.AddRange(update_ops); | |||||
| with(ops.control_dependencies(update_ops), delegate | |||||
| { | |||||
| var (beta1_power, beta2_power) = _get_beta_accumulators(); | |||||
| ops.colocate_with(beta1_power); | |||||
| var update_beta1 = beta1_power.assign(beta1_power * _beta1_t, use_locking: _use_locking); | |||||
| var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking); | |||||
| operations.Add(update_beta1); | |||||
| operations.Add(update_beta1); | |||||
| }); | |||||
| return control_flow_ops.group(operations.ToArray(), name: name_scope); | |||||
| } | |||||
| private (RefVariable, RefVariable) _get_beta_accumulators() | private (RefVariable, RefVariable) _get_beta_accumulators() | ||||
| { | { | ||||
| ops.init_scope(); | ops.init_scope(); | ||||