using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using distribute_lib = Tensorflow.Distribute;
namespace Tensorflow
{
///
/// Base class for optimizers.
/// This class defines the API to add Ops to train a model. You never use this
/// class directly, but instead instantiate one of its subclasses such as
/// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
///
public abstract class Optimizer : Python
{
// Values for gate_gradients.
public static int GATE_NONE = 0;
public static int GATE_OP = 1;
public static int GATE_GRAPH = 2;
public string Name { get; set; }
public float LearningRate { get; set; }
public Tensor LearningRateTensor { get; set; }
public bool _use_locking;
public Dictionary _slots;
public Dictionary _non_slot_dict;
public Dictionary _deferred_slot_restorations;
public Optimizer(float learning_rate, bool use_locking, string name = null)
{
if (String.IsNullOrEmpty(name))
throw new NotImplementedException("Must specify the optimizer name");
Name = name;
_use_locking = use_locking;
// Dictionary of slots.
_slots = new Dictionary();
_non_slot_dict = new Dictionary();
_deferred_slot_restorations = new Dictionary();
}
///
/// Add operations to minimize `loss` by updating `var_list`
///
///
///
/// An Operation that updates the variables in `var_list`. If `global_step`
/// was not `None`, that operation also increments `global_step`.
///
public Operation minimize(Tensor loss,
GateGradientType gate_gradients = GateGradientType.GATE_OP,
bool colocate_gradients_with_ops = false)
{
var grads_and_vars = compute_gradients(loss,
gate_gradients: gate_gradients,
colocate_gradients_with_ops: colocate_gradients_with_ops);
var vars_with_grad = grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray();
if (vars_with_grad.Length == 0)
throw new ValueError($"No gradients provided for any variable, check your graph for ops" +
$" that do not support gradients, between variables {string.Join(",", vars_with_grad.Select(x => x.name))} and loss {loss}.");
return apply_gradients(grads_and_vars);
}
public Operation apply_gradients(Tuple[] grads_and_vars, Tensor global_step = null, string name = null)
{
// No DistributionStrategy case.
var converted_grads_and_vars = new List>();
foreach (var (g, v) in grads_and_vars)
{
if(g != null)
{
// Convert the grad to Tensor or IndexedSlices if necessary.
var gR = ops.convert_to_tensor_or_indexed_slices(g);
var p = _get_processor(v);
converted_grads_and_vars.Add(new Tuple(gR, v, p));
}
}
var var_list = converted_grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray();
if (var_list.Length == 0)
throw new ValueError($"No gradients provided for any variable");
ops.init_scope();
_create_slots(var_list);
var update_ops = new List();
return with(new ops.name_scope(name, Name), scope =>
{
name = scope;
_prepare();
foreach(var (grad, var, processor) in converted_grads_and_vars)
{
if (grad == null)
continue;
var scope_name = var.op.name;
with(new ops.name_scope("update_" + scope_name), scope2 =>
{
update_ops.Add(processor.update_op(this, grad));
});
}
Operation apply_updates = null;
if (global_step == null)
{
apply_updates = _finish(update_ops.ToArray(), name);
}
else
{
}
if (!tf.context.executing_eagerly())
{
var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List