| @@ -0,0 +1,105 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| { | |||
| public Context _control_flow_context; | |||
| private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | |||
| public Queue<_ControlDependenciesController> _control_dependencies_stack | |||
| { | |||
| get | |||
| { | |||
| return _graph_control_dependencies_stack; | |||
| } | |||
| set | |||
| { | |||
| _graph_control_dependencies_stack = value; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// For an op that takes `input_ops` as inputs, compute control inputs. | |||
| /// </summary> | |||
| /// <param name="input_ops">The data input ops for an op to be created.</param> | |||
| /// <returns>A list of control inputs for the op to be created.</returns> | |||
| private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||
| { | |||
| Operation[] ret = new Operation[0]; | |||
| foreach(var controller in _control_dependencies_stack) | |||
| { | |||
| bool dominated = false; | |||
| // If any of the input_ops already depends on the inputs from controller, | |||
| // we say that the new op is dominated (by that input), and we therefore | |||
| // do not need to add control dependencies for this controller's inputs. | |||
| foreach(var op in input_ops) | |||
| { | |||
| if (controller.op_in_group(op)) | |||
| { | |||
| dominated = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!dominated) | |||
| ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray(); | |||
| } | |||
| return ret; | |||
| } | |||
| public _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
| { | |||
| if (control_inputs == null) | |||
| return new _ControlDependenciesController(this, null); | |||
| var control_ops = new List<Operation>(); | |||
| foreach (var c in control_inputs) | |||
| { | |||
| control_ops.Add(c); | |||
| } | |||
| return new _ControlDependenciesController(this, control_ops); | |||
| } | |||
| /// <summary> | |||
| /// Returns the current control flow context. | |||
| /// </summary> | |||
| /// <returns>A context object.</returns> | |||
| public Context _get_control_flow_context() | |||
| { | |||
| return _control_flow_context; | |||
| } | |||
| /// <summary> | |||
| /// Sets the current control flow context. | |||
| /// </summary> | |||
| /// <param name="ctx">a context object.</param> | |||
| public void _set_control_flow_context(Context ctx) | |||
| { | |||
| _control_flow_context = ctx; | |||
| } | |||
| public void _push_control_dependencies_controller(_ControlDependenciesController controller) | |||
| { | |||
| _control_dependencies_stack.Enqueue(controller); | |||
| } | |||
| public void _pop_control_dependencies_controller(_ControlDependenciesController controller) | |||
| { | |||
| _control_dependencies_stack.Dequeue(); | |||
| } | |||
| public void _record_op_seen_by_control_dependencies(Operation op) | |||
| { | |||
| foreach (var controller in _control_dependencies_stack) | |||
| controller.add_op(op); | |||
| } | |||
| } | |||
| } | |||
| @@ -142,19 +142,9 @@ namespace Tensorflow | |||
| return op; | |||
| } | |||
| /// <summary> | |||
| /// For an op that takes `input_ops` as inputs, compute control inputs. | |||
| /// </summary> | |||
| /// <param name="input_ops">The data input ops for an op to be created.</param> | |||
| /// <returns>A list of control inputs for the op to be created.</returns> | |||
| private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||
| { | |||
| return new Operation[0]; | |||
| } | |||
| private void _create_op_helper(Operation op, bool compute_device = true) | |||
| { | |||
| _record_op_seen_by_control_dependencies(op); | |||
| } | |||
| public void _add_op(Operation op) | |||
| @@ -0,0 +1,80 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Context manager for `control_dependencies()` | |||
| /// </summary> | |||
| public class _ControlDependenciesController : IPython | |||
| { | |||
| private Graph _graph; | |||
| private List<Operation> _control_inputs_val; | |||
| private List<Operation> _seen_nodes; | |||
| private Queue<_ControlDependenciesController> _old_stack; | |||
| private bool _new_stack; | |||
| private Context _old_control_flow_context; | |||
| public Operation[] control_inputs => _control_inputs_val.ToArray(); | |||
| public _ControlDependenciesController(Graph graph, List<Operation> control_inputs) | |||
| { | |||
| _graph = graph; | |||
| if (control_inputs == null) | |||
| { | |||
| _control_inputs_val = new List<Operation>(); | |||
| _new_stack = true; | |||
| } | |||
| else | |||
| { | |||
| _control_inputs_val = control_inputs; | |||
| _new_stack = false; | |||
| } | |||
| _seen_nodes = new List<Operation>(); | |||
| } | |||
| public void add_op(Operation op) | |||
| { | |||
| _seen_nodes.Add(op); | |||
| } | |||
| public bool op_in_group(Operation op) | |||
| { | |||
| return _seen_nodes.Contains(op); | |||
| } | |||
| public void __enter__() | |||
| { | |||
| if (_new_stack) | |||
| { | |||
| // Clear the control_dependencies graph. | |||
| _old_stack = _graph._control_dependencies_stack; | |||
| _graph._control_dependencies_stack = new Queue<_ControlDependenciesController>(); | |||
| // Clear the control_flow_context too. | |||
| _old_control_flow_context = _graph._get_control_flow_context(); | |||
| _graph._set_control_flow_context(null); | |||
| } | |||
| _graph._push_control_dependencies_controller(this); | |||
| } | |||
| public void __exit__() | |||
| { | |||
| _graph._pop_control_dependencies_controller(this); | |||
| if (_new_stack) | |||
| { | |||
| _graph._control_dependencies_stack = _old_stack; | |||
| _graph._set_control_flow_context(_old_control_flow_context); | |||
| } | |||
| } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -39,6 +39,14 @@ namespace Tensorflow | |||
| public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
| public Operation[] control_inputs | |||
| { | |||
| get | |||
| { | |||
| return GetControlInputs(); | |||
| } | |||
| } | |||
| public unsafe Operation[] GetControlInputs() | |||
| { | |||
| var control_inputs = new Operation[NumControlInputs]; | |||
| @@ -49,15 +49,53 @@ namespace Tensorflow | |||
| c_api.TF_FinishOperation(desc, status); | |||
| } | |||
| public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
| /// <summary> | |||
| /// Creates an `Operation`. | |||
| /// </summary> | |||
| /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> | |||
| /// <param name="g">`Graph`. The parent graph.</param> | |||
| /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> | |||
| /// <param name="output_types">list of `DType` objects.</param> | |||
| /// <param name="control_inputs"> | |||
| /// list of operations or tensors from which to have a | |||
| /// control dependency. | |||
| /// </param> | |||
| /// <param name="input_types"> | |||
| /// List of `DType` objects representing the | |||
| /// types of the tensors accepted by the `Operation`. By default | |||
| /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect | |||
| /// reference-typed inputs must specify these explicitly. | |||
| /// </param> | |||
| /// <param name="original_op"></param> | |||
| /// <param name="op_def"></param> | |||
| public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
| { | |||
| Graph = g; | |||
| // Build the list of control inputs. | |||
| var control_input_ops = new List<Operation>(); | |||
| if(control_inputs != null) | |||
| { | |||
| foreach(var c in control_inputs) | |||
| { | |||
| switch (c) | |||
| { | |||
| case Operation c1: | |||
| control_input_ops.Add(c1); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); | |||
| } | |||
| } | |||
| } | |||
| // This will be set by self.inputs. | |||
| _id_value = Graph._next_id(); | |||
| if(op_def == null) | |||
| op_def = g.GetOpDef(node_def.Op); | |||
| _handle = ops._create_c_op(g, node_def, inputs); | |||
| _handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); | |||
| output_types = new TF_DataType[NumOutputs]; | |||
| @@ -34,6 +34,14 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_AddInput(IntPtr desc, TF_Output input); | |||
| /// <summary> | |||
| /// Call once per control input to `desc`. | |||
| /// </summary> | |||
| /// <param name="desc">TF_OperationDescription*</param> | |||
| /// <param name="input">TF_Operation*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_AddControlInput(IntPtr desc, IntPtr input); | |||
| /// <summary> | |||
| /// For inputs that take a list of tensors. | |||
| /// inputs must point to TF_Output[num_inputs]. | |||
| @@ -13,9 +13,8 @@ namespace Tensorflow | |||
| { | |||
| name = namescope; | |||
| var ops_on_device = new Dictionary<string, Operation[]>(); | |||
| // Sorts *inputs according to their devices. | |||
| var ops_on_device = new Dictionary<string, Operation[]>(); | |||
| foreach (var inp in inputs) | |||
| { | |||
| ops_on_device[inp.Device] = new Operation[] { inp }; | |||
| @@ -24,7 +23,9 @@ namespace Tensorflow | |||
| // 1-level tree. The root node is the returned NoOp node. | |||
| if (ops_on_device.Count == 1) | |||
| { | |||
| return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); | |||
| var dev = ops_on_device.Keys.First(); | |||
| var deps = ops_on_device.Values.First(); | |||
| return _GroupControlDeps(dev, deps, name); | |||
| } | |||
| // 2-level tree. The root node is the returned NoOp node. | |||
| @@ -35,12 +36,21 @@ namespace Tensorflow | |||
| private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") | |||
| { | |||
| if (string.IsNullOrEmpty(dev)) | |||
| Operation result = null; | |||
| Python.with(ops.control_dependencies(deps), delegate | |||
| { | |||
| return gen_control_flow_ops.no_op(name); | |||
| } | |||
| if (string.IsNullOrEmpty(dev)) | |||
| { | |||
| result = gen_control_flow_ops.no_op(name); | |||
| } | |||
| else | |||
| { | |||
| result = gen_control_flow_ops.no_op(name); | |||
| } | |||
| }); | |||
| return null; | |||
| return result; | |||
| } | |||
| } | |||
| } | |||
| @@ -13,5 +13,30 @@ namespace Tensorflow | |||
| { | |||
| Console.WriteLine(obj.ToString()); | |||
| } | |||
| public static void with(IPython py, Action action) | |||
| { | |||
| try | |||
| { | |||
| py.__enter__(); | |||
| action(); | |||
| } | |||
| catch (Exception ex) | |||
| { | |||
| throw ex; | |||
| } | |||
| finally | |||
| { | |||
| py.__exit__(); | |||
| py.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| public interface IPython : IDisposable | |||
| { | |||
| void __enter__(); | |||
| void __exit__(); | |||
| } | |||
| } | |||
| @@ -20,5 +20,10 @@ namespace Tensorflow | |||
| { | |||
| return var._AsTensor(); | |||
| } | |||
| public static implicit operator RefVariable(Tensor var) | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -166,5 +166,10 @@ namespace Tensorflow | |||
| // Recursively build initializer expressions for inputs. | |||
| return op; | |||
| } | |||
| public override string ToString() | |||
| { | |||
| return $"tf.Variable '{name}' shape={shape} dtype={dtype}"; | |||
| } | |||
| } | |||
| } | |||
| @@ -78,7 +78,29 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | |||
| /// <summary> | |||
| /// Wrapper for `Graph.control_dependencies()` using the default graph. | |||
| /// </summary> | |||
| /// <param name="control_inputs"></param> | |||
| public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
| { | |||
| return get_default_graph().control_dependencies(control_inputs); | |||
| } | |||
| /// <summary> | |||
| /// Creates a TF_Operation. | |||
| /// </summary> | |||
| /// <param name="graph">a `Graph`.</param> | |||
| /// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param> | |||
| /// <param name="inputs"> | |||
| /// A list of `Tensor`s (corresponding to scalar inputs) and lists of | |||
| /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", | |||
| /// "list(int64)"). The length of the list should be equal to the number of | |||
| /// inputs specified by this operation's op def. | |||
| /// </param> | |||
| /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | |||
| /// <returns>A wrapped TF_Operation*.</returns> | |||
| public static IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs, Operation[] control_inputs) | |||
| { | |||
| var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
| @@ -102,6 +124,8 @@ namespace Tensorflow | |||
| var status = new Status(); | |||
| // Add control inputs | |||
| foreach (var control_input in control_inputs) | |||
| c_api.TF_AddControlInput(op_desc, control_input); | |||
| // Add attrs | |||
| foreach (var attr in node_def.Attr) | |||
| @@ -170,8 +194,11 @@ namespace Tensorflow | |||
| // inner_device_stack = default_graph._device_function_stack | |||
| // var outer_context = default_graph.as_default; | |||
| var outer_graph = get_default_graph(); | |||
| // outer_device_stack = None | |||
| Python.with(ops.control_dependencies(null), delegate | |||
| { | |||
| var outer_graph = get_default_graph(); | |||
| // outer_device_stack = None | |||
| }); | |||
| } | |||
| private static int uid_number = 0; | |||
| @@ -46,14 +46,13 @@ namespace TensorFlowNET.UnitTest | |||
| var x = tf.Variable(10, name: "x"); | |||
| var model = tf.global_variables_initializer(); | |||
| using (var session = tf.Session()) | |||
| { | |||
| session.run(x.initializer); | |||
| session.run(model); | |||
| for(int i = 0; i < 5; i++) | |||
| { | |||
| var x1 = x + 1; | |||
| var result = session.run(x1); | |||
| x = x + 1; | |||
| var result = session.run(x); | |||
| print(result); | |||
| } | |||
| } | |||