Browse Source

IsLoopConstantEnter

tags/v0.12
Oceania2018 6 years ago
parent
commit
8243807ede
13 changed files with 272 additions and 26 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Interfaces/IFlatten.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Interfaces/IPackable.cs
  3. +28
    -8
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  4. +11
    -3
      src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs
  5. +138
    -11
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  6. +8
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs
  7. +6
    -1
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  8. +15
    -0
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  9. +2
    -1
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  10. +20
    -0
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  11. +22
    -0
      src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
  12. +3
    -0
      src/TensorFlowNET.Core/Tensors/TensorArray.cs
  13. +7
    -0
      src/TensorFlowNET.Core/Util/nest.py.cs

src/TensorFlowNET.Core/Util/IFlatten.cs → src/TensorFlowNET.Core/Interfaces/IFlatten.cs View File

@@ -2,7 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Operations
namespace Tensorflow
{ {
public interface ICanBeFlattened public interface ICanBeFlattened
{ {

+ 11
- 0
src/TensorFlowNET.Core/Interfaces/IPackable.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface IPackable
{
void Pack(object[] sequences);
}
}

+ 28
- 8
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -170,7 +170,7 @@ namespace Tensorflow.Operations
/// <summary> /// <summary>
/// Add `op` to the current context. /// Add `op` to the current context.
/// </summary> /// </summary>
public void AddOp(Operation op)
public virtual void AddOp(Operation op)
{ {
_AddOpInternal(op); _AddOpInternal(op);
} }
@@ -210,11 +210,6 @@ namespace Tensorflow.Operations
/// </summary> /// </summary>
protected virtual void _AddOpInternal(Operation op) protected virtual void _AddOpInternal(Operation op)
{ {
if (op.name == "rnn/while/Less")
{

}

if(op == null) if(op == null)
{ {
throw new NotImplementedException(""); throw new NotImplementedException("");
@@ -255,9 +250,34 @@ namespace Tensorflow.Operations
throw new NotImplementedException("_IsInOuterContext"); throw new NotImplementedException("_IsInOuterContext");
} }


protected virtual void _RemoveExternalControlEdges(Operation op)
/// <summary>
/// Remove any external control dependency on this op.
/// </summary>
/// <param name="op"></param>
protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operation op)
{ {
var internal_control_inputs = op.control_inputs;
var while_ctxt = GetWhileContext();

var internal_control_inputs = new List<Operation>();
// A control input of `op` is internal if it is in the same while
// loop context as the enclosing while loop context of self.
if (while_ctxt == null)
{
internal_control_inputs = op.control_inputs.ToList();
}
else
{
foreach(Tensor x in op.control_inputs)
{
throw new NotImplementedException("");
}
}

var external_control_inputs = new List<Operation>();
if (len(internal_control_inputs) != len(op.control_inputs))
throw new NotImplementedException("");

return (internal_control_inputs.ToArray(), external_control_inputs.ToArray());
} }


/// <summary> /// <summary>


+ 11
- 3
src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs View File

@@ -1,13 +1,14 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
internal class LoopVar<TItem> : ICanBeFlattened
internal class LoopVar<TItem> : ICanBeFlattened, IPackable
{ {
public Tensor Counter { get; }
public TItem Item { get; }
public Tensor Counter { get; set; }
public TItem Item { get; set; }


public LoopVar(Tensor counter, TItem item) public LoopVar(Tensor counter, TItem item)
{ {
@@ -25,6 +26,13 @@ namespace Tensorflow.Operations
return elements.ToArray(); return elements.ToArray();
} }


public void Pack(object[] sequences)
{
Counter = sequences[0] as Tensor;
if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null)
(Item as IPackable).Pack(sequences.Skip(1).ToArray());
}

public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar) public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar)
{ {
return (loopVar.Counter, loopVar.Item); return (loopVar.Counter, loopVar.Item);


+ 138
- 11
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -240,10 +240,13 @@ namespace Tensorflow.Operations


// Build the graph for body. // Build the graph for body.
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
_pivot_for_body = vars_for_body[0];
// Convert TensorArray flow variables inside the context back into // Convert TensorArray flow variables inside the context back into
// their associated TensorArrays for calling the body. // their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var body_result = body(original_loop_vars);
var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays);
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
var body_result = body(packed_vars_for_body);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);


// Store body_result to keep track of TensorArrays returned by body // Store body_result to keep track of TensorArrays returned by body
@@ -267,17 +270,27 @@ namespace Tensorflow.Operations
private void _FixControlInputsAndContext(Tensor[] enters) private void _FixControlInputsAndContext(Tensor[] enters)
{ {
var graph = ops.get_default_graph(); var graph = ops.get_default_graph();
foreach(var e in enters)
foreach(var x in enters)
{ {
var inp_op = e.op.inputs[0].op;
var inp_op = x.op.inputs[0].op;
var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op });
var outer_control_inputs = new List<Operation>();
foreach(Operation op in control_inputs)
{
// We need to keep control inputs that are in any ancestor
// ControlFlowContext, and within outer WhileContext.
var keep_as_control_input = true;
var op_ctxt = control_flow_util.GetOutputContext(op);
var outer_ctxt = outer_context;
throw new NotImplementedException("");
}
// op for op in control_inputs if self._IsInOuterContext(op) // op for op in control_inputs if self._IsInOuterContext(op)
var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
/*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
.Select(x => x.op) .Select(x => x.op)
.ToArray();
e.op._set_control_flow_context(this);
e.op._add_control_inputs(outer_control_inputs);
graph._record_op_seen_by_control_dependencies(e.op);
.ToArray();*/
x.op._set_control_flow_context(this);
x.op._add_control_inputs(outer_control_inputs.ToArray());
graph._record_op_seen_by_control_dependencies(x.op);
} }
} }


@@ -288,13 +301,127 @@ namespace Tensorflow.Operations
_values.Add(x.name); _values.Add(x.name);
} }


protected override void _AddOpInternal(Operation op)
{
Operation[] external_inputs = new Operation[0];
if (op == null)
{
throw new NotImplementedException("");
}
else
{
foreach (var index in range(len(op.inputs)))
{
var x = op.inputs[index];
var real_x = AddValue(x);
if (real_x != x)
op._update_input(index, real_x);
}

// Remove any external control dependency on this op.
(_, external_inputs) = _RemoveExternalControlEdges(op);
// Add a control dependency to prevent loop invariants from
// enabling ops that should not be executed.
_MaybeAddControlDependency(op);
foreach (Tensor x in op.outputs)
_values.Add(x.name);
}

if (external_inputs.Length > 0)
{
throw new NotImplementedException("external_inputs.Length > 0");
}

if (_outer_context != null || !IsLoopExit(op))
foreach (Tensor x in op.outputs)
op.graph.prevent_feeding(x);

if (_outer_context != null)
_outer_context.AddInnerOp(op);
}

protected void _MaybeAddControlDependency(Operation op)
{
// Determines if `op` needs a control dependency.
Func<Operation, bool> _IsOpFree = (op1) =>
{
if (op1.control_inputs.Length > 0)
return false;

if (op1.type == "SymbolicGradient")
return true;

foreach (Tensor x in op1.inputs)
if (!control_flow_util.IsLoopConstantEnter(x.op))
return false;

return true;
};

if (_IsOpFree(op))
op._add_control_input(GetControlPivot().op);
}

private Tensor GetControlPivot()
{
if (_pivot_for_body != null)
return _pivot_for_body;
return _pivot_for_pred;
}

public override void AddOp(Operation op)
{
_AddOpInternal(op);
}

public override Tensor AddValue(Tensor val) public override Tensor AddValue(Tensor val)
{ {
var result = val; var result = val;
var new_value = _values.Contains(val.name);
var new_value = !_values.Contains(val.name);
new_value &= val.op._get_control_flow_context() != this; new_value &= val.op._get_control_flow_context() != this;
if (new_value) if (new_value)
throw new NotImplementedException("");
{
_values.Add(val.name);

// If we are in a grad context and val is from its forward context,
// use GetRealValue(), which adds the logic to save the history of
// val in forward.
var grad_ctxt = ops.get_default_graph()._get_control_flow_context();
if(grad_ctxt != null)
{
grad_ctxt = grad_ctxt.GetWhileContext();
if (grad_ctxt.grad_state != null)
{
throw new NotImplementedException("");
}
}

if (_outer_context != null)
{
result = _outer_context.AddValue(val);
}

// Create an Enter to make `result` known to this loop context.
Tensor enter = null;
tf_with(ops.control_dependencies(new ITensorOrOperation[0]), delegate
{
enter = _Enter(
result,
_name,
is_constant: true,
parallel_iterations: _parallel_iterations);
enter.graph.prevent_feeding(enter);
if (_outer_context != null)
_outer_context.AddInnerOp(enter.op);
});

// Fix the control inputs and control flow context of these enter ops.
_FixControlInputsAndContext(new[] { enter });
// Add `enter` in this context.
_values.Add(enter.name);
_external_values[val.name] = enter;
result = enter;
}
else else
{ {
var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null;


+ 8
- 1
src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs View File

@@ -4,7 +4,7 @@ using System.Text;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
internal class BodyItemInRnnWhileLoop : ICanBeFlattened
internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable
{ {
/// <summary> /// <summary>
/// int32 scalar Tensor. /// int32 scalar Tensor.
@@ -36,5 +36,12 @@ namespace Tensorflow.Operations
elements.Add(state); elements.Add(state);
return elements.ToArray(); return elements.ToArray();
} }

public void Pack(object[] sequences)
{
time = sequences[0] as Tensor;
output_ta_t = new[] { sequences[1] as TensorArray };
state = sequences[2] as Tensor;
}
} }
} }

+ 6
- 1
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -192,7 +192,12 @@ namespace Tensorflow.Operations
// Take a time step of the dynamic RNN. // Take a time step of the dynamic RNN.
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
{ {
throw new NotImplementedException("");
if (in_graph_mode)
{
input_ta.Select(ta => ta.read(time)).ToArray();
}

return item;
}; };


control_flow_ops.while_loop( control_flow_ops.while_loop(


+ 15
- 0
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -159,5 +159,20 @@ namespace Tensorflow.Operations
{ {
_colocate_with.Add(value); _colocate_with.Add(value);
} }

public Tensor read(Tensor index, string name = null)
{
var value = gen_data_flow_ops.tensor_array_read_v3(
handle: _handle,
index: index,
flow_in: _flow,
dtype: _dtype,
name: name);

if (_element_shape != null)
value.set_shape(_element_shape[0].dims);

return value;
}
} }
} }

+ 2
- 1
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -648,7 +648,8 @@ namespace Tensorflow
body_buildloop = (item) => body_buildloop = (item) =>
{ {
var (i, lv) = (item.Counter, item.Item); var (i, lv) = (item.Counter, item.Item);
return new LoopVar<TItem>(i + 1, orig_body(lv));
var ob = orig_body(lv);
return new LoopVar<TItem>(i + 1, ob);
}; };
} }
try_to_pack = false; try_to_pack = false;


+ 20
- 0
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -30,6 +30,26 @@ namespace Tensorflow
public static bool IsLoopExit(Operation op) public static bool IsLoopExit(Operation op)
{ {
return op.type == "Exit" || op.type == "RefExit"; return op.type == "Exit" || op.type == "RefExit";
}
/// <summary>
/// Returns true if `op` is an Enter.
/// </summary>
/// <param name="op"></param>
/// <returns></returns>
public static bool IsLoopEnter(Operation op)
{
return op.type == "Enter" || op.type == "RefEnter";
}

/// <summary>
/// Return true iff op is a loop invariant.
/// </summary>
/// <param name="op"></param>
/// <returns></returns>
public static bool IsLoopConstantEnter(Operation op)
{
return IsLoopEnter(op) && op.get_attr<bool>("is_constant");
} }


/// <summary> /// <summary>


+ 22
- 0
src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs View File

@@ -198,5 +198,27 @@ namespace Tensorflow


return _op.outputs; return _op.outputs;
} }

/// <summary>
/// Read an element from the TensorArray into output `value`.
/// </summary>
/// <param name="handle"></param>
/// <param name="index"></param>
/// <param name="flow_in"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor tensor_array_read_v3(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = null)
{
var _op = _op_def_lib._apply_op_helper("TensorArrayReadV3", name, new
{
handle,
index,
flow_in,
dtype
});

return _op.output;
}
} }
} }

+ 3
- 0
src/TensorFlowNET.Core/Tensors/TensorArray.cs View File

@@ -58,5 +58,8 @@ namespace Tensorflow


public TensorArray unstack(Tensor value, string name = null) public TensorArray unstack(Tensor value, string name = null)
=> _implementation.unstack(value, name: name); => _implementation.unstack(value, name: name);

public Tensor read(Tensor index, string name = null)
=> _implementation.read(index, name: name);
} }
} }

+ 7
- 0
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -401,6 +401,13 @@ namespace Tensorflow.Util
private static int len(IEnumerable<object> x) => x.Count(); private static int len(IEnumerable<object> x) => x.Count();
public static T pack_sequence_as2<T>(T structure, object[] flat_sequence, bool expand_composites = false)
where T : IPackable
{
structure.Pack(flat_sequence);
return structure;
}
/// <summary> /// <summary>
/// Returns a given flattened sequence packed into a given structure. /// Returns a given flattened sequence packed into a given structure.
/// If `structure` is a scalar, `flat_sequence` must be a single-element list; /// If `structure` is a scalar, `flat_sequence` must be a single-element list;


Loading…
Cancel
Save