Browse Source

_UpdatePendingAndEnqueueReady

tags/v0.12
Oceania2018 6 years ago
parent
commit
c8a61b21d5
1 changed files with 87 additions and 7 deletions
  1. +87
    -7
      src/TensorFlowNET.Core/Gradients/gradients_util.cs

+ 87
- 7
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Operations.ControlFlows;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -82,6 +83,7 @@ namespace Tensorflow
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);

// Add the initial gradients for the ys.
foreach (var (y, grad_y) in zip(ys, grad_ys))
_SetGrad(grads, y, grad_y);

@@ -103,12 +105,25 @@ namespace Tensorflow
}
}

if(loop_state != null)
{
var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set);
foreach(var y in loop_exits)
{
//if(IsTrainable(y))
throw new NotImplementedException("");
}
}

var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs);
while (queue.Count > 0)
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
if(op.name == "rnn/while/basic_rnn_cell/Tanh")
{

}
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
//if (loop_state != null)
//loop_state.EnterGradWhileContext(op, before: true);
@@ -147,8 +162,8 @@ namespace Tensorflow
}
}

// if (loop_state)
//loop_state.EnterGradWhileContext(op, before: false);
if (loop_state != null)
loop_state.EnterGradWhileContext(op, before: false);

if ((is_func_call || grad_fn != null) && has_out_grads)
{
@@ -164,7 +179,7 @@ namespace Tensorflow
// will use SymbolicGradient get a zero gradient. Gradient
// functions should ignore the gradient for other outputs.
if (loop_state != null)
;
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
else
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
}
@@ -275,7 +290,7 @@ namespace Tensorflow
/// <param name="colocate_gradients_with_ops"></param>
/// <param name="func_graphs"></param>
/// <param name="xs"></param>
private static (Operation[], Dictionary<string, int>, object) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs)
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs)
{
// Mark reachable ops from from_ops.
var reached_ops = new List<Operation>();
@@ -308,6 +323,7 @@ namespace Tensorflow
// 'loop_state' is None if there are no while loops.
var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops);

// Initialize pending count for between ops.
var pending_count = new Dictionary<string, int>();
foreach (var op in between_op_list)
{
@@ -550,7 +566,7 @@ namespace Tensorflow
Operation op,
Queue<Operation> queue,
Dictionary<string, int> pending_count,
object loop_state,
ControlFlowState loop_state,
Tensor[] xs)
{
foreach (var x in _NonEagerInputs(op, xs))
@@ -564,14 +580,49 @@ namespace Tensorflow

if (loop_state != null && !ready)
{
ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op);
}

if (ready)
{
// if x is an exit without real gradient, defer processing them.
if (control_flow_util.IsLoopExit(x.op))
{

var grad_state = loop_state.GetGradState(x.op, before: false);
grad_state.deferred_exits.append(x);
grad_state.pending_exits_count -= 1;
// We now have all the exits so process them.
if (grad_state.pending_exits_count == 0)
{
var has_not_none_grad = false;
foreach(var y in grad_state.deferred_exits)
{
if (_HasAnyNotNoneGrads(grads, y.op))
{
has_not_none_grad = true;
queue.Enqueue(y.op);
}
else
grad_state.unused_exits.append(y);
}
if (has_not_none_grad)
{
// For an unused exit, if it has trainable outputs, backprop
// a zero gradient. Otherwise, just ignore it.
foreach (var y in grad_state.unused_exits)
{
if (IsTrainable(y))
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y));
queue.Enqueue(y.op);
}
}
else
{
// All exits are "unused" so use None as gradient.
foreach (var y in grad_state.unused_exits)
queue.Enqueue(y.op);
}
}
}
else
{
@@ -581,6 +632,32 @@ namespace Tensorflow
}
}

private static bool IsTrainable(Tensor tensor)
{
var dtype = tensor.dtype.as_base_dtype();
return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128,
dtypes.resource, dtypes.variant}.Contains(dtype);
}

/// <summary>
/// Return true if op has real gradient.
/// </summary>
/// <param name="grads"></param>
/// <param name="op"></param>
/// <returns></returns>
private static bool _HasAnyNotNoneGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op)
{
var out_grads = _GetGrads(grads, op);
foreach(var out_grad in out_grads)
{
if (out_grad.Exists(g => g != null))
return true;
}
return false;
}


private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn)
{
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope;
@@ -589,6 +666,9 @@ namespace Tensorflow

private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op)
{
if (op.type == "While" || op.type == "StatelessWhile")
return;

if (grads.Count() != op.inputs._inputs.Count())
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " +
$"inputs {op.inputs._inputs.Count()}");


Loading…
Cancel
Save