Browse Source

ControlFlow MergeOutput

tags/v0.12
Oceania2018 6 years ago
parent
commit
e2190c94fc
5 changed files with 459 additions and 309 deletions
  1. +26
    -2
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  2. +157
    -109
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs
  3. +90
    -180
      src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
  4. +36
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs
  5. +150
    -18
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

+ 26
- 2
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -20,6 +20,7 @@ using System.Linq;
using Tensorflow.Operations.ControlFlows; using Tensorflow.Operations.ControlFlows;
using static Tensorflow.ControlFlowContextDef; using static Tensorflow.ControlFlowContextDef;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using util = Tensorflow.control_flow_util;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
@@ -146,6 +147,14 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(last_context); graph._set_control_flow_context(last_context);
} }


public void ExitResult(Tensor[] result)
{
if(_outer_context != null)
{
throw new NotImplementedException("ExitResult");
}
}

/// <summary> /// <summary>
/// Add `op` to the current context. /// Add `op` to the current context.
/// </summary> /// </summary>
@@ -172,6 +181,11 @@ namespace Tensorflow.Operations
return null; return null;
} }


public void AddName(string name)
{
_values.Add(name);
}

/// <summary> /// <summary>
/// Notifies a scope about an operator added to an inner scope. /// Notifies a scope about an operator added to an inner scope.
/// </summary> /// </summary>
@@ -246,9 +260,11 @@ namespace Tensorflow.Operations
} }
else else
{ {
foreach(Tensor x in op.control_inputs)
foreach(Operation x in op.control_inputs)
{ {
throw new NotImplementedException("");
var ctxt = util.GetOutputContext(x);
if (ctxt != null && ctxt.GetWhileContext() == while_ctxt)
internal_control_inputs.append(x);
} }
} }


@@ -288,6 +304,14 @@ namespace Tensorflow.Operations
throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}"); throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}");
} }


public virtual bool IsWhileContext()
{
throw new NotImplementedException("IsWhileContext");
}

public virtual bool IsCondContext()
=> false;

public object to_proto() public object to_proto()
{ {
throw new NotImplementedException(); throw new NotImplementedException();


+ 157
- 109
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs View File

@@ -14,6 +14,12 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/
using System;
using System.Linq;
using System.Collections.Generic;
using util = Tensorflow.control_flow_util;
using static Tensorflow.Binding;
namespace Tensorflow.Operations.ControlFlows namespace Tensorflow.Operations.ControlFlows
{ {
/// <summary> /// <summary>
@@ -21,6 +27,7 @@ namespace Tensorflow.Operations.ControlFlows
/// </summary> /// </summary>
public class ControlFlowState public class ControlFlowState
{ {
Dictionary<ControlFlowContext, GradLoopState> _map;
//class ControlFlowState(object): //class ControlFlowState(object):
// """Maintain the mapping from the loops to their grad states.""" // """Maintain the mapping from the loops to their grad states."""
@@ -40,51 +47,67 @@ namespace Tensorflow.Operations.ControlFlows
// return self._map.get(forward_ctxt) // return self._map.get(forward_ctxt)
// return None // return None
// def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
// """Process all the "unused" loop exits.
// The "unused" exits of the loops are added to `unused_exits`. An exit is
// unused if its pending_count is 0. If there is an exit with real gradient,
// all these deferred exits will enter the backprop loop with zero gradient.
// Otherwise, they will enter the backprop loop with None. As an example,
// people often write:
// ```python
// v1, _ = tf.while_loop(p, b, [x1, x2])
// result = gradients(v1, x1)
// ```
// The exit node for x2 is not included by the betweenness analysis. But we
// need to backprop x2 if x2 is involved in computing v1.
// Args:
// pending_count: The number of backprop inputs for every op.
// to_ops_set: The set of ops for ys in gradients(ys, xs)
// Returns:
// The set of unused loop exits that we know at this point we need
// to backprop.
// """
// loop_exits = []
// for grad_state in self._map.values():
// for y in grad_state.forward_loop_exits:
// if pending_count[y.op] == 0:
// grad_state.pending_exits_count -= 1
// if y.op not in to_ops_set:
// grad_state.unused_exits.append(y)
// if grad_state.pending_exits_count == 0:
// loop_exits.extend(grad_state.unused_exits)
// # Need to include Enters in backprop for higher-order gradients.
// for y in grad_state.forward_context.loop_enters:
// if pending_count[y.op] == 0:
// pending_count[y.op] = 1
// return loop_exits
// def EnterGradWhileContext(self, op, before):
// """Enter the WhileContext for gradient computation."""
// grad_state = self.GetGradState(op, before)
// if grad_state:
// grad_state.grad_context.Enter()
public ControlFlowState()
{
_map = new Dictionary<ControlFlowContext, GradLoopState>();
}
/// <summary>
/// Return the grad state for this op if it's in a forward loop context.
/// </summary>
/// <param name="op"></param>
/// <param name="before"></param>
/// <returns></returns>
public GradLoopState GetGradState(Operation op, bool before)
{
ControlFlowContext forward_ctxt = null;
if (before && util.IsLoopExit(op))
{
forward_ctxt = op._get_control_flow_context();
forward_ctxt = forward_ctxt.outer_context;
if (forward_ctxt != null)
forward_ctxt = forward_ctxt.GetWhileContext();
}
else
forward_ctxt = util.GetWhileContext(op);
if (forward_ctxt != null)
return _map.get(forward_ctxt);
return null;
}
public Tensor[] ProcessUnusedLoopExits(Dictionary<string, int> pending_count, List<Operation> to_ops_set)
{
var loop_exits = new List<Tensor>();
foreach(var grad_state in _map.Values)
{
foreach(var y in grad_state.forward_loop_exits)
{
if(!pending_count.ContainsKey(y.op.name))
{
grad_state.pending_exits_count -= 1;
if (!to_ops_set.Contains(y.op))
grad_state.unused_exits.append(y);
if (grad_state.pending_exits_count == 0)
loop_exits.extend(grad_state.unused_exits);
}
}
foreach(var y in grad_state.forward_context.loop_enters)
{
if (!pending_count.ContainsKey(y.op.name))
pending_count[y.op.name] = 1;
}
}
return loop_exits.ToArray();
}
public void EnterGradWhileContext(Operation op, bool before)
{
var grad_state = GetGradState(op, before);
if (grad_state != null)
grad_state.grad_context.Enter();
}
// def ExitGradWhileContext(self, op, before): // def ExitGradWhileContext(self, op, before):
// """Exit the WhileContext for gradient computation.""" // """Exit the WhileContext for gradient computation."""
@@ -118,6 +141,32 @@ namespace Tensorflow.Operations.ControlFlows
// if loop_exit.op not in between_ops: // if loop_exit.op not in between_ops:
// between_ops.add(loop_exit.op) // between_ops.add(loop_exit.op)
// between_op_list.append(loop_exit.op) // between_op_list.append(loop_exit.op)
public void AddWhileContext(Operation op, List<Operation> between_op_list, List<Operation> between_ops)
{
var forward_ctxt = op.GetWhileContext();
var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null;
if(grad_state == null)
{
GradLoopState outer_grad_state = null;
var outer_forward_ctxt = forward_ctxt.outer_context;
if (outer_forward_ctxt != null)
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext();
if (outer_forward_ctxt != null)
outer_grad_state = _map[outer_forward_ctxt];
grad_state = new GradLoopState(forward_ctxt, outer_grad_state);
_map[forward_ctxt] = grad_state;
// We need to include all exits of a loop for backprop.
foreach (var loop_exit in grad_state.forward_loop_exits)
{
if(!between_ops.Contains(loop_exit.op))
{
between_ops.add(loop_exit.op);
between_op_list.append(loop_exit.op);
}
}
}
}
// def ZerosLikeForExit(self, val): // def ZerosLikeForExit(self, val):
// """Create zeros_like gradient for a loop exit. // """Create zeros_like gradient for a loop exit.
@@ -174,70 +223,69 @@ namespace Tensorflow.Operations.ControlFlows
// result = array_ops.zeros_like(val, optimize=False) // result = array_ops.zeros_like(val, optimize=False)
// return result // return result
// def ZerosLike(self, op, index):
// """Create zeros_like for the specified output of an op.
// If op is in a while loop that is part of gradients(), this method
// must be called in its grad loop context.
// Args:
// op: A tensorflow operation.
// index: the index for a specific output of the op.
// Returns:
// A zero tensor of the same shape of op.outputs[index].
// """
// if util.IsLoopSwitch(op):
// return None
// if op.graph._building_function: # pylint: disable=protected-access
// # The optimization here is tricky to apply to functions
// return array_ops.zeros_like(op.outputs[index])
// dead_branch = util.IsSwitch(op)
// forward_ctxt = _GetWhileContext(op)
// grad_state = self._map.get(forward_ctxt)
// if grad_state is None:
// # op is not in a while loop that is part of gradients().
// return ZerosLikeOutsideLoop(op, index)
// op_ctxt = op._get_control_flow_context()
// val = ops.convert_to_tensor(op.outputs[index], name="tensor")
// shape = val.get_shape()
// if shape.is_fully_defined():
// # If the shape is known statically, just create a zero tensor with
// # the right shape in the grad loop context.
// result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
// if dead_branch:
// # op is a cond switch. Guard the zero tensor with a switch.
// pred = grad_state.history_map.get(op_ctxt.pred.name)
// branch = op_ctxt.branch
// result = _SwitchRefOrTensor(result, pred)[1 - branch]
// else:
// # Unknown shape so keep a history of the shape at runtime.
// if dead_branch:
// # Need to add a special switch to guard the value.
// pred = op_ctxt.pred
// branch = op_ctxt.branch
// op_ctxt.outer_context.Enter()
// val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
// zeros_shape = array_ops.shape_internal(val, optimize=False)
// op_ctxt.outer_context.Exit()
// val.op._set_control_flow_context(op_ctxt)
// zeros_shape.op._set_control_flow_context(op_ctxt)
// else:
// op_ctxt.Enter()
// zeros_shape = array_ops.shape_internal(val, optimize=False)
// op_ctxt.Exit()
// # Add forward accumulator for shape.
// grad_state.grad_context.Exit()
// history_zeros_shape = grad_state.AddForwardAccumulator(
// zeros_shape, dead_branch=dead_branch)
// grad_state.grad_context.Enter()
// # Create a zero tensor with the right shape.
// shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
// zeros_shape, dead_branch)
// result = array_ops.zeros(shape, val.dtype)
// return result
public Tensor ZerosLike(Operation op, int index)
{
if (util.IsLoopSwitch(op))
return null;
if (op.graph.building_function)
return array_ops.zeros_like(op.outputs[index]);
var dead_branch = util.IsSwitch(op);
var forward_ctxt = util.GetWhileContext(op);
var grad_state = _map.get(forward_ctxt);
// op is not in a while loop that is part of gradients().
if (grad_state == null)
return ZerosLikeOutsideLoop(op, index);
throw new NotImplementedException("ZerosLike");
}
public Tensor ZerosLikeOutsideLoop(Operation op, int index)
{
var val = op.outputs[index];
if (!util.IsSwitch(op))
{
if (val.dtype == dtypes.resource)
throw new NotImplementedException("ZerosLikeOutsideLoop");
/*return array_ops.zeros(
gen_resource_variable_ops.variable_shape(val),
dtype: default_gradient.get_zeros_dtype(val));*/
return array_ops.zeros_like(val, optimize: false);
}
else
throw new NotImplementedException("ZerosLikeOutsideLoop");
}
/// <summary>
/// Create zeros_like gradient for a loop exit.
/// </summary>
/// <param name="val"></param>
/// <returns></returns>
public Tensor ZerosLikeForExit(Tensor val)
{
Tensor result = null;
var val_shape = val.TensorShape;
var forward_ctxt = val.op._get_control_flow_context();
var outer_forward_ctxt = forward_ctxt.outer_context;
if (outer_forward_ctxt != null)
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext();
GradLoopState outer_grad_state = null;
if (outer_forward_ctxt != null)
outer_grad_state = _map.get(outer_forward_ctxt);
// This is a nested loop.
if (outer_grad_state != null)
{
throw new NotImplementedException("ZerosLikeForExit");
}
else
{
// If the shape is known statically, just create a zero tensor
// with the right shape.
if (val_shape.is_fully_defined())
result = array_ops.zeros(val_shape.dims, val.dtype);
else
result = array_ops.zeros_like(val, optimize: false);
}
return result;
}
// def PostProcessing(self): // def PostProcessing(self):
// """Perform postprocessing at the end of gradients(). // """Perform postprocessing at the end of gradients().


+ 90
- 180
src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs View File

@@ -16,41 +16,16 @@
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic;
using static Tensorflow.Binding;
namespace Tensorflow.Operations.ControlFlows namespace Tensorflow.Operations.ControlFlows
{ {
/// <summary>
/// The state used for constructing the gradient graph for a while loop.
/// </summary>
public class GradLoopState public class GradLoopState
{ {
//class GradLoopState(object):
// """The state used for constructing the gradient graph for a while loop.
// We create a GradLoopState for each while loop in forward and its
// corresponding while loop in backprop. This gives us access to both
// the forward and the backprop WhileContexts.
// During the construction of gradient graph, any time when we detect
// a forward value that is needed for backprop, we create a history
// accumulator and add it to `history_map`. Any time when we backprop
// a loop switch op (in _SwitchGrad), we add the grad merge op in
// `switch_map`.
// """
// def __init__(self, forward_ctxt, outer_grad_state):
// # The grad loop state for the outer while loop.
// self._outer_grad_state = None
// # The while loop context for forward.
// self._forward_context = None
// # The loop counter added by AddForwardLoopCounter. It is the value
// # of the loop counter for the next iteration.
// self._forward_index = None
// # A sync op for forward.
// self._forward_sync = None
// # The while loop context for backprop.
private WhileContext _grad_context = null; private WhileContext _grad_context = null;
public WhileContext grad_context => _grad_context; public WhileContext grad_context => _grad_context;
@@ -65,156 +40,91 @@ namespace Tensorflow.Operations.ControlFlows
// # Information needed by backprop. // # Information needed by backprop.
private Hashtable _history_map = new Hashtable(); private Hashtable _history_map = new Hashtable();
public Hashtable history_map => _history_map; public Hashtable history_map => _history_map;
private Hashtable _switch_map = new Hashtable();
public Hashtable switch_map => _switch_map;
// self._unused_exits = []
// self._deferred_exits = []
// self._forward_loop_exits = list(forward_ctxt.loop_exits)
// self._pending_exits_count = len(forward_ctxt.loop_exits)
// self._outer_grad_state = outer_grad_state
// if outer_grad_state:
// outer_forward_ctxt = outer_grad_state.forward_context
// else:
// if not hasattr(forward_ctxt, "outer_context"):
// raise ValueError("Failed to call gradients on a while loop without"
// "properly serializing graph via MetaGraphDef")
// outer_forward_ctxt = forward_ctxt.outer_context
// # Add the forward loop counter.
// with forward_ctxt._graph.as_default(): # pylint: disable=protected-access
// if outer_forward_ctxt:
// outer_forward_ctxt.Enter()
// cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
// if outer_forward_ctxt:
// outer_forward_ctxt.Exit()
// self._forward_context = forward_ctxt
// self._forward_index = forward_index
// # Add the backprop WhileContext, and the backprop loop counter.
// if outer_grad_state:
// # This is a nested loop. Remember the iteration counts for each
// # execution of this inner loop.
// outer_forward_ctxt.AddName(cnt.name)
// history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
// outer_grad_ctxt = outer_grad_state.grad_context
// outer_grad_ctxt.Enter()
// self._grad_context = WhileContext(
// maximum_iterations=forward_ctxt.maximum_iterations,
// parallel_iterations=forward_ctxt.parallel_iterations,
// back_prop=forward_ctxt.back_prop,
// swap_memory=forward_ctxt.swap_memory,
// name=forward_ctxt.name,
// grad_state=self)
// real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
// self._grad_index = self._grad_context.AddBackpropLoopCounter(
// real_cnt, outer_grad_state)
// outer_grad_ctxt.Exit()
// else:
// if outer_forward_ctxt:
// outer_forward_ctxt.Enter()
// self._grad_context = WhileContext(
// maximum_iterations=forward_ctxt.maximum_iterations,
// parallel_iterations=forward_ctxt.parallel_iterations,
// back_prop=forward_ctxt.back_prop,
// swap_memory=forward_ctxt.swap_memory,
// name=forward_ctxt.name,
// grad_state=self)
// self._grad_index = self._grad_context.AddBackpropLoopCounter(
// cnt, outer_grad_state)
// if outer_forward_ctxt:
// outer_forward_ctxt.Exit()
// @property
// def outer_grad_state(self):
// """The grad loop state for outer loop."""
// return self._outer_grad_state
// @property
// def forward_context(self):
// """The while loop context for forward."""
// return self._forward_context
// @property
// def forward_index(self):
// """The loop index of forward loop."""
// return self._forward_index
// @property
// def forward_sync(self):
// """A control trigger node for synchronization in the forward loop.
// One main use is to keep the push ops of a stack executed in the
// iteration order.
// """
// if self._forward_sync is None:
// with ops.control_dependencies(None):
// self._forward_sync = control_trigger(name="f_sync")
// self._forward_sync._set_control_flow_context(self._forward_context)
// self._forward_index.op._add_control_input(self._forward_sync)
// return self._forward_sync
// @property
// def grad_context(self):
// """The corresponding WhileContext for gradient."""
// return self._grad_context
// @property
// def grad_index(self):
// """The loop index of backprop loop."""
// return self._grad_index
// @property
// def grad_sync(self):
// """A control trigger node for synchronization in the grad loop.
// One main use is to keep the pop ops of a stack executed in the
// iteration order.
// """
// if self._grad_sync is None:
// with ops.control_dependencies(None):
// self._grad_sync = control_trigger(name="b_sync")
// self._grad_sync._set_control_flow_context(self._grad_context)
// self._grad_index.op._add_control_input(self._grad_sync)
// if self._grad_context.outer_context:
// self._grad_context.outer_context.AddInnerOp(self._grad_sync)
// return self._grad_sync
// @property
// def history_map(self):
// """The map that records all the tensors needed for backprop."""
// return self._history_map
// @property
// def switch_map(self):
// """The map that records all the Switch ops for the while loop."""
// return self._switch_map
// @property
// def unused_exits(self):
// """The list of "unused" exits."""
// return self._unused_exits
// @property
// def deferred_exits(self):
// """The list of "deferred" exits."""
// return self._deferred_exits
// @property
// def forward_loop_exits(self):
// """The list of exits of the forward loop."""
// return self._forward_loop_exits
// @property
// def pending_exits_count(self):
// """The number of exits we expect to see but haven't."""
// return self._pending_exits_count
// @pending_exits_count.setter
// def pending_exits_count(self, cnt):
// """Set the pending count to cnt."""
// self._pending_exits_count = cnt
Dictionary<Operation, Tensor> _switch_map = new Dictionary<Operation, Tensor>();
public Dictionary<Operation, Tensor> switch_map => _switch_map;
/// <summary>
/// The while loop context for forward.
/// </summary>
WhileContext _forward_context;
public WhileContext forward_context => _forward_context;
/// <summary>
/// The grad loop state for the outer while loop.
/// </summary>
GradLoopState _outer_grad_state;
public GradLoopState outer_grad_state => _outer_grad_state;
Tensor _forward_index;
Tensor _grad_index;
Tensor[] _forward_loop_exits;
/// <summary>
/// The list of exits of the forward loop.
/// </summary>
public Tensor[] forward_loop_exits => _forward_loop_exits;
List<Tensor> _deferred_exits;
public List<Tensor> deferred_exits => _deferred_exits;
List<Tensor> _unused_exits;
public List<Tensor> unused_exits => _unused_exits;
/// <summary>
/// The number of exits we expect to see but haven't.
/// </summary>
public int pending_exits_count { get; set; }
public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
{
// Information needed by backprop.
_unused_exits = new List<Tensor>();
_deferred_exits = new List<Tensor>();
_forward_loop_exits = list(forward_ctxt.loop_exits);
pending_exits_count = len(forward_ctxt.loop_exits);
_outer_grad_state = outer_grad_state_;
ControlFlowContext outer_forward_ctxt = null;
if (outer_grad_state_ != null)
outer_forward_ctxt = outer_grad_state_.forward_context;
// Add the forward loop counter.
// with forward_ctxt._graph.as_default():
Tensor cnt, forward_index;
{
if (outer_forward_ctxt != null)
outer_forward_ctxt.Enter();
(cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state);
if (outer_forward_ctxt != null)
outer_forward_ctxt.Exit();
}
_forward_context = forward_ctxt;
_forward_index = forward_index;
// Add the backprop WhileContext, and the backprop loop counter.
if (outer_grad_state != null)
{
// This is a nested loop. Remember the iteration counts for each
// execution of this inner loop.
throw new NotImplementedException("GradLoopState");
}
else
{
if (outer_forward_ctxt != null)
outer_forward_ctxt.Enter();
_grad_context = new WhileContext(
maximum_iterations: forward_ctxt.maximum_iterations,
parallel_iterations: forward_ctxt.parallel_iterations,
back_prop: forward_ctxt.back_prop,
swap_memory: forward_ctxt.swap_memory,
name: forward_ctxt.name,
grad_state: this);
_grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state);
if (outer_forward_ctxt != null)
outer_forward_ctxt.Exit();
}
}
/// <summary> /// <summary>
/// Add an accumulator for each forward tensor that is needed in backprop. /// Add an accumulator for each forward tensor that is needed in backprop.


+ 36
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
public class MergeOutput
{
Tensor output;
Tensor value_index;
public MergeOutput(Tensor[] values)
{
output = values[0];
value_index = values[1];
}

public Tensor this[int idx]
{
get
{
switch(idx)
{
case 0:
return output;
case 1:
return value_index;
default:
return null;
}
}
}

public static implicit operator Tensor(MergeOutput merge)
=> merge.output;
}
}

+ 150
- 18
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -32,12 +32,17 @@ namespace Tensorflow.Operations
bool _back_prop=true; bool _back_prop=true;
GradLoopState _grad_state =null; GradLoopState _grad_state =null;
Tensor _maximum_iterations; Tensor _maximum_iterations;
public Tensor maximum_iterations => _maximum_iterations;
int _parallel_iterations; int _parallel_iterations;
public int parallel_iterations => _parallel_iterations;
bool _swap_memory; bool _swap_memory;
public bool swap_memory => _swap_memory;
Tensor _pivot_for_pred; Tensor _pivot_for_pred;
Tensor _pivot_for_body; Tensor _pivot_for_body;
List<Tensor> _loop_exits; List<Tensor> _loop_exits;
public List<Tensor> loop_exits => _loop_exits;
List<Tensor> _loop_enters; List<Tensor> _loop_enters;
public List<Tensor> loop_enters => _loop_enters;
Graph _graph; Graph _graph;
public override GradLoopState grad_state => _grad_state; public override GradLoopState grad_state => _grad_state;
public override bool back_prop => _back_prop; public override bool back_prop => _back_prop;
@@ -109,7 +114,7 @@ namespace Tensorflow.Operations
/// <summary> /// <summary>
/// Add the loop termination condition and body to the graph. /// Add the loop termination condition and body to the graph.
/// </summary> /// </summary>
internal Tensor[] BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
internal LoopVar<TItem> BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
Func<LoopVar<TItem>, LoopVar<TItem>> body, Func<LoopVar<TItem>, LoopVar<TItem>> body,
LoopVar<TItem> loop_vars, LoopVar<TItem> loop_vars,
TensorShape[] shape_invariants, TensorShape[] shape_invariants,
@@ -132,14 +137,16 @@ namespace Tensorflow.Operations
pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); pred, body, original_loop_vars, loop_vars_tensors, shape_invariants);
Exit(); Exit();


var flat_result = original_body_result;
var flat_result = nest.flatten2(original_body_result)
.Select(x => x as ITensorOrTensorArray)
.ToArray();


var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars);
var packed_exit_vars = nest.pack_sequence_as(
var packed_exit_vars = nest.pack_sequence_as2(
structure: original_body_result, structure: original_body_result,
flat_sequence: exit_vars_with_tensor_arrays); flat_sequence: exit_vars_with_tensor_arrays);


return packed_exit_vars as Tensor[];
return packed_exit_vars;
} }


private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array)
@@ -167,7 +174,7 @@ namespace Tensorflow.Operations
/// <param name="loop_vars"></param> /// <param name="loop_vars"></param>
/// <param name="shape_invariants"></param> /// <param name="shape_invariants"></param>
/// <returns></returns> /// <returns></returns>
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
private (LoopVar<TItem>, Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
Func<LoopVar<TItem>, LoopVar<TItem>> body, Func<LoopVar<TItem>, LoopVar<TItem>> body,
LoopVar<TItem> original_loop_vars, LoopVar<TItem> original_loop_vars,
Tensor[] loop_vars, Tensor[] loop_vars,
@@ -221,6 +228,7 @@ namespace Tensorflow.Operations


var merge_vars = enter_vars var merge_vars = enter_vars
.Select(x => merge(new[] { x, x })) .Select(x => merge(new[] { x, x }))
.Select(m => (Tensor)m)
.ToArray(); .ToArray();


_pivot_for_pred = merge_vars[0]; _pivot_for_pred = merge_vars[0];
@@ -250,13 +258,15 @@ namespace Tensorflow.Operations
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);


// Store body_result to keep track of TensorArrays returned by body // Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };
var original_body_result = body_result;
// Convert TensorArrays returned by body into their flow variables // Convert TensorArrays returned by body into their flow variables
var result = new[] { body_result };

var result = nest.flatten2(body_result)
.Select(x => _convert_tensorarray_to_flow(x))
.ToArray();
// result = ops.convert_n_to_tensor_or_composite(result);
var next_vars = new List<Tensor>(); var next_vars = new List<Tensor>();
//foreach (var (m, v) in zip(merge_vars, result))
//next_vars.Add(_AddNextAndBackEdge(m, v));
foreach (var (m, v) in zip(merge_vars, result))
next_vars.Add(_AddNextAndBackEdge(m, v));


// Add the exit ops. // Add the exit ops.
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
@@ -264,7 +274,7 @@ namespace Tensorflow.Operations


// Exit the loop. // Exit the loop.
// ExitResult(exit_vars); // ExitResult(exit_vars);
return (null, exit_vars.ToArray());
return (original_body_result, exit_vars.ToArray());
} }


private void _FixControlInputsAndContext(Tensor[] enters) private void _FixControlInputsAndContext(Tensor[] enters)
@@ -282,7 +292,18 @@ namespace Tensorflow.Operations
var keep_as_control_input = true; var keep_as_control_input = true;
var op_ctxt = control_flow_util.GetOutputContext(op); var op_ctxt = control_flow_util.GetOutputContext(op);
var outer_ctxt = outer_context; var outer_ctxt = outer_context;
throw new NotImplementedException("");
var outer_while_context = outer_ctxt == null ? null : outer_ctxt.GetWhileContext();
while (outer_ctxt != op_ctxt)
{
if (outer_ctxt == null || outer_ctxt == outer_while_context)
{
keep_as_control_input = false;
break;
}
outer_ctxt = outer_ctxt.outer_context;
}
if (keep_as_control_input)
outer_control_inputs.append(op);
} }
// op for op in control_inputs if self._IsInOuterContext(op) // op for op in control_inputs if self._IsInOuterContext(op)
/*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
@@ -307,10 +328,21 @@ namespace Tensorflow.Operations


protected override void _AddOpInternal(Operation op) protected override void _AddOpInternal(Operation op)
{ {
if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad")
{

}
Operation[] external_inputs = new Operation[0]; Operation[] external_inputs = new Operation[0];
Operation[] control_inputs = new Operation[0];
if (op.inputs.Length == 0) if (op.inputs.Length == 0)
{ {
throw new NotImplementedException("");
// Remove any external control dependency on this op
(control_inputs, external_inputs) = _RemoveExternalControlEdges(op);
if (control_inputs.Length == 0)
op._add_control_input(GetControlPivot().op);
foreach (var x in op.outputs)
_values.Add(x.name);
} }
else else
{ {
@@ -378,6 +410,93 @@ namespace Tensorflow.Operations
_AddOpInternal(op); _AddOpInternal(op);
} }


/// <summary>
/// Adds a loop that counts the number of iterations.
/// </summary>
/// <param name="outer_grad_state">The outer grad state. None if not nested.</param>
/// <returns>The number of iterations taken by the forward loop and the loop index.</returns>
public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state)
{
var n = constant_op.constant(0, name: "f_count");
if (outer_grad_state != null)
throw new NotImplementedException("AddForwardLoopCounter");

Enter();
AddName(n.name);
var enter_n = _Enter(n,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
name: "f_count");
_loop_enters.Add(enter_n);

var m1 = merge(new[] { enter_n, enter_n });
var merge_n = m1[0];
var switch_n = @switch (merge_n, _pivot);

var index = math_ops.add(switch_n[1], 1);
var next_n = _NextIteration(index);
merge_n.op._update_input(1, next_n);

var total_iterations = exit(switch_n[0], name: "f_count");
loop_exits.append(total_iterations);
ExitResult(new[] { total_iterations });
Exit();

return (total_iterations, next_n);
}

/// <summary>
/// Add the backprop loop that controls the iterations.
/// </summary>
/// <param name="count">The number of iterations for backprop.</param>
/// <param name="outer_grad_state">The outer grad state. None if not nested.</param>
/// <returns>The loop index.</returns>
public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state)
{
Tensor one = null;
var in_separate_functions = count.graph != ops.get_default_graph();
if (in_separate_functions)
// Brings the count into this graph
count = array_ops.identity(count);
else
one = constant_op.constant(1, name: "b_count");

Enter();
AddName(count.name);
var enter_count = _Enter(
count,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
name: "b_count");
loop_enters.append(enter_count);

var merge_count = merge(new[] { enter_count, enter_count })[0];
_pivot_for_pred = merge_count;
if (in_separate_functions)
one = constant_op.constant(1, name: "b_count");
var pred = math_ops.greater_equal(merge_count, one);
_pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count");
var switch_count = @switch(merge_count, _pivot);

var index = math_ops.subtract(switch_count[1], one);
_pivot_for_body = index;
var next_count = _NextIteration(index);
merge_count.op._update_input(1, next_count);

var final_zero = exit(switch_count[0], name: "b_count");
loop_exits.append(final_zero);
// Force the stack pops of i-th execution of an inner loop to be ordered
// before the pops of (i+1)-th execution of the same inner loop.
if (outer_grad_state != null)
throw new NotImplementedException("outer_grad_state");
//outer_grad_state.grad_sync._add_control_input(final_zero.op);
ExitResult(new[] { final_zero });
Exit();
return next_count;
}

/// <summary> /// <summary>
/// Add `val` to the current context and its outer context recursively. /// Add `val` to the current context and its outer context recursively.
/// </summary> /// </summary>
@@ -401,17 +520,27 @@ namespace Tensorflow.Operations
grad_ctxt = grad_ctxt.GetWhileContext(); grad_ctxt = grad_ctxt.GetWhileContext();
if (grad_ctxt.grad_state != null) if (grad_ctxt.grad_state != null)
{ {
throw new NotImplementedException("");
var forward_ctxt = val.op.GetWhileContext();
if (control_flow_util.IsLoopExit(val.op))
{
forward_ctxt = forward_ctxt.outer_context as WhileContext;
if (forward_ctxt != null)
forward_ctxt = forward_ctxt.GetWhileContext();
throw new NotImplementedException("control_flow_util.IsLoopExit");
}
if(forward_ctxt == grad_ctxt.grad_state.forward_context)
{
throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context");
/*real_val = grad_ctxt.grad_state.GetRealValue(val);
_external_values[val.name] = real_val;
return real_val;*/
}
} }
} }


if (_outer_context != null) if (_outer_context != null)
result = _outer_context.AddValue(val); result = _outer_context.AddValue(val);


if (tf.get_default_graph()._nodes_by_name.Count >= 83)
{

}
// Create an Enter to make `result` known to this loop context. // Create an Enter to make `result` known to this loop context.
Tensor enter = null; Tensor enter = null;
tf_with(ops.control_dependencies(null), delegate tf_with(ops.control_dependencies(null), delegate
@@ -443,6 +572,9 @@ namespace Tensorflow.Operations
return result; return result;
} }


public override bool IsWhileContext()
=> true;

public override WhileContext GetWhileContext() public override WhileContext GetWhileContext()
{ {
return this; return this;


Loading…
Cancel
Save