Browse Source

_SwitchGrad

tags/v0.12
Oceania2018 6 years ago
parent
commit
a8a515682e
1 changed files with 74 additions and 67 deletions
  1. +74
    -67
      src/TensorFlowNET.Core/Gradients/control_flow_grad.cs

+ 74
- 67
src/TensorFlowNET.Core/Gradients/control_flow_grad.cs View File

@@ -45,7 +45,19 @@ namespace Tensorflow.Gradients
switch (op_ctxt)
{
case WhileContext cwhile:
throw new NotImplementedException("_SwitchGrad WhileContext");
{
var merge_grad = grad_ctxt.grad_state.switch_map.get(op);
if (merge_grad != null)
throw new NotImplementedException("_SwitchGrad merge_grad != null");
else if (grads[0] != null)
{
merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0];
grad_ctxt.grad_state.switch_map[op] = merge_grad;
return new Tensor[] { merge_grad, null };
}
else
return new Tensor[] { null, null };
}
case CondContext ccond:
{
var zero_grad = grads[1 - op_ctxt.branch];
@@ -74,7 +86,7 @@ namespace Tensorflow.Gradients
/// <param name="inputs"></param>
/// <param name="name"></param>
/// <returns></returns>
internal static Tensor[] merge(Tensor[] inputs, string name = null)
internal static MergeOutput merge(Tensor[] inputs, string name = null)
{
return tf_with(ops.name_scope(name, "Merge", inputs), scope =>
{
@@ -146,7 +158,7 @@ namespace Tensorflow.Gradients
}

[RegisterGradient("RefMerge")]
public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
public static Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
{
return _MergeGrad(op, grads);
}
@@ -155,43 +167,32 @@ namespace Tensorflow.Gradients
/// Gradients for an exit op are calculated using an Enter op.
/// </summary>
[RegisterGradient("Exit")]
public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
public static Tensor[] _ExitGrad(Operation op, Tensor[] grads)
{
throw new NotImplementedException("_ExitGrad");
// graph = ops.get_default_graph()
//# pylint: disable=protected-access
// op_ctxt = op._get_control_flow_context()
// grad_ctxt = graph._get_control_flow_context()
// # pylint: enable=protected-access
// if not grad_ctxt.back_prop:
// # The flag `back_prop` is set by users to suppress gradient
// # computation for this loop. If the attribute `back_prop` is false,
// # no gradient computation.
// return None
var grad = grads[0];
var graph = ops.get_default_graph();
var op_ctxt = op._get_control_flow_context();
var grad_ctxt = graph._get_control_flow_context() as WhileContext;
// The flag `back_prop` is set by users to suppress gradient
// computation for this loop. If the attribute `back_prop` is false,
// no gradient computation.
if (!grad_ctxt.back_prop)
return null;

if (op_ctxt.grad_state != null)
throw new TypeError("Second-order gradient for while loops not supported.");

grad_ctxt.AddName(grad.name);

// if op_ctxt.grad_state:
// raise TypeError("Second-order gradient for while loops not supported.")
grad_ctxt.Enter();
var result = control_flow_ops._Enter(
grad, grad_ctxt.name, is_constant: false,
parallel_iterations: grad_ctxt.parallel_iterations,
name: "b_exit");

// if isinstance(grad, ops.Tensor) :
// grad_ctxt.AddName(grad.name)
// else:
// if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
// raise TypeError("Type %s not supported" % type(grad))
// grad_ctxt.AddName(grad.values.name)
// grad_ctxt.AddName(grad.indices.name)
// dense_shape = grad.dense_shape
// if dense_shape is not None:
// grad_ctxt.AddName(dense_shape.name)
// grad_ctxt.Enter()
// # pylint: disable=protected-access
// result = control_flow_ops._Enter(
// grad, grad_ctxt.name, is_constant=False,
// parallel_iterations=grad_ctxt.parallel_iterations,
// name="b_exit")
// # pylint: enable=protected-access
// grad_ctxt.loop_enters.append(result)
// grad_ctxt.Exit()
// return result
grad_ctxt.loop_enters.append(result);
grad_ctxt.Exit();
return new[] { result };
}

/// <summary>
@@ -200,15 +201,15 @@ namespace Tensorflow.Gradients
/// Note that the backprop next_iteration is added in switch grad.
/// </summary>
[RegisterGradient("NextIteration")]
public Tensor[] _NextIterationGrad(object _, Tensor[] grad)
public static Tensor[] _NextIterationGrad(Operation op, Tensor[] grads)
{
return grad;
return grads;
}

[RegisterGradient("RefNextIteration")]
public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad)
public static Tensor[] _RefNextIterationGrad(Operation op, Tensor[] grads)
{
return grad;
return grads;
}

/// <summary>
@@ -218,33 +219,39 @@ namespace Tensorflow.Gradients
/// For loop invariants, we need to add an accumulator loop.
/// </summary>
[RegisterGradient("Enter")]
public Tensor[] _EnterGrad(Tensor op, Tensor[] grad)
public static Tensor[] _EnterGrad(Operation op, Tensor[] grads)
{
throw new NotImplementedException("_EnterGrad");
// graph = ops.get_default_graph()
//# pylint: disable=protected-access
// grad_ctxt = graph._get_control_flow_context()
// # pylint: enable=protected-access
// if not grad_ctxt.back_prop:
// # Skip gradient computation, if the attribute `back_prop` is false.
// return grad
// if grad_ctxt.grad_state is None:
// # Pass the gradient through if we are not in a gradient while context.
// return grad
// if op.get_attr("is_constant"):
// # Add a gradient accumulator for each loop invariant.
// if isinstance(grad, ops.Tensor) :
// result = grad_ctxt.AddBackpropAccumulator(op, grad)
// elif isinstance(grad, ops.IndexedSlices) :
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
// else:
// # TODO(yuanbyu, lukasr): Add support for SparseTensor.
// raise TypeError("Type %s not supported" % type(grad))
// else:
// result = exit(grad)
// grad_ctxt.loop_exits.append(result)
// grad_ctxt.ExitResult([result])
// return result
Tensor result = null;
var grad = grads[0];
var graph = ops.get_default_graph();
var grad_ctxt = graph._get_control_flow_context() as WhileContext;
if (!grad_ctxt.back_prop)
// Skip gradient computation, if the attribute `back_prop` is false.
return grads;
if (grad_ctxt.grad_state == null)
// Pass the gradient through if we are not in a gradient while context.
return grads;
if (op.get_attr<bool>("is_constant"))
{
throw new NotImplementedException("_EnterGrad is_constant");
// Add a gradient accumulator for each loop invariant.
// if isinstance(grad, ops.Tensor) :
// result = grad_ctxt.AddBackpropAccumulator(op, grad)
// elif isinstance(grad, ops.IndexedSlices) :
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
// else:
// # TODO(yuanbyu, lukasr): Add support for SparseTensor.
// raise TypeError("Type %s not supported" % type(grad))
}

else
{
result = control_flow_ops.exit(grad);
grad_ctxt.loop_exits.append(result);
grad_ctxt.ExitResult(new[] { result });
}

return new Tensor[] { result };
}




Loading…
Cancel
Save