|
|
|
@@ -97,6 +97,25 @@ namespace Tensorflow.Train |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public override Operation _finish(Operation[] update_ops, string name_scope) |
|
|
|
{ |
|
|
|
var operations = new List<ITensorOrOperation>(); |
|
|
|
operations.AddRange(update_ops); |
|
|
|
|
|
|
|
with(ops.control_dependencies(update_ops), delegate |
|
|
|
{ |
|
|
|
var (beta1_power, beta2_power) = _get_beta_accumulators(); |
|
|
|
ops.colocate_with(beta1_power); |
|
|
|
var update_beta1 = beta1_power.assign(beta1_power * _beta1_t, use_locking: _use_locking); |
|
|
|
var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking); |
|
|
|
|
|
|
|
operations.Add(update_beta1); |
|
|
|
operations.Add(update_beta1); |
|
|
|
}); |
|
|
|
|
|
|
|
return control_flow_ops.group(operations.ToArray(), name: name_scope); |
|
|
|
} |
|
|
|
|
|
|
|
private (RefVariable, RefVariable) _get_beta_accumulators() |
|
|
|
{ |
|
|
|
ops.init_scope(); |
|
|
|
|