|
|
|
@@ -123,10 +123,7 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
// generate gradient subgraph for op. |
|
|
|
var op = queue.Dequeue(); |
|
|
|
if(op.name == "rnn/while/Exit") |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); |
|
|
|
{ |
|
|
|
if (loop_state != null) |
|
|
|
@@ -136,6 +133,7 @@ namespace Tensorflow |
|
|
|
loop_state.ExitGradWhileContext(op, before: true); |
|
|
|
|
|
|
|
Tensor[] in_grads = null; |
|
|
|
Func<Operation, Tensor[], Tensor[]> grad_fn = null; |
|
|
|
var is_partitioned_call = _IsPartitionedCall(op); |
|
|
|
var is_func_call = false; |
|
|
|
var has_out_grads = out_grads.Exists(x => x != null); |
|
|
|
@@ -143,8 +141,6 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
// A grad_fn must be defined, either as a function or as None |
|
|
|
// for ops that do not have gradients. |
|
|
|
|
|
|
|
Func<Operation, Tensor[], Tensor[]> grad_fn = null; |
|
|
|
try |
|
|
|
{ |
|
|
|
grad_fn = ops.get_gradient_function(op); |
|
|
|
@@ -167,61 +163,57 @@ namespace Tensorflow |
|
|
|
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (loop_state != null) |
|
|
|
loop_state.EnterGradWhileContext(op, before: false); |
|
|
|
if (loop_state != null) |
|
|
|
loop_state.EnterGradWhileContext(op, before: false); |
|
|
|
|
|
|
|
if ((is_func_call || grad_fn != null) && has_out_grads) |
|
|
|
if ((is_func_call || grad_fn != null) && has_out_grads) |
|
|
|
{ |
|
|
|
// NOTE: If _AggregatedGrads didn't compute a value for the i'th |
|
|
|
// output, it means that the cost does not depend on output[i], |
|
|
|
// therefore dC/doutput[i] is 0. |
|
|
|
foreach (var (i, out_grad) in enumerate(out_grads)) |
|
|
|
{ |
|
|
|
// NOTE: If _AggregatedGrads didn't compute a value for the i'th |
|
|
|
// output, it means that the cost does not depend on output[i], |
|
|
|
// therefore dC/doutput[i] is 0. |
|
|
|
foreach (var (i, out_grad) in enumerate(out_grads)) |
|
|
|
{ |
|
|
|
if (out_grad == null && |
|
|
|
(grad_fn == null || _IsTrainable(op.outputs[i]))) |
|
|
|
{ |
|
|
|
// Only trainable outputs or outputs for a function call that |
|
|
|
// will use SymbolicGradient get a zero gradient. Gradient |
|
|
|
// functions should ignore the gradient for other outputs. |
|
|
|
if (loop_state != null) |
|
|
|
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) }; |
|
|
|
else |
|
|
|
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
tf_with(ops.name_scope(op.name + "_grad"), scope1 => |
|
|
|
if (out_grad == null && |
|
|
|
(grad_fn == null || _IsTrainable(op.outputs[i]))) |
|
|
|
{ |
|
|
|
if (grad_fn != null) |
|
|
|
{ |
|
|
|
in_grads = _MaybeCompile(grad_scope, |
|
|
|
op, |
|
|
|
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), |
|
|
|
null, |
|
|
|
grad_fn); |
|
|
|
} |
|
|
|
// Only trainable outputs or outputs for a function call that |
|
|
|
// will use SymbolicGradient get a zero gradient. Gradient |
|
|
|
// functions should ignore the gradient for other outputs. |
|
|
|
if (loop_state != null) |
|
|
|
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) }; |
|
|
|
else |
|
|
|
{ |
|
|
|
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); |
|
|
|
} |
|
|
|
_VerifyGeneratedGradients(in_grads, op); |
|
|
|
if (gate_gradients && in_grads.Count(x => x != null) > 1) |
|
|
|
{ |
|
|
|
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); |
|
|
|
in_grads = control_flow_ops.tuple(in_grads); |
|
|
|
} |
|
|
|
}); |
|
|
|
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; |
|
|
|
} |
|
|
|
} |
|
|
|
else |
|
|
|
|
|
|
|
tf_with(ops.name_scope(op.name + "_grad"), scope1 => |
|
|
|
{ |
|
|
|
// If no grad_fn is defined or none of out_grads is available, |
|
|
|
// just propagate a list of None backwards. |
|
|
|
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; |
|
|
|
} |
|
|
|
if (grad_fn != null) |
|
|
|
{ |
|
|
|
in_grads = _MaybeCompile(grad_scope, |
|
|
|
op, |
|
|
|
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), |
|
|
|
null, |
|
|
|
grad_fn); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); |
|
|
|
} |
|
|
|
_VerifyGeneratedGradients(in_grads, op); |
|
|
|
if (gate_gradients && in_grads.Count(x => x != null) > 1) |
|
|
|
{ |
|
|
|
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); |
|
|
|
in_grads = control_flow_ops.tuple(in_grads); |
|
|
|
} |
|
|
|
}); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
// If no grad_fn is defined or none of out_grads is available, |
|
|
|
// just propagate a list of None backwards. |
|
|
|
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -370,7 +362,16 @@ namespace Tensorflow |
|
|
|
grads[op.name] = op_grads; |
|
|
|
} |
|
|
|
var t_grads = op_grads[t.value_index]; |
|
|
|
t_grads.Add(grad); |
|
|
|
if (t_grads.Count == 0) |
|
|
|
t_grads.Add(grad); |
|
|
|
else |
|
|
|
op_grads[t.value_index][0] = grad; |
|
|
|
|
|
|
|
/*if (control_flow_util.IsLoopSwitch(op) && |
|
|
|
t_grads[0] == null) |
|
|
|
op_grads[t.value_index] = new List<Tensor> { grad }; |
|
|
|
else |
|
|
|
t_grads.Add(grad);*/ |
|
|
|
} |
|
|
|
|
|
|
|
private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs) |
|
|
|
@@ -379,7 +380,8 @@ namespace Tensorflow |
|
|
|
yield return op.inputs[i]; |
|
|
|
} |
|
|
|
|
|
|
|
private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) |
|
|
|
private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, |
|
|
|
ControlFlowState loop_state, int aggregation_method = 0) |
|
|
|
{ |
|
|
|
var out_grads = _GetGrads(grads, op); |
|
|
|
|
|
|
|
@@ -387,7 +389,10 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
if (loop_state != null) |
|
|
|
{ |
|
|
|
|
|
|
|
if (out_grads.Count > 1 && |
|
|
|
out_grads[1].Count > 0 && |
|
|
|
control_flow_util.IsLoopSwitch(op)) |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// Aggregate multiple gradients, and convert [] to None. |
|
|
|
|