Browse Source

CondContext: implemented missing functionality

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
a7b76704fa
9 changed files with 381 additions and 89 deletions
  1. +19
    -3
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  2. +151
    -61
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  3. +97
    -3
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  4. +4
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
  5. +7
    -3
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  6. +29
    -2
      src/TensorFlowNET.Core/Operations/Operation.cs
  7. +44
    -13
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  8. +26
    -4
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
  9. +4
    -0
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

+ 19
- 3
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -41,11 +41,27 @@ namespace Tensorflow
{ {
var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper)); var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper));
return _get_operation_by_name_unsafe(op_name); return _get_operation_by_name_unsafe(op_name);
}

}
/// <summary>
/// Creates an `Operation` in this graph from the supplied TF_Operation.
///
/// This method is like create_op() except the new Operation is constructed
/// using `c_op`. The returned Operation will have `c_op` as its _c_op
/// field.This is used to create Operation objects around TF_Operations created
/// indirectly by the C API(e.g.by TF_ImportGraphDef, TF_FinishWhile).
///
/// This function does not call Operation._control_flow_post_processing or
/// Graph._control_dependencies_for_inputs (since the inputs may not be
/// available yet). The caller is responsible for calling these methods.
/// </summary>
/// <param name="c_op">a wrapped TF_Operation</param>
/// <param name="compute_device">(Optional.) If True, device functions will be executed
/// to compute the device property of the Operation.</param>
/// <returns>An `Operation` object.</returns>
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true)
{ {
var ret = new Operation(c_op);
var ret = new Operation(c_op, this);
_add_op(ret); _add_op(ret);


var name_key = ret.name.ToLower(); var name_key = ret.name.ToLower();


+ 151
- 61
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -16,6 +16,7 @@ namespace Tensorflow.Operations
/// The boolean tensor for the cond predicate /// The boolean tensor for the cond predicate
/// </summary> /// </summary>
private Tensor _pred; private Tensor _pred;

public Tensor pred => _pred; public Tensor pred => _pred;


/// <summary> /// <summary>
@@ -23,11 +24,6 @@ namespace Tensorflow.Operations
/// </summary> /// </summary>
private int _branch; private int _branch;


/// <summary>
///
/// </summary>
private List<string> _values = new List<string>();

private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>();


/// <summary> /// <summary>
@@ -66,72 +62,166 @@ namespace Tensorflow.Operations
} }


/// <summary> /// <summary>
/// Add the subgraph defined by fn() to the graph.
/// Add `val` to the current context and its outer context recursively.
/// </summary> /// </summary>
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
/// <param name="val"></param>
public override Tensor AddValue(Tensor val)
{ {
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
Tensor result = null;
if (_values.Contains(val.name))
{
// Use the real value if it comes from outer context. This is needed in
// particular for nested conds.
if (_external_values.ContainsKey(val.name))
result = _external_values[val.name];
else
result = val;
}
else
{
result = val;
_values.Add(val.name);
// TODO: _outer_context
if (_outer_context != null)
{
result = _outer_context.AddValue(val);
_values.Add(result.name);
_external_values[result.name] = result;
}
// TODO: how to do 'with' here??
//with(ops.control_dependencies(null), ctrl =>
//{
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);
result = new[]{r0, r1}[_branch];
if (_outer_context != null)
_outer_context.AddInnerOp(result.op);
//});


//TODO: port this chunck of missing code:
/*
if len(post_summaries) > len(pre_summaries):
new_summaries = post_summaries[len(pre_summaries):]
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
if original_result is None:
return no_op(), None
else:
original_result = nest.map_structure(array_ops.identity,
original_result)
*/
if (original_result == null)
return (original_result, null);
result.op.graph.prevent_fetching(result.op);
result.op._set_control_flow_context(this);


switch (original_result)
{
case Tensor result:
return (original_result, _BuildCondTensor(new[] { result.op }));
case Operation[] results:
return (original_result, _BuildCondTensor(results));
case float[] fv:
// Mark Switch output as seen by this context and any outer contexts,
// just like what we do for normal op outputs in _AddOpInternal() below.
IControlFlowContext ctxt = this;
while (ctxt != null)
{ {
var result = ops.convert_to_tensor(fv[0]);
return (original_result, result );
ctxt.values.Add(result.name);
ctxt = ctxt.outer_context;
} }
default:
return (original_result, null);
_external_values[val.name] = result;
} }
}

public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
return result;
}
/// <summary>
/// Add the subgraph defined by fn() to the graph.
/// </summary>
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
//TODO: port this chunck of missing code:
/*
if len(post_summaries) > len(pre_summaries):
new_summaries = post_summaries[len(pre_summaries):]
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
if original_result is None:
return no_op(), None
else:
original_result = nest.map_structure(array_ops.identity,
original_result)
*/
if (original_result == null)
return (original_result, null);
switch (original_result)
{
case Tensor result:
return (original_result, _BuildCondTensor(result));
case Operation op:
return (original_result, _BuildCondTensor(op));
case float[] fv:
{
var result = ops.convert_to_tensor(fv[0]);
return (original_result, _BuildCondTensor(result));
}
default:
return (original_result, null);
}
}
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
switch (original_result)
{
case Tensor[] results:
return (original_result, results.Select(_BuildCondTensor).ToArray());
case Operation[] results:
return (original_result, results.Select(_BuildCondTensor).ToArray());
case float[] fv:
var result = ops.convert_to_tensor(fv[0]);
return (original_result, new Tensor[] { result });
default:
return (original_result, new Tensor[0]);
}
}
private Tensor _BuildCondTensor(ITensorOrOperation v)
{
switch (v)
{
case Operation op:
// Use pivot as the proxy for this op.
return control_flow_ops.with_dependencies(new Operation[] { op }, _pivot);
case Tensor t:
return _ProcessOutputTensor(t);
default:
return _ProcessOutputTensor(ops.convert_to_tensor(v));
}
}
/// <summary>
/// Process an output tensor of a conditional branch.
/// </summary>
private Tensor _ProcessOutputTensor(Tensor val)
{ {
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);

switch (original_result)
var real_val = val;
if (!_values.Contains(val.name))
{ {
case Tensor[] results:
return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())});
case Operation[] results:
return (original_result, new Tensor[] { _BuildCondTensor (results) });
case float[] fv:
var result = ops.convert_to_tensor(fv[0]);
return (original_result, new Tensor[] { result });
default:
return (original_result, new Tensor[0]);
// Handle the special case of lambda: x
_values.Add(val.name);
if (_outer_context != null)
{
real_val = _outer_context.AddValue(val);
_values.Add(real_val.name);
_external_values[real_val.name] = real_val;
}
} }
else
{
Tensor external_val = null;
if (_external_values.ContainsKey(val.name))
external_val = _external_values[val.name];
if (external_val != null)
real_val = external_val;
}
return real_val;
} }

private Tensor _BuildCondTensor(Operation[] v)
public override void AddInnerOp(Operation resultOp)
{ {
// Use pivot as the proxy for this op.
return control_flow_ops.with_dependencies(v, _pivot);
throw new NotImplementedException();
} }
}
}
}
}

+ 97
- 3
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow.Operations namespace Tensorflow.Operations
@@ -29,6 +30,8 @@ namespace Tensorflow.Operations
protected Tensor _pivot; protected Tensor _pivot;


protected Stack<IControlFlowContext> _context_stack; protected Stack<IControlFlowContext> _context_stack;
protected IControlFlowContext _outer_context;

public ControlFlowContext() public ControlFlowContext()
{ {
_context_stack = new Stack<IControlFlowContext>(); _context_stack = new Stack<IControlFlowContext>();
@@ -69,23 +72,114 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(last_context); graph._set_control_flow_context(last_context);
} }


/// <summary>
/// Add `op` to the current context.
/// </summary>
public void AddOp(Operation op) public void AddOp(Operation op)
{ {
_AddOpInternal(op); _AddOpInternal(op);
} }


public IControlFlowContext outer_context { get { return _outer_context; } }
public HashSet<string> values => _values;
public virtual Tensor AddValue(Tensor val)
{
// to be overridden
return null;
}

public virtual void AddInnerOp(Operation resultOp)
{
// to be overridden
}

protected HashSet<string> _values = new HashSet<string>();

/// <summary>
/// Add `op` to the current context.
/// </summary>
protected virtual void _AddOpInternal(Operation op) protected virtual void _AddOpInternal(Operation op)
{ {
if(op.inputs.Length == 0)
if (op.inputs.Length == 0)
{ {
//If we're in a while loop, remove any control inputs from outside the
// loop.
_RemoveExternalControlEdges(op); _RemoveExternalControlEdges(op);
op._add_control_input(_pivot.op);
if (!op.control_inputs.Any(input_op => OpInContext(input_op)))
op._add_control_input(_pivot.op);
} }
else else
{ {
// Make each input to 'op' available in this CondContext. If an input is
// already part of this context there's nothing to do, but if it's
// external, AddValue() will handle adding the appropriate Switch node and
// other bookkeeping.
for (int index = 0; index < op.inputs.Length; index++)
{
var x = op.inputs[index];
Tensor real_x = null;
if (op.type == "Merge" && x.op.type == "NextIteration")
{
//# Edge case: if we're importing a while loop inside this CondContext,
//# AddValue() will not correctly handle the NextIteration inputs to
//# Merge node. The problem is that the NextIteration should also be
//# part of this context, but if we're importing it won't have been
//# processed and added to the context yet, so AddValue() will try to
//# add a Switch which results in an invalid graph. Instead, we use the
//# NextIteration input as-is here, and it will eventually be added to
//# the context via AddOp().
real_x = x;
}
else
{
real_x = AddValue(x);
}
if (real_x != x)
op._update_input(index, real_x);
}
// Remove any external control dependency on this op.
_RemoveExternalControlEdges(op);
// TODO: implement below code dependencies
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient")
// op._add_control_input(_pivot.op);
}
// Mark op's outputs as seen by this context and any outer contexts.
var output_names = op.outputs.Select(x => x.name).ToArray();
IControlFlowContext ctxt = this;
while (ctxt != null)
{
foreach(var name in output_names)
ctxt.values.Add(name);
ctxt = ctxt.outer_context;
}

if (_outer_context != null || !control_flow_ops.IsLoopExit(op))
op.graph.prevent_fetching(op);


if (_outer_context != null)
_outer_context.AddInnerOp(op);
}
private bool OpInContext(Operation op)
{
return IsContainingContext(op._get_control_flow_context(), this);
}
/// <summary>
/// Returns true if `maybe_containing_ctxt` is or contains `ctxt`.
/// </summary>
public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt)
{
while (ctxt != maybe_containing_ctxt)
{
if (ctxt == null)
return false;
ctxt = ctxt.outer_context;
} }
}
return true;
}



protected virtual void _RemoveExternalControlEdges(Operation op) protected virtual void _RemoveExternalControlEdges(Operation op)
{ {


+ 4
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs View File

@@ -7,5 +7,9 @@ namespace Tensorflow
public interface IControlFlowContext public interface IControlFlowContext
{ {
void AddOp(Operation op); void AddOp(Operation op);
IControlFlowContext outer_context { get; }
HashSet<string> values { get; }
Tensor AddValue(Tensor val);
void AddInnerOp(Operation resultOp);
} }
} }

+ 7
- 3
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -7,16 +7,20 @@ namespace Tensorflow
{ {
public partial class Operation public partial class Operation
{ {
private IControlFlowContext _control_flow_context;
private IControlFlowContext _control_flow_context;
/// <summary> /// <summary>
/// Add this op to its control flow context. /// Add this op to its control flow context.
///
/// This may add new ops and change this op's inputs. self.inputs must be
/// available before calling this method.
/// </summary> /// </summary>
public void _control_flow_post_processing() public void _control_flow_post_processing()
{ {
foreach(var input_tensor in inputs) foreach(var input_tensor in inputs)
{ {

//TODO: implement below code dependency
//control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
} }


if (_control_flow_context != null) if (_control_flow_context != null)


+ 29
- 2
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -62,16 +62,22 @@ namespace Tensorflow
} }
} }


public Operation(IntPtr handle)
public Operation(IntPtr handle, Graph g=null)
{ {
if (handle == IntPtr.Zero) if (handle == IntPtr.Zero)
return; return;


_handle = handle; _handle = handle;
_graph = ops.get_default_graph();
_graph = g ?? ops.get_default_graph();
_outputs = new Tensor[NumOutputs]; _outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++) for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i)); _outputs[i] = new Tensor(this, i, OutputType(i));

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();

// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
} }


public Operation(Graph g, string opType, string oper_name) public Operation(Graph g, string opType, string oper_name)
@@ -81,6 +87,10 @@ namespace Tensorflow
_operDesc = c_api.TF_NewOperation(g, opType, oper_name); _operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
_handle = c_api.TF_FinishOperation(_operDesc, status); _handle = c_api.TF_FinishOperation(_operDesc, status);

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
} }


/// <summary> /// <summary>
@@ -258,6 +268,23 @@ namespace Tensorflow
} }


return base.Equals(obj); return base.Equals(obj);
}
/// <summary>
/// Update the input to this operation at the given index.
///
/// NOTE: This is for TF internal use only.Please don't use it.
/// </summary>
/// <param name="index">the index of the input to update.</param>
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor)
{
throw new NotImplementedException("_update_input");
// TODO: implement below code dependencies
//_assert_same_graph( tensor);
//// Reset cached inputs.
//_inputs_val = null;
//c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index));
} }
} }
} }

+ 44
- 13
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -111,7 +111,7 @@ namespace Tensorflow
return loop_state; return loop_state;
} }


private static bool IsLoopExit(Operation op)
public static bool IsLoopExit(Operation op)
{ {
return op.OpType == "Exit" || op.OpType == "RefExit"; return op.OpType == "Exit" || op.OpType == "RefExit";
} }
@@ -193,20 +193,49 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name); return gen_array_ops.identity(data, name: name);
} }


/// <summary>
/// Forwards `data` to an output determined by `pred`.
/// </summary>
/// <param name="data"></param>
/// <param name="pred"></param>
/// <param name="name"></param>
/// <returns></returns>
/// <summary>
/// Forwards `data` to an output determined by `pred`.
/// If `pred` is false, the `data` input is forwarded to the first output.
/// Otherwise, the data goes to the second output.
///
/// This op handles `Tensor`s and `IndexedSlices`.
/// </summary>
/// <param name="data">The tensor to be forwarded to the appropriate output.</param>
/// <param name="pred">A scalar that specifies which output port will receive data.</param>
/// <param name="name"> A name for this operation (optional).</param>
/// <returns>
/// `(output_false, output_true)`: If `pred` is true, data will be forwarded to
/// `output_true`, otherwise it goes to `output_false`.
/// </returns>
public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
{ {
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");

data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
// NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
// addresses the following scenario.
//
// Assume you execute Optimizer.apply_gradients() in a branch of a cond().
//
// 1. The update op is created inside a `with ops.colocate(var):` block
//
// 2. Some tensor `data` is captured and a switch is created in a
// `with ops.colocate_with(data):` block.
//
// with ops.colocate_with(var):
// with ops.colocate_with(data):
// op = ...
//
// var and data may be pinned to different devices, so we want to ops
// created within ops.colocate_with(data) to ignore the existing stack.
ops.colocate_with(data, ignore_existing: true); ops.colocate_with(data, ignore_existing: true);

return @switch(data, pred, name: name);
{
if (data is Tensor)
{
// TODO: ref_switch
//if (data.dtype._is_ref_dtype)
// return control_flow_ops.ref_switch(data, pred, name = name);
}
return @switch(data, pred, name: name);
}
} }


/// <summary> /// <summary>
@@ -483,6 +512,8 @@ namespace Tensorflow
} }


throw new NotImplementedException("ZerosLikeOutsideLoop"); throw new NotImplementedException("ZerosLikeOutsideLoop");
}
}
} }
} }

+ 26
- 4
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -13,13 +13,35 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("NoOp", name, null); var _op = _op_def_lib._apply_op_helper("NoOp", name, null);


return _op; return _op;
}

}
/// <summary>
/// Forwards `data` to the output port determined by `pred`.
///
/// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise,
/// the data goes to `output_false`.
///
/// See also `RefSwitch` and `Merge`.
/// </summary>
/// <param name="data">A `Tensor`. The tensor to be forwarded to the appropriate output.</param>
/// <param name="pred">A `Tensor` of type `bool`.
/// A scalar that specifies which output port will receive data.
/// </param>
/// <param name="name"> A name for the operation (optional).</param>
/// <returns>A tuple of `Tensor` objects (output_false, output_true).
///
/// output_false: A `Tensor`. Has the same type as `data`.
/// output_true: A `Tensor`. Has the same type as `data`.
/// </returns>
public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred });

return (_op.outputs[0], _op.outputs[1]);
var _result = (_op.outputs[0], _op.outputs[1]);
var _inputs_flat = _op.inputs;
var _attrs = ("T", _op.get_attr("T"));
// TODO: missing original code
//_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name);
return _result;
} }


public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)


+ 4
- 0
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -84,6 +84,10 @@ namespace TensorFlowNET.UnitTest.ops_test
control_flow_ops.cond(x < 10, true_fn, () => x); control_flow_ops.cond(x < 10, true_fn, () => x);
var op = g.get_operation_by_name("cond/myop"); var op = g.get_operation_by_name("cond/myop");
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true);
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
self.assertIsNotNone(op); self.assertIsNotNone(op);
self.assertEqual(op.name, "cond/myop"); self.assertEqual(op.name, "cond/myop");
self.assertEqual(op.type, "Identity"); self.assertEqual(op.type, "Identity");


Loading…
Cancel
Save