From 8f42762f1c65990f56b56cf0f0742a888ec20a08 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 27 Jan 2019 08:38:21 -0600 Subject: [PATCH] Add control_inputs in Operation #141 --- .../Graphs/Graph.Control.cs | 105 ++++++++++++++++++ src/TensorFlowNET.Core/Graphs/Graph.cs | 12 +- .../Graphs/_ControlDependenciesController.cs | 80 +++++++++++++ .../Operations/Operation.Input.cs | 8 ++ .../Operations/Operation.cs | 42 ++++++- .../Operations/c_api.ops.cs | 8 ++ .../Operations/control_flow_ops.py.cs | 24 ++-- src/TensorFlowNET.Core/Python.cs | 25 +++++ .../Variables/RefVariable.Implicit.cs | 5 + .../Variables/RefVariable.cs | 5 + src/TensorFlowNET.Core/ops.py.cs | 33 +++++- test/TensorFlowNET.UnitTest/VariableTest.cs | 7 +- 12 files changed, 327 insertions(+), 27 deletions(-) create mode 100644 src/TensorFlowNET.Core/Graphs/Graph.Control.cs create mode 100644 src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs new file mode 100644 index 00000000..af92e905 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -0,0 +1,105 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Eager; + +namespace Tensorflow +{ + public partial class Graph + { + public Context _control_flow_context; + + private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); + public Queue<_ControlDependenciesController> _control_dependencies_stack + { + get + { + return _graph_control_dependencies_stack; + } + set + { + _graph_control_dependencies_stack = value; + } + } + + /// + /// For an op that takes `input_ops` as inputs, compute control inputs. + /// + /// The data input ops for an op to be created. + /// A list of control inputs for the op to be created. + private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) + { + Operation[] ret = new Operation[0]; + + foreach(var controller in _control_dependencies_stack) + { + bool dominated = false; + // If any of the input_ops already depends on the inputs from controller, + // we say that the new op is dominated (by that input), and we therefore + // do not need to add control dependencies for this controller's inputs. + foreach(var op in input_ops) + { + if (controller.op_in_group(op)) + { + dominated = true; + break; + } + } + + if (!dominated) + ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray(); + } + + return ret; + } + + public _ControlDependenciesController control_dependencies(Operation[] control_inputs) + { + if (control_inputs == null) + return new _ControlDependenciesController(this, null); + + var control_ops = new List(); + foreach (var c in control_inputs) + { + control_ops.Add(c); + } + + return new _ControlDependenciesController(this, control_ops); + } + + /// + /// Returns the current control flow context. + /// + /// A context object. + public Context _get_control_flow_context() + { + return _control_flow_context; + } + + /// + /// Sets the current control flow context. + /// + /// a context object. + public void _set_control_flow_context(Context ctx) + { + _control_flow_context = ctx; + } + + public void _push_control_dependencies_controller(_ControlDependenciesController controller) + { + _control_dependencies_stack.Enqueue(controller); + } + + public void _pop_control_dependencies_controller(_ControlDependenciesController controller) + { + _control_dependencies_stack.Dequeue(); + } + + public void _record_op_seen_by_control_dependencies(Operation op) + { + foreach (var controller in _control_dependencies_stack) + controller.add_op(op); + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 963a1a40..3d46196d 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -142,19 +142,9 @@ namespace Tensorflow return op; } - /// - /// For an op that takes `input_ops` as inputs, compute control inputs. - /// - /// The data input ops for an op to be created. - /// A list of control inputs for the op to be created. - private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) - { - return new Operation[0]; - } - private void _create_op_helper(Operation op, bool compute_device = true) { - + _record_op_seen_by_control_dependencies(op); } public void _add_op(Operation op) diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs new file mode 100644 index 00000000..08302cc1 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -0,0 +1,80 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; + +namespace Tensorflow +{ + /// + /// Context manager for `control_dependencies()` + /// + public class _ControlDependenciesController : IPython + { + private Graph _graph; + private List _control_inputs_val; + private List _seen_nodes; + private Queue<_ControlDependenciesController> _old_stack; + private bool _new_stack; + private Context _old_control_flow_context; + + public Operation[] control_inputs => _control_inputs_val.ToArray(); + + public _ControlDependenciesController(Graph graph, List control_inputs) + { + _graph = graph; + if (control_inputs == null) + { + _control_inputs_val = new List(); + _new_stack = true; + } + else + { + _control_inputs_val = control_inputs; + _new_stack = false; + } + + _seen_nodes = new List(); + } + + public void add_op(Operation op) + { + _seen_nodes.Add(op); + } + + public bool op_in_group(Operation op) + { + return _seen_nodes.Contains(op); + } + + public void __enter__() + { + if (_new_stack) + { + // Clear the control_dependencies graph. + _old_stack = _graph._control_dependencies_stack; + _graph._control_dependencies_stack = new Queue<_ControlDependenciesController>(); + + // Clear the control_flow_context too. + _old_control_flow_context = _graph._get_control_flow_context(); + _graph._set_control_flow_context(null); + } + + _graph._push_control_dependencies_controller(this); + } + + public void __exit__() + { + _graph._pop_control_dependencies_controller(this); + if (_new_stack) + { + _graph._control_dependencies_stack = _old_stack; + _graph._set_control_flow_context(_old_control_flow_context); + } + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 55987262..5db34ce9 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -39,6 +39,14 @@ namespace Tensorflow public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); + public Operation[] control_inputs + { + get + { + return GetControlInputs(); + } + } + public unsafe Operation[] GetControlInputs() { var control_inputs = new Operation[NumControlInputs]; diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index ef5d9813..dfc89e67 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -49,15 +49,53 @@ namespace Tensorflow c_api.TF_FinishOperation(desc, status); } - public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) + /// + /// Creates an `Operation`. + /// + /// `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. + /// `Graph`. The parent graph. + /// list of `Tensor` objects. The inputs to this `Operation`. + /// list of `DType` objects. + /// + /// list of operations or tensors from which to have a + /// control dependency. + /// + /// + /// List of `DType` objects representing the + /// types of the tensors accepted by the `Operation`. By default + /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect + /// reference-typed inputs must specify these explicitly. + /// + /// + /// + public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { Graph = g; + // Build the list of control inputs. + var control_input_ops = new List(); + if(control_inputs != null) + { + foreach(var c in control_inputs) + { + switch (c) + { + case Operation c1: + control_input_ops.Add(c1); + break; + default: + throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); + } + } + } + + // This will be set by self.inputs. + _id_value = Graph._next_id(); if(op_def == null) op_def = g.GetOpDef(node_def.Op); - _handle = ops._create_c_op(g, node_def, inputs); + _handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); output_types = new TF_DataType[NumOutputs]; diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 46317bac..0283e013 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -34,6 +34,14 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_AddInput(IntPtr desc, TF_Output input); + /// + /// Call once per control input to `desc`. + /// + /// TF_OperationDescription* + /// TF_Operation* + [DllImport(TensorFlowLibName)] + public static extern void TF_AddControlInput(IntPtr desc, IntPtr input); + /// /// For inputs that take a list of tensors. /// inputs must point to TF_Output[num_inputs]. diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 37b91448..6146e54d 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -13,9 +13,8 @@ namespace Tensorflow { name = namescope; - var ops_on_device = new Dictionary(); - // Sorts *inputs according to their devices. + var ops_on_device = new Dictionary(); foreach (var inp in inputs) { ops_on_device[inp.Device] = new Operation[] { inp }; @@ -24,7 +23,9 @@ namespace Tensorflow // 1-level tree. The root node is the returned NoOp node. if (ops_on_device.Count == 1) { - return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); + var dev = ops_on_device.Keys.First(); + var deps = ops_on_device.Values.First(); + return _GroupControlDeps(dev, deps, name); } // 2-level tree. The root node is the returned NoOp node. @@ -35,12 +36,21 @@ namespace Tensorflow private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") { - if (string.IsNullOrEmpty(dev)) + Operation result = null; + + Python.with(ops.control_dependencies(deps), delegate { - return gen_control_flow_ops.no_op(name); - } + if (string.IsNullOrEmpty(dev)) + { + result = gen_control_flow_ops.no_op(name); + } + else + { + result = gen_control_flow_ops.no_op(name); + } + }); - return null; + return result; } } } diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index 2e4d6423..9059be1e 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -13,5 +13,30 @@ namespace Tensorflow { Console.WriteLine(obj.ToString()); } + + public static void with(IPython py, Action action) + { + try + { + py.__enter__(); + action(); + } + catch (Exception ex) + { + throw ex; + } + finally + { + py.__exit__(); + py.Dispose(); + } + } + } + + public interface IPython : IDisposable + { + void __enter__(); + + void __exit__(); } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs index 6e4d28f5..5e7fe2a5 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs @@ -20,5 +20,10 @@ namespace Tensorflow { return var._AsTensor(); } + + public static implicit operator RefVariable(Tensor var) + { + return null; + } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index fc15d97d..5736dd23 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -166,5 +166,10 @@ namespace Tensorflow // Recursively build initializer expressions for inputs. return op; } + + public override string ToString() + { + return $"tf.Variable '{name}' shape={shape} dtype={dtype}"; + } } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index b8c8dc1c..3440298c 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -78,7 +78,29 @@ namespace Tensorflow } } - public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) + /// + /// Wrapper for `Graph.control_dependencies()` using the default graph. + /// + /// + public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) + { + return get_default_graph().control_dependencies(control_inputs); + } + + /// + /// Creates a TF_Operation. + /// + /// a `Graph`. + /// `node_def_pb2.NodeDef` for the operation to create. + /// + /// A list of `Tensor`s (corresponding to scalar inputs) and lists of + /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", + /// "list(int64)"). The length of the list should be equal to the number of + /// inputs specified by this operation's op def. + /// + /// A list of `Operation`s to set as control dependencies. + /// A wrapped TF_Operation*. + public static IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs, Operation[] control_inputs) { var op_desc = graph.NewOperation(node_def.Op, node_def.Name); @@ -102,6 +124,8 @@ namespace Tensorflow var status = new Status(); // Add control inputs + foreach (var control_input in control_inputs) + c_api.TF_AddControlInput(op_desc, control_input); // Add attrs foreach (var attr in node_def.Attr) @@ -170,8 +194,11 @@ namespace Tensorflow // inner_device_stack = default_graph._device_function_stack // var outer_context = default_graph.as_default; - var outer_graph = get_default_graph(); - // outer_device_stack = None + Python.with(ops.control_dependencies(null), delegate + { + var outer_graph = get_default_graph(); + // outer_device_stack = None + }); } private static int uid_number = 0; diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 69e4f39e..95313d4c 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -46,14 +46,13 @@ namespace TensorFlowNET.UnitTest var x = tf.Variable(10, name: "x"); var model = tf.global_variables_initializer(); - using (var session = tf.Session()) { - session.run(x.initializer); + session.run(model); for(int i = 0; i < 5; i++) { - var x1 = x + 1; - var result = session.run(x1); + x = x + 1; + var result = session.run(x); print(result); } }