|
|
|
@@ -14,6 +14,11 @@ namespace Tensorflow |
|
|
|
/// </summary> |
|
|
|
public abstract class Optimizer |
|
|
|
{ |
|
|
|
// Values for gate_gradients. |
|
|
|
public static int GATE_NONE = 0; |
|
|
|
public static int GATE_OP = 1; |
|
|
|
public static int GATE_GRAPH = 2; |
|
|
|
|
|
|
|
public string Name { get; set; } |
|
|
|
public double LearningRate { get; set; } |
|
|
|
public Tensor LearningRateTensor { get; set; } |
|
|
|
@@ -87,11 +92,15 @@ namespace Tensorflow |
|
|
|
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); |
|
|
|
var var_refs = processors.Select(x => x.target()).ToArray(); |
|
|
|
|
|
|
|
gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss, |
|
|
|
var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss, |
|
|
|
gate_gradients: (gate_gradients == GateGradientType.GATE_OP), |
|
|
|
aggregation_method: aggregation_method, |
|
|
|
colocate_gradients_with_ops: colocate_gradients_with_ops); |
|
|
|
|
|
|
|
//if ((int)gate_gradients == Optimizer.GATE_GRAPH) |
|
|
|
//grads = control_flow_ops.tuple(grads); |
|
|
|
|
|
|
|
|
|
|
|
return null; |
|
|
|
} |
|
|
|
} |
|
|
|
|