Browse Source

LoopVar<T>

tags/v0.12
Oceania2018 6 years ago
parent
commit
3425aa4a29
3 changed files with 160 additions and 83 deletions
  1. +26
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs
  3. +129
    -82
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

+ 26
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Operations.ControlFlows; using Tensorflow.Operations.ControlFlows;
using static Tensorflow.ControlFlowContextDef; using static Tensorflow.ControlFlowContextDef;
using static Tensorflow.Binding;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
@@ -72,6 +73,7 @@ namespace Tensorflow.Operations
public ControlFlowContext() public ControlFlowContext()
{ {
_context_stack = new Stack<ControlFlowContext>(); _context_stack = new Stack<ControlFlowContext>();
_external_values = new Dictionary<string, ITensorOrOperation>();
} }


public string name { get => _name; } public string name { get => _name; }
@@ -180,6 +182,11 @@ namespace Tensorflow.Operations


public virtual bool back_prop => throw new NotImplementedException("abstract method"); public virtual bool back_prop => throw new NotImplementedException("abstract method");


/// <summary>
/// Add `val` to the current context and its outer context recursively.
/// </summary>
/// <param name="val"></param>
/// <returns></returns>
public virtual Tensor AddValue(Tensor val) public virtual Tensor AddValue(Tensor val)
{ {
// to be overridden // to be overridden
@@ -203,7 +210,25 @@ namespace Tensorflow.Operations
/// </summary> /// </summary>
protected virtual void _AddOpInternal(Operation op) protected virtual void _AddOpInternal(Operation op)
{ {
if (op.name == "rnn/while/Less")
{

}

if(op == null)
{
throw new NotImplementedException("");
}
else
{
foreach(var index in range(len(op.inputs)))
{
var x = op.inputs[index];
var real_x = AddValue(x);
if (real_x != x)
op._update_input(index, real_x);
}
}
} }


protected bool OpInContext(Operation op) protected bool OpInContext(Operation op)


+ 5
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs View File

@@ -24,5 +24,10 @@ namespace Tensorflow.Operations
elements.Add(Item); elements.Add(Item);
return elements.ToArray(); return elements.ToArray();
} }

public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar)
{
return (loopVar.Counter, loopVar.Item);
}
} }
} }

+ 129
- 82
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -71,6 +71,8 @@ namespace Tensorflow.Operations
string name) string name)
{ {
_name = ops.get_default_graph().unique_name(name); _name = ops.get_default_graph().unique_name(name);
_maximum_iterations = maximum_iterations;
_parallel_iterations = parallel_iterations;
_back_prop = back_prop; _back_prop = back_prop;
_swap_memory = swap_memory; _swap_memory = swap_memory;
_loop_exits = new List<Tensor>(); _loop_exits = new List<Tensor>();
@@ -107,18 +109,27 @@ namespace Tensorflow.Operations
/// <summary> /// <summary>
/// Add the loop termination condition and body to the graph. /// Add the loop termination condition and body to the graph.
/// </summary> /// </summary>
internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
Func<Tensor, TItem, LoopVar<TItem>> body,
internal Tensor[] BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
Func<LoopVar<TItem>, LoopVar<TItem>> body,
LoopVar<TItem> loop_vars, LoopVar<TItem> loop_vars,
TensorShape shape_invariants,
TensorShape[] shape_invariants,
bool return_same_structure) bool return_same_structure)
{ {
// Keep original_loop_vars to identify which are TensorArrays // Keep original_loop_vars to identify which are TensorArrays
var original_loop_vars = loop_vars; var original_loop_vars = loop_vars;
// Convert TensorArrays to their flow variables // Convert TensorArrays to their flow variables
var loop_vars_tensors = nest.flatten2(loop_vars)
.Select(x => _convert_tensorarray_to_flow(x))
.ToArray();

if (shape_invariants == null)
shape_invariants = loop_vars_tensors
.Select(x => _get_shape_invariant(x as Tensor))
.ToArray();

Enter(); Enter();
var(original_body_result, exit_vars) = _BuildLoop( var(original_body_result, exit_vars) = _BuildLoop(
pred, body, original_loop_vars, loop_vars, shape_invariants);
pred, body, original_loop_vars, loop_vars_tensors, shape_invariants);
Exit(); Exit();


var flat_result = original_body_result; var flat_result = original_body_result;
@@ -131,7 +142,7 @@ namespace Tensorflow.Operations
return packed_exit_vars as Tensor[]; return packed_exit_vars as Tensor[];
} }


private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array)
private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array)
{ {
if (tensor_or_tensor_array is TensorArray tensor_array) if (tensor_or_tensor_array is TensorArray tensor_array)
return tensor_array.flow; return tensor_array.flow;
@@ -141,97 +152,116 @@ namespace Tensorflow.Operations
throw new NotImplementedException("_convert_tensorarray_to_flow"); throw new NotImplementedException("_convert_tensorarray_to_flow");
} }


private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
Func<Tensor, TItem, LoopVar<TItem>> body,
LoopVar<TItem> original_loop_vars,
LoopVar<TItem> loop_vars,
TensorShape shape_invariants)
private TensorShape _get_shape_invariant(Tensor var, int[] shape = null)
{ {
var flat_loop_vars = original_loop_vars;
return var.TensorShape;
}


// Convert TensorArrays to their flow variables
var loop_vars_tensor = nest.map_structure(
_convert_tensorarray_to_flow,
nest.flatten2(loop_vars));
/// <summary>
/// Add the loop termination condition and body to the graph.
/// </summary>
/// <typeparam name="TItem"></typeparam>
/// <param name="pred"></param>
/// <param name="body"></param>
/// <param name="original_loop_vars"></param>
/// <param name="loop_vars"></param>
/// <param name="shape_invariants"></param>
/// <returns></returns>
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
Func<LoopVar<TItem>, LoopVar<TItem>> body,
LoopVar<TItem> original_loop_vars,
Tensor[] loop_vars,
TensorShape[] shape_invariants)
{
var flat_loop_vars = nest.flatten2(original_loop_vars)
.Select(x => (ITensorOrTensorArray)x)
.ToArray();


// Let the context know the loop variables so the loop variables // Let the context know the loop variables so the loop variables
// would be added in the outer contexts properly. // would be added in the outer contexts properly.
if (loop_vars is Tensor[] real_vars)
_InitializeValues(loop_vars);
var real_vars = loop_vars;
Tensor[] enter_vars = null;
tf_with(ops.control_dependencies(null), delegate
{ {
_InitializeValues(real_vars);
Tensor[] enter_vars = null;
tf_with(ops.control_dependencies(null), delegate
{
enter_vars = real_vars.Select(x => _Enter(x,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
use_input_shape: shape_invariants == null))
.ToArray();

foreach (var x in enter_vars)
{
x.graph.prevent_feeding(x);
if (_outer_context != null)
_outer_context.AddInnerOp(x.op);
}
});

// Finds the closest enclosing non-None control pivot.
var outer_context = _outer_context;
while (outer_context != null)
enter_vars = real_vars.Select(x => _Enter(x,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
use_input_shape: shape_invariants == null))
.ToArray();

foreach (var x in enter_vars)
{ {

x.graph.prevent_feeding(x);
if (_outer_context != null)
_outer_context.AddInnerOp(x.op);
} }
});


_SetShapeInvariants(real_vars, enter_vars, shape_invariants);

// Fix the control inputs and control flow context of these enter ops.
_FixControlInputsAndContext(enter_vars);
_InitializeValues(enter_vars);
_loop_enters = enter_vars.ToList();

var merge_vars = enter_vars
.Select(x => merge(new[] { x, x }))
.ToArray();
// Finds the closest enclosing non-None control pivot.
var outer_context = _outer_context;
object control_pivot = null;
while (outer_context != null && control_pivot == null)
{


_pivot_for_pred = merge_vars[0];
}


// Build the graph for pred.
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem)));
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
.ToArray();
if (control_pivot != null)
{


// Build the graph for body.
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
// Convert TensorArray flow variables inside the context back into
// their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
/*var body_result = body(packed_vars_for_body[0]);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

// Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };
// Convert TensorArrays returned by body into their flow variables
var result = new[] { body_result };

var next_vars = new List<Tensor>();
foreach (var (m, v) in zip(merge_vars, result))
next_vars.Add(_AddNextAndBackEdge(m, v));

// Add the exit ops.
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
_loop_exits = exit_vars;

// Exit the loop.
// ExitResult(exit_vars);
return (original_body_result, exit_vars.ToArray());*/
} }


throw new NotImplementedException("");
_SetShapeInvariants(real_vars, enter_vars, shape_invariants);

// Fix the control inputs and control flow context of these enter ops.
_FixControlInputsAndContext(enter_vars);
_InitializeValues(enter_vars);
_loop_enters = enter_vars.ToList();

var merge_vars = enter_vars
.Select(x => merge(new[] { x, x }))
.ToArray();

_pivot_for_pred = merge_vars[0];

// Build the graph for pred.
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
//var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true);
var packed_vars = new LoopVar<TItem>((Tensor)merge_vars_with_tensor_arrays[0],
(TItem)(object)new BodyItemInRnnWhileLoop((Tensor)merge_vars_with_tensor_arrays[1],
new[] { (TensorArray)merge_vars_with_tensor_arrays[2] },
(Tensor)merge_vars_with_tensor_arrays[3]));
var pp = pred(packed_vars);
var c = ops.convert_to_tensor(pp);
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
.ToArray();

// Build the graph for body.
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
// Convert TensorArray flow variables inside the context back into
// their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var body_result = body(original_loop_vars);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

// Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };
// Convert TensorArrays returned by body into their flow variables
var result = new[] { body_result };

var next_vars = new List<Tensor>();
//foreach (var (m, v) in zip(merge_vars, result))
//next_vars.Add(_AddNextAndBackEdge(m, v));

// Add the exit ops.
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
_loop_exits = exit_vars;

// Exit the loop.
// ExitResult(exit_vars);
return (null, exit_vars.ToArray());
} }


private void _FixControlInputsAndContext(Tensor[] enters) private void _FixControlInputsAndContext(Tensor[] enters)
@@ -258,6 +288,23 @@ namespace Tensorflow.Operations
_values.Add(x.name); _values.Add(x.name);
} }


public override Tensor AddValue(Tensor val)
{
var result = val;
var new_value = _values.Contains(val.name);
new_value &= val.op._get_control_flow_context() != this;
if (new_value)
throw new NotImplementedException("");
else
{
var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null;
if (actual_val != null)
result = actual_val as Tensor;
}

return result;
}

public override WhileContext GetWhileContext() public override WhileContext GetWhileContext()
{ {
return this; return this;


Loading…
Cancel
Save