diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
index 2d099292..22703fbc 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
@@ -41,11 +41,27 @@ namespace Tensorflow
{
var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper));
return _get_operation_by_name_unsafe(op_name);
- }
-
+ }
+
+ ///
+ /// Creates an `Operation` in this graph from the supplied TF_Operation.
+ ///
+ /// This method is like create_op() except the new Operation is constructed
+ /// using `c_op`. The returned Operation will have `c_op` as its _c_op
+ /// field.This is used to create Operation objects around TF_Operations created
+ /// indirectly by the C API(e.g.by TF_ImportGraphDef, TF_FinishWhile).
+ ///
+ /// This function does not call Operation._control_flow_post_processing or
+ /// Graph._control_dependencies_for_inputs (since the inputs may not be
+ /// available yet). The caller is responsible for calling these methods.
+ ///
+ /// a wrapped TF_Operation
+ /// (Optional.) If True, device functions will be executed
+ /// to compute the device property of the Operation.
+ /// An `Operation` object.
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true)
{
- var ret = new Operation(c_op);
+ var ret = new Operation(c_op, this);
_add_op(ret);
var name_key = ret.name.ToLower();
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
index f32a6fa7..5fd25faa 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -16,6 +16,7 @@ namespace Tensorflow.Operations
/// The boolean tensor for the cond predicate
///
private Tensor _pred;
+
public Tensor pred => _pred;
///
@@ -23,11 +24,6 @@ namespace Tensorflow.Operations
///
private int _branch;
- ///
- ///
- ///
- private List _values = new List();
-
private Dictionary _external_values = new Dictionary();
///
@@ -66,72 +62,166 @@ namespace Tensorflow.Operations
}
///
- /// Add the subgraph defined by fn() to the graph.
+ /// Add `val` to the current context and its outer context recursively.
///
- public (T, Tensor) BuildCondBranch(Func fn)
+ ///
+ public override Tensor AddValue(Tensor val)
{
- // Add the subgraph defined by fn() to the graph.
- var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
- var original_result = fn();
- var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
+ Tensor result = null;
+ if (_values.Contains(val.name))
+ {
+ // Use the real value if it comes from outer context. This is needed in
+ // particular for nested conds.
+ if (_external_values.ContainsKey(val.name))
+ result = _external_values[val.name];
+ else
+ result = val;
+ }
+ else
+ {
+ result = val;
+ _values.Add(val.name);
+ // TODO: _outer_context
+ if (_outer_context != null)
+ {
+ result = _outer_context.AddValue(val);
+ _values.Add(result.name);
+ _external_values[result.name] = result;
+ }
+ // TODO: how to do 'with' here??
+ //with(ops.control_dependencies(null), ctrl =>
+ //{
+ var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);
+ result = new[]{r0, r1}[_branch];
+ if (_outer_context != null)
+ _outer_context.AddInnerOp(result.op);
+ //});
- //TODO: port this chunck of missing code:
- /*
- if len(post_summaries) > len(pre_summaries):
- new_summaries = post_summaries[len(pre_summaries):]
- summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
- summary_ref[:] = pre_summaries
- with ops.control_dependencies(new_summaries):
- if original_result is None:
- return no_op(), None
- else:
- original_result = nest.map_structure(array_ops.identity,
- original_result)
- */
- if (original_result == null)
- return (original_result, null);
+ result.op.graph.prevent_fetching(result.op);
+ result.op._set_control_flow_context(this);
- switch (original_result)
- {
- case Tensor result:
- return (original_result, _BuildCondTensor(new[] { result.op }));
- case Operation[] results:
- return (original_result, _BuildCondTensor(results));
- case float[] fv:
+ // Mark Switch output as seen by this context and any outer contexts,
+ // just like what we do for normal op outputs in _AddOpInternal() below.
+ IControlFlowContext ctxt = this;
+ while (ctxt != null)
{
- var result = ops.convert_to_tensor(fv[0]);
- return (original_result, result );
+ ctxt.values.Add(result.name);
+ ctxt = ctxt.outer_context;
}
- default:
- return (original_result, null);
+ _external_values[val.name] = result;
}
- }
-
- public (T[], Tensor[]) BuildCondBranch(Func fn)
+ return result;
+ }
+
+ ///
+ /// Add the subgraph defined by fn() to the graph.
+ ///
+ public (T, Tensor) BuildCondBranch(Func fn)
+ {
+ // Add the subgraph defined by fn() to the graph.
+ var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
+ var original_result = fn();
+ var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
+
+ //TODO: port this chunck of missing code:
+ /*
+ if len(post_summaries) > len(pre_summaries):
+ new_summaries = post_summaries[len(pre_summaries):]
+ summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
+ summary_ref[:] = pre_summaries
+ with ops.control_dependencies(new_summaries):
+ if original_result is None:
+ return no_op(), None
+ else:
+ original_result = nest.map_structure(array_ops.identity,
+ original_result)
+ */
+ if (original_result == null)
+ return (original_result, null);
+
+ switch (original_result)
+ {
+ case Tensor result:
+ return (original_result, _BuildCondTensor(result));
+ case Operation op:
+ return (original_result, _BuildCondTensor(op));
+ case float[] fv:
+ {
+ var result = ops.convert_to_tensor(fv[0]);
+ return (original_result, _BuildCondTensor(result));
+ }
+ default:
+ return (original_result, null);
+ }
+ }
+
+ public (T[], Tensor[]) BuildCondBranch(Func fn)
+ {
+ // Add the subgraph defined by fn() to the graph.
+ var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
+ var original_result = fn();
+ var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
+
+ switch (original_result)
+ {
+ case Tensor[] results:
+ return (original_result, results.Select(_BuildCondTensor).ToArray());
+ case Operation[] results:
+ return (original_result, results.Select(_BuildCondTensor).ToArray());
+ case float[] fv:
+ var result = ops.convert_to_tensor(fv[0]);
+ return (original_result, new Tensor[] { result });
+ default:
+ return (original_result, new Tensor[0]);
+ }
+ }
+
+ private Tensor _BuildCondTensor(ITensorOrOperation v)
+ {
+ switch (v)
+ {
+ case Operation op:
+ // Use pivot as the proxy for this op.
+ return control_flow_ops.with_dependencies(new Operation[] { op }, _pivot);
+ case Tensor t:
+ return _ProcessOutputTensor(t);
+ default:
+ return _ProcessOutputTensor(ops.convert_to_tensor(v));
+
+ }
+ }
+
+ ///
+ /// Process an output tensor of a conditional branch.
+ ///
+ private Tensor _ProcessOutputTensor(Tensor val)
{
- // Add the subgraph defined by fn() to the graph.
- var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
- var original_result = fn();
- var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
-
- switch (original_result)
+ var real_val = val;
+ if (!_values.Contains(val.name))
{
- case Tensor[] results:
- return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())});
- case Operation[] results:
- return (original_result, new Tensor[] { _BuildCondTensor (results) });
- case float[] fv:
- var result = ops.convert_to_tensor(fv[0]);
- return (original_result, new Tensor[] { result });
- default:
- return (original_result, new Tensor[0]);
+ // Handle the special case of lambda: x
+ _values.Add(val.name);
+ if (_outer_context != null)
+ {
+ real_val = _outer_context.AddValue(val);
+ _values.Add(real_val.name);
+ _external_values[real_val.name] = real_val;
+ }
}
+ else
+ {
+ Tensor external_val = null;
+ if (_external_values.ContainsKey(val.name))
+ external_val = _external_values[val.name];
+ if (external_val != null)
+ real_val = external_val;
+ }
+ return real_val;
}
-
- private Tensor _BuildCondTensor(Operation[] v)
+
+ public override void AddInnerOp(Operation resultOp)
{
- // Use pivot as the proxy for this op.
- return control_flow_ops.with_dependencies(v, _pivot);
+ throw new NotImplementedException();
}
- }
-}
+ }
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 9d039d58..b7cc911d 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
namespace Tensorflow.Operations
@@ -29,6 +30,8 @@ namespace Tensorflow.Operations
protected Tensor _pivot;
protected Stack _context_stack;
+ protected IControlFlowContext _outer_context;
+
public ControlFlowContext()
{
_context_stack = new Stack();
@@ -69,23 +72,114 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(last_context);
}
+ ///
+ /// Add `op` to the current context.
+ ///
public void AddOp(Operation op)
{
_AddOpInternal(op);
}
+ public IControlFlowContext outer_context { get { return _outer_context; } }
+ public HashSet values => _values;
+ public virtual Tensor AddValue(Tensor val)
+ {
+ // to be overridden
+ return null;
+ }
+
+ public virtual void AddInnerOp(Operation resultOp)
+ {
+ // to be overridden
+ }
+
+ protected HashSet _values = new HashSet();
+
+ ///
+ /// Add `op` to the current context.
+ ///
protected virtual void _AddOpInternal(Operation op)
{
- if(op.inputs.Length == 0)
+ if (op.inputs.Length == 0)
{
+ //If we're in a while loop, remove any control inputs from outside the
+ // loop.
_RemoveExternalControlEdges(op);
- op._add_control_input(_pivot.op);
+ if (!op.control_inputs.Any(input_op => OpInContext(input_op)))
+ op._add_control_input(_pivot.op);
}
else
{
+ // Make each input to 'op' available in this CondContext. If an input is
+ // already part of this context there's nothing to do, but if it's
+ // external, AddValue() will handle adding the appropriate Switch node and
+ // other bookkeeping.
+ for (int index = 0; index < op.inputs.Length; index++)
+ {
+ var x = op.inputs[index];
+ Tensor real_x = null;
+ if (op.type == "Merge" && x.op.type == "NextIteration")
+ {
+ //# Edge case: if we're importing a while loop inside this CondContext,
+ //# AddValue() will not correctly handle the NextIteration inputs to
+ //# Merge node. The problem is that the NextIteration should also be
+ //# part of this context, but if we're importing it won't have been
+ //# processed and added to the context yet, so AddValue() will try to
+ //# add a Switch which results in an invalid graph. Instead, we use the
+ //# NextIteration input as-is here, and it will eventually be added to
+ //# the context via AddOp().
+ real_x = x;
+ }
+ else
+ {
+ real_x = AddValue(x);
+ }
+ if (real_x != x)
+ op._update_input(index, real_x);
+ }
+ // Remove any external control dependency on this op.
+ _RemoveExternalControlEdges(op);
+ // TODO: implement below code dependencies
+ //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient")
+ // op._add_control_input(_pivot.op);
+ }
+
+ // Mark op's outputs as seen by this context and any outer contexts.
+ var output_names = op.outputs.Select(x => x.name).ToArray();
+ IControlFlowContext ctxt = this;
+ while (ctxt != null)
+ {
+ foreach(var name in output_names)
+ ctxt.values.Add(name);
+ ctxt = ctxt.outer_context;
+ }
+
+ if (_outer_context != null || !control_flow_ops.IsLoopExit(op))
+ op.graph.prevent_fetching(op);
+ if (_outer_context != null)
+ _outer_context.AddInnerOp(op);
+ }
+
+ private bool OpInContext(Operation op)
+ {
+ return IsContainingContext(op._get_control_flow_context(), this);
+ }
+
+ ///
+ /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`.
+ ///
+ public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt)
+ {
+ while (ctxt != maybe_containing_ctxt)
+ {
+ if (ctxt == null)
+ return false;
+ ctxt = ctxt.outer_context;
}
- }
+ return true;
+ }
+
protected virtual void _RemoveExternalControlEdges(Operation op)
{
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
index 6bd8c6e2..5bc34965 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
@@ -7,5 +7,9 @@ namespace Tensorflow
public interface IControlFlowContext
{
void AddOp(Operation op);
+ IControlFlowContext outer_context { get; }
+ HashSet values { get; }
+ Tensor AddValue(Tensor val);
+ void AddInnerOp(Operation resultOp);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index 812ca9fb..aaf2937c 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -7,16 +7,20 @@ namespace Tensorflow
{
public partial class Operation
{
- private IControlFlowContext _control_flow_context;
-
+ private IControlFlowContext _control_flow_context;
+
///
/// Add this op to its control flow context.
+ ///
+ /// This may add new ops and change this op's inputs. self.inputs must be
+ /// available before calling this method.
///
public void _control_flow_post_processing()
{
foreach(var input_tensor in inputs)
{
-
+ //TODO: implement below code dependency
+ //control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
}
if (_control_flow_context != null)
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index e4bccea1..245e38b5 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -62,16 +62,22 @@ namespace Tensorflow
}
}
- public Operation(IntPtr handle)
+ public Operation(IntPtr handle, Graph g=null)
{
if (handle == IntPtr.Zero)
return;
_handle = handle;
- _graph = ops.get_default_graph();
+ _graph = g ?? ops.get_default_graph();
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
+
+ // Dict mapping op name to file and line information for op colocation
+ // context managers.
+ _control_flow_context = graph._get_control_flow_context();
+
+ // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
}
public Operation(Graph g, string opType, string oper_name)
@@ -81,6 +87,10 @@ namespace Tensorflow
_operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
_handle = c_api.TF_FinishOperation(_operDesc, status);
+
+ // Dict mapping op name to file and line information for op colocation
+ // context managers.
+ _control_flow_context = graph._get_control_flow_context();
}
///
@@ -258,6 +268,23 @@ namespace Tensorflow
}
return base.Equals(obj);
+ }
+
+ ///
+ /// Update the input to this operation at the given index.
+ ///
+ /// NOTE: This is for TF internal use only.Please don't use it.
+ ///
+ /// the index of the input to update.
+ /// the Tensor to be used as the input at the given index.
+ public void _update_input(int index, Tensor tensor)
+ {
+ throw new NotImplementedException("_update_input");
+ // TODO: implement below code dependencies
+ //_assert_same_graph( tensor);
+ //// Reset cached inputs.
+ //_inputs_val = null;
+ //c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index));
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index e0f38c95..aebcfaef 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -111,7 +111,7 @@ namespace Tensorflow
return loop_state;
}
- private static bool IsLoopExit(Operation op)
+ public static bool IsLoopExit(Operation op)
{
return op.OpType == "Exit" || op.OpType == "RefExit";
}
@@ -193,20 +193,49 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name);
}
- ///
- /// Forwards `data` to an output determined by `pred`.
- ///
- ///
- ///
- ///
- ///
+ ///
+ /// Forwards `data` to an output determined by `pred`.
+ /// If `pred` is false, the `data` input is forwarded to the first output.
+ /// Otherwise, the data goes to the second output.
+ ///
+ /// This op handles `Tensor`s and `IndexedSlices`.
+ ///
+ /// The tensor to be forwarded to the appropriate output.
+ /// A scalar that specifies which output port will receive data.
+ /// A name for this operation (optional).
+ ///
+ /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to
+ /// `output_true`, otherwise it goes to `output_false`.
+ ///
public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
{
- data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
-
+ data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
+ // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
+ // addresses the following scenario.
+ //
+ // Assume you execute Optimizer.apply_gradients() in a branch of a cond().
+ //
+ // 1. The update op is created inside a `with ops.colocate(var):` block
+ //
+ // 2. Some tensor `data` is captured and a switch is created in a
+ // `with ops.colocate_with(data):` block.
+ //
+ // with ops.colocate_with(var):
+ // with ops.colocate_with(data):
+ // op = ...
+ //
+ // var and data may be pinned to different devices, so we want to ops
+ // created within ops.colocate_with(data) to ignore the existing stack.
ops.colocate_with(data, ignore_existing: true);
-
- return @switch(data, pred, name: name);
+ {
+ if (data is Tensor)
+ {
+ // TODO: ref_switch
+ //if (data.dtype._is_ref_dtype)
+ // return control_flow_ops.ref_switch(data, pred, name = name);
+ }
+ return @switch(data, pred, name: name);
+ }
}
///
@@ -483,6 +512,8 @@ namespace Tensorflow
}
throw new NotImplementedException("ZerosLikeOutsideLoop");
- }
+ }
+
+
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
index 21447c57..21daf844 100644
--- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
@@ -13,13 +13,35 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("NoOp", name, null);
return _op;
- }
-
+ }
+
+ ///
+ /// Forwards `data` to the output port determined by `pred`.
+ ///
+ /// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise,
+ /// the data goes to `output_false`.
+ ///
+ /// See also `RefSwitch` and `Merge`.
+ ///
+ /// A `Tensor`. The tensor to be forwarded to the appropriate output.
+ /// A `Tensor` of type `bool`.
+ /// A scalar that specifies which output port will receive data.
+ ///
+ /// A name for the operation (optional).
+ /// A tuple of `Tensor` objects (output_false, output_true).
+ ///
+ /// output_false: A `Tensor`. Has the same type as `data`.
+ /// output_true: A `Tensor`. Has the same type as `data`.
+ ///
public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred });
-
- return (_op.outputs[0], _op.outputs[1]);
+ var _result = (_op.outputs[0], _op.outputs[1]);
+ var _inputs_flat = _op.inputs;
+ var _attrs = ("T", _op.get_attr("T"));
+ // TODO: missing original code
+ //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name);
+ return _result;
}
public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)
diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs
index 048d9bac..bace1dde 100644
--- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs
+++ b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs
@@ -84,6 +84,10 @@ namespace TensorFlowNET.UnitTest.ops_test
control_flow_ops.cond(x < 10, true_fn, () => x);
var op = g.get_operation_by_name("cond/myop");
+
+ tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true);
+ tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
+
self.assertIsNotNone(op);
self.assertEqual(op.name, "cond/myop");
self.assertEqual(op.type, "Identity");