|
|
|
@@ -1,4 +1,5 @@ |
|
|
|
using System; |
|
|
|
using NumSharp.Core; |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text; |
|
|
|
@@ -8,9 +9,9 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
public class gradients_impl |
|
|
|
{ |
|
|
|
public static void gradients(object ys, |
|
|
|
object xs, |
|
|
|
List<Tensor> grad_ys = null, |
|
|
|
public static void gradients(Tensor[] ys, |
|
|
|
Tensor[] xs, |
|
|
|
Tensor[] grad_ys = null, |
|
|
|
string name = "gradients", |
|
|
|
bool colocate_gradients_with_ops = false, |
|
|
|
bool gate_gradients = false, |
|
|
|
@@ -19,13 +20,14 @@ namespace Tensorflow |
|
|
|
_GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); |
|
|
|
} |
|
|
|
|
|
|
|
public static void _GradientsHelper(object ys, |
|
|
|
object xs, |
|
|
|
object grad_ys = null, |
|
|
|
public static Tensor[] _GradientsHelper(Tensor[] ys, |
|
|
|
Tensor[] xs, |
|
|
|
Tensor[] grad_ys = null, |
|
|
|
string name = "gradients", |
|
|
|
bool colocate_gradients_with_ops = false, |
|
|
|
bool gate_gradients = false, |
|
|
|
object stop_gradients = null, |
|
|
|
int aggregation_method = 0, |
|
|
|
Tensor[] stop_gradients = null, |
|
|
|
Graph src_graph = null) |
|
|
|
{ |
|
|
|
if (src_graph == null) |
|
|
|
@@ -35,20 +37,14 @@ namespace Tensorflow |
|
|
|
// ancestor graphs. This is necessary for correctly handling captured values. |
|
|
|
var curr_graph = src_graph; |
|
|
|
|
|
|
|
var ys1 = _AsList(ys); |
|
|
|
var xs1 = _AsList(xs); |
|
|
|
List<Tensor> grad_ys1 = null; |
|
|
|
List<Tensor> stop_gradients1 = stop_gradients == null ? new List<Tensor>() : _AsList(stop_gradients); |
|
|
|
if (grad_ys == null) |
|
|
|
grad_ys1 = ys1.Select(x => new Tensor(IntPtr.Zero)).ToList(); |
|
|
|
else |
|
|
|
grad_ys = _AsList(grad_ys); |
|
|
|
grad_ys = new Tensor[ys.Length]; |
|
|
|
|
|
|
|
var all = new List<Tensor>(); |
|
|
|
all.AddRange(ys1); |
|
|
|
all.AddRange(xs1); |
|
|
|
all.AddRange(stop_gradients1); |
|
|
|
all.AddRange(grad_ys1); |
|
|
|
all.AddRange(ys); |
|
|
|
all.AddRange(xs); |
|
|
|
all.AddRange(stop_gradients); |
|
|
|
all.AddRange(grad_ys); |
|
|
|
|
|
|
|
Python.with<ops.name_scope>(new ops.name_scope(name, "gradients", values: all), scope => |
|
|
|
{ |
|
|
|
@@ -56,24 +52,209 @@ namespace Tensorflow |
|
|
|
// Get a uid for this call to gradients that can be used to help |
|
|
|
// cluster ops for compilation. |
|
|
|
var gradient_uid = ops.get_default_graph().unique_name("uid"); |
|
|
|
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); |
|
|
|
|
|
|
|
/** |
|
|
|
* The approach we take here is as follows: Create a list of all ops in the |
|
|
|
* subgraph between the ys and xs. Visit these ops in reverse order of ids |
|
|
|
* to ensure that when we visit an op the gradients w.r.t its outputs have |
|
|
|
* been collected. Then aggregate these gradients if needed, call the op's |
|
|
|
* gradient function, and add the generated gradients to the gradients for |
|
|
|
* its input. |
|
|
|
**/ |
|
|
|
|
|
|
|
// Initialize the pending count for ops in the connected subgraph from ys |
|
|
|
// to the xs. |
|
|
|
var to_ops = ys1.Select(x => x.op).ToList(); |
|
|
|
var from_ops = xs1.Select(x => x.op).ToList(); |
|
|
|
var stop_gradient_ops = stop_gradients1.Select(x => x.op).ToList(); |
|
|
|
_PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs1); |
|
|
|
var to_ops = ys.Select(x => x.op).ToList(); |
|
|
|
var from_ops = xs.Select(x => x.op).ToList(); |
|
|
|
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); |
|
|
|
(var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); |
|
|
|
|
|
|
|
// Iterate over the collected ops. |
|
|
|
/** |
|
|
|
* grads: op => list of gradients received on each output endpoint of the |
|
|
|
* op. The gradients for each endpoint are initially collected as a list. |
|
|
|
* When it is time to call the op's gradient function, for each endpoint we |
|
|
|
* aggregate the list of received gradients into a Add() Operation if there |
|
|
|
* is more than one. |
|
|
|
**/ |
|
|
|
var grads = new Dictionary<string, Tensor[][]>(); |
|
|
|
for(int i = 0; i < ys.Count(); i++) |
|
|
|
{ |
|
|
|
(var y, var grad_y) = Python.zip(ys, grad_ys, i); |
|
|
|
_SetGrad(grads, y, grad_y); |
|
|
|
} |
|
|
|
|
|
|
|
// Initialize queue with to_ops. |
|
|
|
var queue = new Queue<Operation>(); |
|
|
|
// Add the ops in 'to_ops' into the queue. |
|
|
|
var to_ops_set = new List<Operation>(); |
|
|
|
foreach (var op in to_ops) |
|
|
|
{ |
|
|
|
// 'ready' handles the case where one output gradient relies on |
|
|
|
// another output's gradient. |
|
|
|
bool ready = !pending_count.ContainsKey(op.Name) || pending_count[op.Name] == 0; |
|
|
|
if(ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) |
|
|
|
{ |
|
|
|
to_ops_set.Add(op); |
|
|
|
queue.Enqueue(op); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); |
|
|
|
while(queue.Count > 0) |
|
|
|
{ |
|
|
|
// generate gradient subgraph for op. |
|
|
|
var op = queue.Dequeue(); |
|
|
|
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); |
|
|
|
//if (loop_state != null) |
|
|
|
//loop_state.EnterGradWhileContext(op, before: true); |
|
|
|
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); |
|
|
|
|
|
|
|
var is_partitioned_call = _IsPartitionedCall(op); |
|
|
|
var is_func_call = false; |
|
|
|
var has_out_grads = true; |
|
|
|
if (has_out_grads && !stop_ops.Contains(op)) |
|
|
|
{ |
|
|
|
if (is_func_call) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
// A grad_fn must be defined, either as a function or as None |
|
|
|
// for ops that do not have gradients. |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
return null; |
|
|
|
} |
|
|
|
|
|
|
|
private static bool _IsPartitionedCall(Operation op) |
|
|
|
{ |
|
|
|
return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; |
|
|
|
} |
|
|
|
|
|
|
|
private static Tensor[] _AggregatedGrads(Dictionary<string, Tensor[][]> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) |
|
|
|
{ |
|
|
|
var out_grads = _GetGrads(grads, op); |
|
|
|
for(int i = 0; i < out_grads.Length; i++) |
|
|
|
{ |
|
|
|
var out_grad = out_grads[i]; |
|
|
|
if(loop_state != null) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
// Grads have to be Tensors or IndexedSlices |
|
|
|
|
|
|
|
// Aggregate multiple gradients, and convert [] to None. |
|
|
|
if(out_grad != null) |
|
|
|
{ |
|
|
|
if(out_grad.Length < 2) |
|
|
|
{ |
|
|
|
string used = "nop"; |
|
|
|
return new Tensor[] { out_grad[0] }; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return null; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// The set of ops that terminate the gradient computation. |
|
|
|
/// </summary> |
|
|
|
/// <param name="grad_ys"></param> |
|
|
|
/// <param name="ys"></param> |
|
|
|
/// <param name="from_ops">list of Operations.</param> |
|
|
|
/// <param name="stop_gradient_ops">list of Operations never to backprop through.</param> |
|
|
|
/// <param name="pending_count">mapping from operation to number of backprop inputs.</param> |
|
|
|
/// <param name="xs">list of Tensors.</param> |
|
|
|
/// <returns>The set of operations.</returns> |
|
|
|
private static Operation[] _StopOps(List<Operation> from_ops, List<Operation> stop_gradient_ops, Dictionary<string, int> pending_count, Tensor[] xs) |
|
|
|
{ |
|
|
|
var stop_ops = new List<Operation>(); |
|
|
|
|
|
|
|
foreach(var op in from_ops) |
|
|
|
{ |
|
|
|
bool is_stop_op = true; |
|
|
|
foreach(var inp in _NonEagerInputs(op, xs)) |
|
|
|
{ |
|
|
|
if(pending_count.ContainsKey(op.Name) && pending_count[op.Name] > 0) |
|
|
|
{ |
|
|
|
is_stop_op = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (is_stop_op) |
|
|
|
stop_ops.Add(op); |
|
|
|
} |
|
|
|
stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); |
|
|
|
return stop_ops.ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
private static Tensor[][] _GetGrads(Dictionary<string, Tensor[][]> grads, Operation op) |
|
|
|
{ |
|
|
|
if (grads.ContainsKey(op.Name)) |
|
|
|
return grads[op.Name]; |
|
|
|
else |
|
|
|
return op.outputs.Select(x => new Tensor[0]).ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Sets gradient "grad" in "grads" for tensor "t". |
|
|
|
/// </summary> |
|
|
|
/// <param name="grads"></param> |
|
|
|
/// <param name="t"></param> |
|
|
|
/// <param name="grad"></param> |
|
|
|
private static void _SetGrad(Dictionary<string, Tensor[][]> grads, Tensor t, Tensor grad) |
|
|
|
{ |
|
|
|
var op = t.op; |
|
|
|
Tensor[][] op_grads = null; |
|
|
|
if (!grads.ContainsKey(op.Name)) |
|
|
|
{ |
|
|
|
op_grads = op.outputs.Select(x => new Tensor[1]).ToArray(); |
|
|
|
grads[op.Name] = op_grads; |
|
|
|
} |
|
|
|
var t_grads = op_grads[t.value_index]; |
|
|
|
t_grads[0] = grad; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Fill in default values for grad_ys. |
|
|
|
/// </summary> |
|
|
|
/// <param name="grad_ys">List of gradients, can contain None.</param> |
|
|
|
/// <param name="ys">List of tensors.</param> |
|
|
|
/// <param name="colocate_gradients_with_ops"></param> |
|
|
|
/// <param name="gradient_uid"></param> |
|
|
|
private void _DefaultGradYs(List<Tensor> grad_ys, List<Tensor> ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") |
|
|
|
private static Tensor[] _DefaultGradYs(Tensor[] grad_ys, Tensor[] ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") |
|
|
|
{ |
|
|
|
var new_grad_ys = new List<Tensor>(); |
|
|
|
|
|
|
|
for(int i = 0; i < grad_ys.Length; i++) |
|
|
|
{ |
|
|
|
var grad_y = grad_ys[i]; |
|
|
|
var y = ys[i]; |
|
|
|
|
|
|
|
_maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); |
|
|
|
|
|
|
|
if(grad_y == null) |
|
|
|
{ |
|
|
|
if (y.dtype.is_complex()) |
|
|
|
throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); |
|
|
|
var shape = array_ops.shape(y); |
|
|
|
var constant = constant_op.constant(1.0, name: $"grad_ys_{i}"); |
|
|
|
var fill = gen_array_ops.fill(shape, constant); |
|
|
|
new_grad_ys.Add(fill); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return new_grad_ys.ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
private static void _maybe_colocate_with(Operation op, string gradient_uid, bool colocate_gradients_with_ops) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
@@ -88,12 +269,59 @@ namespace Tensorflow |
|
|
|
/// <param name="colocate_gradients_with_ops"></param> |
|
|
|
/// <param name="func_graphs"></param> |
|
|
|
/// <param name="xs"></param> |
|
|
|
private static void _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, List<Tensor> xs) |
|
|
|
private static (Operation[], Dictionary<string, int>, object) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) |
|
|
|
{ |
|
|
|
List<Operation> reached_ops = new List<Operation>(); |
|
|
|
// Mark reachable ops from from_ops. |
|
|
|
var reached_ops = new List<Operation>(); |
|
|
|
_MarkReachedOps(from_ops, reached_ops, func_graphs); |
|
|
|
// X in reached_ops iff X is reachable from from_ops by a path of zero or more |
|
|
|
// backpropagatable tensors. |
|
|
|
|
|
|
|
var reachable_to_ops = to_ops.Where(x => reached_ops.Contains(x)).Select(x => x).ToArray(); |
|
|
|
|
|
|
|
var between_ops = new List<Operation>(); |
|
|
|
var between_op_list = new List<Operation>(); |
|
|
|
|
|
|
|
Queue<Operation> queue = new Queue<Operation>(to_ops); |
|
|
|
while(queue.Count > 0) |
|
|
|
{ |
|
|
|
var op = queue.Dequeue(); |
|
|
|
if (reached_ops.Contains(op)) |
|
|
|
{ |
|
|
|
between_ops.Add(op); |
|
|
|
between_op_list.Insert(between_op_list.Count, op); |
|
|
|
// Clear the boolean so we won't add the inputs again. |
|
|
|
reached_ops.Remove(op); |
|
|
|
foreach (var inp in _NonEagerInputs(op, xs)) |
|
|
|
queue.Enqueue((inp as Tensor).op); |
|
|
|
} |
|
|
|
} |
|
|
|
// X in between_ops iff X is on a path of zero or more backpropagatable tensors |
|
|
|
// between from_ops and to_ops |
|
|
|
|
|
|
|
// 'loop_state' is None if there are no while loops. |
|
|
|
var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); |
|
|
|
|
|
|
|
var pending_count = new Dictionary<string, int>(); |
|
|
|
foreach (var op in between_op_list) |
|
|
|
{ |
|
|
|
foreach(Tensor x in _NonEagerInputs(op, xs)) |
|
|
|
{ |
|
|
|
if (between_ops.Contains(x.op)) |
|
|
|
if (pending_count.ContainsKey(x.op.Name)) |
|
|
|
pending_count[x.op.Name] += 1; |
|
|
|
else |
|
|
|
pending_count[x.op.Name] = 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return (reachable_to_ops.ToArray(), pending_count, loop_state); |
|
|
|
} |
|
|
|
|
|
|
|
private static InputList _NonEagerInputs(Operation op, Tensor[] xs) |
|
|
|
{ |
|
|
|
return op.inputs; |
|
|
|
} |
|
|
|
/// <summary> |
|
|
|
/// Mark all ops reached from "from_ops" |
|
|
|
/// </summary> |
|
|
|
|