diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index 2d099292..22703fbc 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -41,11 +41,27 @@ namespace Tensorflow { var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper)); return _get_operation_by_name_unsafe(op_name); - } - + } + + /// + /// Creates an `Operation` in this graph from the supplied TF_Operation. + /// + /// This method is like create_op() except the new Operation is constructed + /// using `c_op`. The returned Operation will have `c_op` as its _c_op + /// field.This is used to create Operation objects around TF_Operations created + /// indirectly by the C API(e.g.by TF_ImportGraphDef, TF_FinishWhile). + /// + /// This function does not call Operation._control_flow_post_processing or + /// Graph._control_dependencies_for_inputs (since the inputs may not be + /// available yet). The caller is responsible for calling these methods. + /// + /// a wrapped TF_Operation + /// (Optional.) If True, device functions will be executed + /// to compute the device property of the Operation. + /// An `Operation` object. public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) { - var ret = new Operation(c_op); + var ret = new Operation(c_op, this); _add_op(ret); var name_key = ret.name.ToLower(); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index f32a6fa7..5fd25faa 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -16,6 +16,7 @@ namespace Tensorflow.Operations /// The boolean tensor for the cond predicate /// private Tensor _pred; + public Tensor pred => _pred; /// @@ -23,11 +24,6 @@ namespace Tensorflow.Operations /// private int _branch; - /// - /// - /// - private List _values = new List(); - private Dictionary _external_values = new Dictionary(); /// @@ -66,72 +62,166 @@ namespace Tensorflow.Operations } /// - /// Add the subgraph defined by fn() to the graph. + /// Add `val` to the current context and its outer context recursively. /// - public (T, Tensor) BuildCondBranch(Func fn) + /// + public override Tensor AddValue(Tensor val) { - // Add the subgraph defined by fn() to the graph. - var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); - var original_result = fn(); - var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + Tensor result = null; + if (_values.Contains(val.name)) + { + // Use the real value if it comes from outer context. This is needed in + // particular for nested conds. + if (_external_values.ContainsKey(val.name)) + result = _external_values[val.name]; + else + result = val; + } + else + { + result = val; + _values.Add(val.name); + // TODO: _outer_context + if (_outer_context != null) + { + result = _outer_context.AddValue(val); + _values.Add(result.name); + _external_values[result.name] = result; + } + // TODO: how to do 'with' here?? + //with(ops.control_dependencies(null), ctrl => + //{ + var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); + result = new[]{r0, r1}[_branch]; + if (_outer_context != null) + _outer_context.AddInnerOp(result.op); + //}); - //TODO: port this chunck of missing code: - /* - if len(post_summaries) > len(pre_summaries): - new_summaries = post_summaries[len(pre_summaries):] - summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access - summary_ref[:] = pre_summaries - with ops.control_dependencies(new_summaries): - if original_result is None: - return no_op(), None - else: - original_result = nest.map_structure(array_ops.identity, - original_result) - */ - if (original_result == null) - return (original_result, null); + result.op.graph.prevent_fetching(result.op); + result.op._set_control_flow_context(this); - switch (original_result) - { - case Tensor result: - return (original_result, _BuildCondTensor(new[] { result.op })); - case Operation[] results: - return (original_result, _BuildCondTensor(results)); - case float[] fv: + // Mark Switch output as seen by this context and any outer contexts, + // just like what we do for normal op outputs in _AddOpInternal() below. + IControlFlowContext ctxt = this; + while (ctxt != null) { - var result = ops.convert_to_tensor(fv[0]); - return (original_result, result ); + ctxt.values.Add(result.name); + ctxt = ctxt.outer_context; } - default: - return (original_result, null); + _external_values[val.name] = result; } - } - - public (T[], Tensor[]) BuildCondBranch(Func fn) + return result; + } + + /// + /// Add the subgraph defined by fn() to the graph. + /// + public (T, Tensor) BuildCondBranch(Func fn) + { + // Add the subgraph defined by fn() to the graph. + var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var original_result = fn(); + var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + + //TODO: port this chunck of missing code: + /* + if len(post_summaries) > len(pre_summaries): + new_summaries = post_summaries[len(pre_summaries):] + summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access + summary_ref[:] = pre_summaries + with ops.control_dependencies(new_summaries): + if original_result is None: + return no_op(), None + else: + original_result = nest.map_structure(array_ops.identity, + original_result) + */ + if (original_result == null) + return (original_result, null); + + switch (original_result) + { + case Tensor result: + return (original_result, _BuildCondTensor(result)); + case Operation op: + return (original_result, _BuildCondTensor(op)); + case float[] fv: + { + var result = ops.convert_to_tensor(fv[0]); + return (original_result, _BuildCondTensor(result)); + } + default: + return (original_result, null); + } + } + + public (T[], Tensor[]) BuildCondBranch(Func fn) + { + // Add the subgraph defined by fn() to the graph. + var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var original_result = fn(); + var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + + switch (original_result) + { + case Tensor[] results: + return (original_result, results.Select(_BuildCondTensor).ToArray()); + case Operation[] results: + return (original_result, results.Select(_BuildCondTensor).ToArray()); + case float[] fv: + var result = ops.convert_to_tensor(fv[0]); + return (original_result, new Tensor[] { result }); + default: + return (original_result, new Tensor[0]); + } + } + + private Tensor _BuildCondTensor(ITensorOrOperation v) + { + switch (v) + { + case Operation op: + // Use pivot as the proxy for this op. + return control_flow_ops.with_dependencies(new Operation[] { op }, _pivot); + case Tensor t: + return _ProcessOutputTensor(t); + default: + return _ProcessOutputTensor(ops.convert_to_tensor(v)); + + } + } + + /// + /// Process an output tensor of a conditional branch. + /// + private Tensor _ProcessOutputTensor(Tensor val) { - // Add the subgraph defined by fn() to the graph. - var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); - var original_result = fn(); - var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); - - switch (original_result) + var real_val = val; + if (!_values.Contains(val.name)) { - case Tensor[] results: - return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())}); - case Operation[] results: - return (original_result, new Tensor[] { _BuildCondTensor (results) }); - case float[] fv: - var result = ops.convert_to_tensor(fv[0]); - return (original_result, new Tensor[] { result }); - default: - return (original_result, new Tensor[0]); + // Handle the special case of lambda: x + _values.Add(val.name); + if (_outer_context != null) + { + real_val = _outer_context.AddValue(val); + _values.Add(real_val.name); + _external_values[real_val.name] = real_val; + } } + else + { + Tensor external_val = null; + if (_external_values.ContainsKey(val.name)) + external_val = _external_values[val.name]; + if (external_val != null) + real_val = external_val; + } + return real_val; } - - private Tensor _BuildCondTensor(Operation[] v) + + public override void AddInnerOp(Operation resultOp) { - // Use pivot as the proxy for this op. - return control_flow_ops.with_dependencies(v, _pivot); + throw new NotImplementedException(); } - } -} + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 9d039d58..b7cc911d 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow.Operations @@ -29,6 +30,8 @@ namespace Tensorflow.Operations protected Tensor _pivot; protected Stack _context_stack; + protected IControlFlowContext _outer_context; + public ControlFlowContext() { _context_stack = new Stack(); @@ -69,23 +72,114 @@ namespace Tensorflow.Operations graph._set_control_flow_context(last_context); } + /// + /// Add `op` to the current context. + /// public void AddOp(Operation op) { _AddOpInternal(op); } + public IControlFlowContext outer_context { get { return _outer_context; } } + public HashSet values => _values; + public virtual Tensor AddValue(Tensor val) + { + // to be overridden + return null; + } + + public virtual void AddInnerOp(Operation resultOp) + { + // to be overridden + } + + protected HashSet _values = new HashSet(); + + /// + /// Add `op` to the current context. + /// protected virtual void _AddOpInternal(Operation op) { - if(op.inputs.Length == 0) + if (op.inputs.Length == 0) { + //If we're in a while loop, remove any control inputs from outside the + // loop. _RemoveExternalControlEdges(op); - op._add_control_input(_pivot.op); + if (!op.control_inputs.Any(input_op => OpInContext(input_op))) + op._add_control_input(_pivot.op); } else { + // Make each input to 'op' available in this CondContext. If an input is + // already part of this context there's nothing to do, but if it's + // external, AddValue() will handle adding the appropriate Switch node and + // other bookkeeping. + for (int index = 0; index < op.inputs.Length; index++) + { + var x = op.inputs[index]; + Tensor real_x = null; + if (op.type == "Merge" && x.op.type == "NextIteration") + { + //# Edge case: if we're importing a while loop inside this CondContext, + //# AddValue() will not correctly handle the NextIteration inputs to + //# Merge node. The problem is that the NextIteration should also be + //# part of this context, but if we're importing it won't have been + //# processed and added to the context yet, so AddValue() will try to + //# add a Switch which results in an invalid graph. Instead, we use the + //# NextIteration input as-is here, and it will eventually be added to + //# the context via AddOp(). + real_x = x; + } + else + { + real_x = AddValue(x); + } + if (real_x != x) + op._update_input(index, real_x); + } + // Remove any external control dependency on this op. + _RemoveExternalControlEdges(op); + // TODO: implement below code dependencies + //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") + // op._add_control_input(_pivot.op); + } + + // Mark op's outputs as seen by this context and any outer contexts. + var output_names = op.outputs.Select(x => x.name).ToArray(); + IControlFlowContext ctxt = this; + while (ctxt != null) + { + foreach(var name in output_names) + ctxt.values.Add(name); + ctxt = ctxt.outer_context; + } + + if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) + op.graph.prevent_fetching(op); + if (_outer_context != null) + _outer_context.AddInnerOp(op); + } + + private bool OpInContext(Operation op) + { + return IsContainingContext(op._get_control_flow_context(), this); + } + + /// + /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. + /// + public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) + { + while (ctxt != maybe_containing_ctxt) + { + if (ctxt == null) + return false; + ctxt = ctxt.outer_context; } - } + return true; + } + protected virtual void _RemoveExternalControlEdges(Operation op) { diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs index 6bd8c6e2..5bc34965 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs @@ -7,5 +7,9 @@ namespace Tensorflow public interface IControlFlowContext { void AddOp(Operation op); + IControlFlowContext outer_context { get; } + HashSet values { get; } + Tensor AddValue(Tensor val); + void AddInnerOp(Operation resultOp); } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 812ca9fb..aaf2937c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -7,16 +7,20 @@ namespace Tensorflow { public partial class Operation { - private IControlFlowContext _control_flow_context; - + private IControlFlowContext _control_flow_context; + /// /// Add this op to its control flow context. + /// + /// This may add new ops and change this op's inputs. self.inputs must be + /// available before calling this method. /// public void _control_flow_post_processing() { foreach(var input_tensor in inputs) { - + //TODO: implement below code dependency + //control_flow_util.CheckInputFromValidContext(this, input_tensor.op); } if (_control_flow_context != null) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index e4bccea1..245e38b5 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -62,16 +62,22 @@ namespace Tensorflow } } - public Operation(IntPtr handle) + public Operation(IntPtr handle, Graph g=null) { if (handle == IntPtr.Zero) return; _handle = handle; - _graph = ops.get_default_graph(); + _graph = g ?? ops.get_default_graph(); _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) _outputs[i] = new Tensor(this, i, OutputType(i)); + + // Dict mapping op name to file and line information for op colocation + // context managers. + _control_flow_context = graph._get_control_flow_context(); + + // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. } public Operation(Graph g, string opType, string oper_name) @@ -81,6 +87,10 @@ namespace Tensorflow _operDesc = c_api.TF_NewOperation(g, opType, oper_name); c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); _handle = c_api.TF_FinishOperation(_operDesc, status); + + // Dict mapping op name to file and line information for op colocation + // context managers. + _control_flow_context = graph._get_control_flow_context(); } /// @@ -258,6 +268,23 @@ namespace Tensorflow } return base.Equals(obj); + } + + /// + /// Update the input to this operation at the given index. + /// + /// NOTE: This is for TF internal use only.Please don't use it. + /// + /// the index of the input to update. + /// the Tensor to be used as the input at the given index. + public void _update_input(int index, Tensor tensor) + { + throw new NotImplementedException("_update_input"); + // TODO: implement below code dependencies + //_assert_same_graph( tensor); + //// Reset cached inputs. + //_inputs_val = null; + //c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index)); } } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index e0f38c95..aebcfaef 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -111,7 +111,7 @@ namespace Tensorflow return loop_state; } - private static bool IsLoopExit(Operation op) + public static bool IsLoopExit(Operation op) { return op.OpType == "Exit" || op.OpType == "RefExit"; } @@ -193,20 +193,49 @@ namespace Tensorflow return gen_array_ops.identity(data, name: name); } - /// - /// Forwards `data` to an output determined by `pred`. - /// - /// - /// - /// - /// + /// + /// Forwards `data` to an output determined by `pred`. + /// If `pred` is false, the `data` input is forwarded to the first output. + /// Otherwise, the data goes to the second output. + /// + /// This op handles `Tensor`s and `IndexedSlices`. + /// + /// The tensor to be forwarded to the appropriate output. + /// A scalar that specifies which output port will receive data. + /// A name for this operation (optional). + /// + /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to + /// `output_true`, otherwise it goes to `output_false`. + /// public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") { - data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); - + data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); + // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below + // addresses the following scenario. + // + // Assume you execute Optimizer.apply_gradients() in a branch of a cond(). + // + // 1. The update op is created inside a `with ops.colocate(var):` block + // + // 2. Some tensor `data` is captured and a switch is created in a + // `with ops.colocate_with(data):` block. + // + // with ops.colocate_with(var): + // with ops.colocate_with(data): + // op = ... + // + // var and data may be pinned to different devices, so we want to ops + // created within ops.colocate_with(data) to ignore the existing stack. ops.colocate_with(data, ignore_existing: true); - - return @switch(data, pred, name: name); + { + if (data is Tensor) + { + // TODO: ref_switch + //if (data.dtype._is_ref_dtype) + // return control_flow_ops.ref_switch(data, pred, name = name); + } + return @switch(data, pred, name: name); + } } /// @@ -483,6 +512,8 @@ namespace Tensorflow } throw new NotImplementedException("ZerosLikeOutsideLoop"); - } + } + + } } diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs index 21447c57..21daf844 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs @@ -13,13 +13,35 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("NoOp", name, null); return _op; - } - + } + + /// + /// Forwards `data` to the output port determined by `pred`. + /// + /// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, + /// the data goes to `output_false`. + /// + /// See also `RefSwitch` and `Merge`. + /// + /// A `Tensor`. The tensor to be forwarded to the appropriate output. + /// A `Tensor` of type `bool`. + /// A scalar that specifies which output port will receive data. + /// + /// A name for the operation (optional). + /// A tuple of `Tensor` objects (output_false, output_true). + /// + /// output_false: A `Tensor`. Has the same type as `data`. + /// output_true: A `Tensor`. Has the same type as `data`. + /// public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) { var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); - - return (_op.outputs[0], _op.outputs[1]); + var _result = (_op.outputs[0], _op.outputs[1]); + var _inputs_flat = _op.inputs; + var _attrs = ("T", _op.get_attr("T")); + // TODO: missing original code + //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); + return _result; } public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs index 048d9bac..bace1dde 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs @@ -84,6 +84,10 @@ namespace TensorFlowNET.UnitTest.ops_test control_flow_ops.cond(x < 10, true_fn, () => x); var op = g.get_operation_by_name("cond/myop"); + + tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true); + tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); + self.assertIsNotNone(op); self.assertEqual(op.name, "cond/myop"); self.assertEqual(op.type, "Identity");