Browse Source

override Adam _finish.

tags/v0.9
Oceania2018 6 years ago
parent
commit
6af53e6310
1 changed files with 19 additions and 0 deletions
  1. +19
    -0
      src/TensorFlowNET.Core/Train/AdamOptimizer.cs

+ 19
- 0
src/TensorFlowNET.Core/Train/AdamOptimizer.cs View File

@@ -97,6 +97,25 @@ namespace Tensorflow.Train
}
}

public override Operation _finish(Operation[] update_ops, string name_scope)
{
var operations = new List<ITensorOrOperation>();
operations.AddRange(update_ops);

with(ops.control_dependencies(update_ops), delegate
{
var (beta1_power, beta2_power) = _get_beta_accumulators();
ops.colocate_with(beta1_power);
var update_beta1 = beta1_power.assign(beta1_power * _beta1_t, use_locking: _use_locking);
var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking);

operations.Add(update_beta1);
operations.Add(update_beta1);
});

return control_flow_ops.group(operations.ToArray(), name: name_scope);
}

private (RefVariable, RefVariable) _get_beta_accumulators()
{
ops.init_scope();


Loading…
Cancel
Save