| @@ -2,7 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| namespace Tensorflow | |||
| { | |||
| public interface ICanBeFlattened | |||
| { | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public interface IPackable | |||
| { | |||
| void Pack(object[] sequences); | |||
| } | |||
| } | |||
| @@ -170,7 +170,7 @@ namespace Tensorflow.Operations | |||
| /// <summary> | |||
| /// Add `op` to the current context. | |||
| /// </summary> | |||
| public void AddOp(Operation op) | |||
| public virtual void AddOp(Operation op) | |||
| { | |||
| _AddOpInternal(op); | |||
| } | |||
| @@ -210,11 +210,6 @@ namespace Tensorflow.Operations | |||
| /// </summary> | |||
| protected virtual void _AddOpInternal(Operation op) | |||
| { | |||
| if (op.name == "rnn/while/Less") | |||
| { | |||
| } | |||
| if(op == null) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| @@ -255,9 +250,34 @@ namespace Tensorflow.Operations | |||
| throw new NotImplementedException("_IsInOuterContext"); | |||
| } | |||
| protected virtual void _RemoveExternalControlEdges(Operation op) | |||
| /// <summary> | |||
| /// Remove any external control dependency on this op. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operation op) | |||
| { | |||
| var internal_control_inputs = op.control_inputs; | |||
| var while_ctxt = GetWhileContext(); | |||
| var internal_control_inputs = new List<Operation>(); | |||
| // A control input of `op` is internal if it is in the same while | |||
| // loop context as the enclosing while loop context of self. | |||
| if (while_ctxt == null) | |||
| { | |||
| internal_control_inputs = op.control_inputs.ToList(); | |||
| } | |||
| else | |||
| { | |||
| foreach(Tensor x in op.control_inputs) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| var external_control_inputs = new List<Operation>(); | |||
| if (len(internal_control_inputs) != len(op.control_inputs)) | |||
| throw new NotImplementedException(""); | |||
| return (internal_control_inputs.ToArray(), external_control_inputs.ToArray()); | |||
| } | |||
| /// <summary> | |||
| @@ -1,13 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class LoopVar<TItem> : ICanBeFlattened | |||
| internal class LoopVar<TItem> : ICanBeFlattened, IPackable | |||
| { | |||
| public Tensor Counter { get; } | |||
| public TItem Item { get; } | |||
| public Tensor Counter { get; set; } | |||
| public TItem Item { get; set; } | |||
| public LoopVar(Tensor counter, TItem item) | |||
| { | |||
| @@ -25,6 +26,13 @@ namespace Tensorflow.Operations | |||
| return elements.ToArray(); | |||
| } | |||
| public void Pack(object[] sequences) | |||
| { | |||
| Counter = sequences[0] as Tensor; | |||
| if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null) | |||
| (Item as IPackable).Pack(sequences.Skip(1).ToArray()); | |||
| } | |||
| public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar) | |||
| { | |||
| return (loopVar.Counter, loopVar.Item); | |||
| @@ -240,10 +240,13 @@ namespace Tensorflow.Operations | |||
| // Build the graph for body. | |||
| var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||
| _pivot_for_body = vars_for_body[0]; | |||
| // Convert TensorArray flow variables inside the context back into | |||
| // their associated TensorArrays for calling the body. | |||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
| var body_result = body(original_loop_vars); | |||
| var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
| var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays); | |||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| var body_result = body(packed_vars_for_body); | |||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| // Store body_result to keep track of TensorArrays returned by body | |||
| @@ -267,17 +270,27 @@ namespace Tensorflow.Operations | |||
| private void _FixControlInputsAndContext(Tensor[] enters) | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| foreach(var e in enters) | |||
| foreach(var x in enters) | |||
| { | |||
| var inp_op = e.op.inputs[0].op; | |||
| var inp_op = x.op.inputs[0].op; | |||
| var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); | |||
| var outer_control_inputs = new List<Operation>(); | |||
| foreach(Operation op in control_inputs) | |||
| { | |||
| // We need to keep control inputs that are in any ancestor | |||
| // ControlFlowContext, and within outer WhileContext. | |||
| var keep_as_control_input = true; | |||
| var op_ctxt = control_flow_util.GetOutputContext(op); | |||
| var outer_ctxt = outer_context; | |||
| throw new NotImplementedException(""); | |||
| } | |||
| // op for op in control_inputs if self._IsInOuterContext(op) | |||
| var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||
| /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||
| .Select(x => x.op) | |||
| .ToArray(); | |||
| e.op._set_control_flow_context(this); | |||
| e.op._add_control_inputs(outer_control_inputs); | |||
| graph._record_op_seen_by_control_dependencies(e.op); | |||
| .ToArray();*/ | |||
| x.op._set_control_flow_context(this); | |||
| x.op._add_control_inputs(outer_control_inputs.ToArray()); | |||
| graph._record_op_seen_by_control_dependencies(x.op); | |||
| } | |||
| } | |||
| @@ -288,13 +301,127 @@ namespace Tensorflow.Operations | |||
| _values.Add(x.name); | |||
| } | |||
| protected override void _AddOpInternal(Operation op) | |||
| { | |||
| Operation[] external_inputs = new Operation[0]; | |||
| if (op == null) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| else | |||
| { | |||
| foreach (var index in range(len(op.inputs))) | |||
| { | |||
| var x = op.inputs[index]; | |||
| var real_x = AddValue(x); | |||
| if (real_x != x) | |||
| op._update_input(index, real_x); | |||
| } | |||
| // Remove any external control dependency on this op. | |||
| (_, external_inputs) = _RemoveExternalControlEdges(op); | |||
| // Add a control dependency to prevent loop invariants from | |||
| // enabling ops that should not be executed. | |||
| _MaybeAddControlDependency(op); | |||
| foreach (Tensor x in op.outputs) | |||
| _values.Add(x.name); | |||
| } | |||
| if (external_inputs.Length > 0) | |||
| { | |||
| throw new NotImplementedException("external_inputs.Length > 0"); | |||
| } | |||
| if (_outer_context != null || !IsLoopExit(op)) | |||
| foreach (Tensor x in op.outputs) | |||
| op.graph.prevent_feeding(x); | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(op); | |||
| } | |||
| protected void _MaybeAddControlDependency(Operation op) | |||
| { | |||
| // Determines if `op` needs a control dependency. | |||
| Func<Operation, bool> _IsOpFree = (op1) => | |||
| { | |||
| if (op1.control_inputs.Length > 0) | |||
| return false; | |||
| if (op1.type == "SymbolicGradient") | |||
| return true; | |||
| foreach (Tensor x in op1.inputs) | |||
| if (!control_flow_util.IsLoopConstantEnter(x.op)) | |||
| return false; | |||
| return true; | |||
| }; | |||
| if (_IsOpFree(op)) | |||
| op._add_control_input(GetControlPivot().op); | |||
| } | |||
| private Tensor GetControlPivot() | |||
| { | |||
| if (_pivot_for_body != null) | |||
| return _pivot_for_body; | |||
| return _pivot_for_pred; | |||
| } | |||
| public override void AddOp(Operation op) | |||
| { | |||
| _AddOpInternal(op); | |||
| } | |||
| public override Tensor AddValue(Tensor val) | |||
| { | |||
| var result = val; | |||
| var new_value = _values.Contains(val.name); | |||
| var new_value = !_values.Contains(val.name); | |||
| new_value &= val.op._get_control_flow_context() != this; | |||
| if (new_value) | |||
| throw new NotImplementedException(""); | |||
| { | |||
| _values.Add(val.name); | |||
| // If we are in a grad context and val is from its forward context, | |||
| // use GetRealValue(), which adds the logic to save the history of | |||
| // val in forward. | |||
| var grad_ctxt = ops.get_default_graph()._get_control_flow_context(); | |||
| if(grad_ctxt != null) | |||
| { | |||
| grad_ctxt = grad_ctxt.GetWhileContext(); | |||
| if (grad_ctxt.grad_state != null) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| if (_outer_context != null) | |||
| { | |||
| result = _outer_context.AddValue(val); | |||
| } | |||
| // Create an Enter to make `result` known to this loop context. | |||
| Tensor enter = null; | |||
| tf_with(ops.control_dependencies(new ITensorOrOperation[0]), delegate | |||
| { | |||
| enter = _Enter( | |||
| result, | |||
| _name, | |||
| is_constant: true, | |||
| parallel_iterations: _parallel_iterations); | |||
| enter.graph.prevent_feeding(enter); | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(enter.op); | |||
| }); | |||
| // Fix the control inputs and control flow context of these enter ops. | |||
| _FixControlInputsAndContext(new[] { enter }); | |||
| // Add `enter` in this context. | |||
| _values.Add(enter.name); | |||
| _external_values[val.name] = enter; | |||
| result = enter; | |||
| } | |||
| else | |||
| { | |||
| var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; | |||
| @@ -4,7 +4,7 @@ using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class BodyItemInRnnWhileLoop : ICanBeFlattened | |||
| internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable | |||
| { | |||
| /// <summary> | |||
| /// int32 scalar Tensor. | |||
| @@ -36,5 +36,12 @@ namespace Tensorflow.Operations | |||
| elements.Add(state); | |||
| return elements.ToArray(); | |||
| } | |||
| public void Pack(object[] sequences) | |||
| { | |||
| time = sequences[0] as Tensor; | |||
| output_ta_t = new[] { sequences[1] as TensorArray }; | |||
| state = sequences[2] as Tensor; | |||
| } | |||
| } | |||
| } | |||
| @@ -192,7 +192,12 @@ namespace Tensorflow.Operations | |||
| // Take a time step of the dynamic RNN. | |||
| Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => | |||
| { | |||
| throw new NotImplementedException(""); | |||
| if (in_graph_mode) | |||
| { | |||
| input_ta.Select(ta => ta.read(time)).ToArray(); | |||
| } | |||
| return item; | |||
| }; | |||
| control_flow_ops.while_loop( | |||
| @@ -159,5 +159,20 @@ namespace Tensorflow.Operations | |||
| { | |||
| _colocate_with.Add(value); | |||
| } | |||
| public Tensor read(Tensor index, string name = null) | |||
| { | |||
| var value = gen_data_flow_ops.tensor_array_read_v3( | |||
| handle: _handle, | |||
| index: index, | |||
| flow_in: _flow, | |||
| dtype: _dtype, | |||
| name: name); | |||
| if (_element_shape != null) | |||
| value.set_shape(_element_shape[0].dims); | |||
| return value; | |||
| } | |||
| } | |||
| } | |||
| @@ -648,7 +648,8 @@ namespace Tensorflow | |||
| body_buildloop = (item) => | |||
| { | |||
| var (i, lv) = (item.Counter, item.Item); | |||
| return new LoopVar<TItem>(i + 1, orig_body(lv)); | |||
| var ob = orig_body(lv); | |||
| return new LoopVar<TItem>(i + 1, ob); | |||
| }; | |||
| } | |||
| try_to_pack = false; | |||
| @@ -30,6 +30,26 @@ namespace Tensorflow | |||
| public static bool IsLoopExit(Operation op) | |||
| { | |||
| return op.type == "Exit" || op.type == "RefExit"; | |||
| } | |||
| /// <summary> | |||
| /// Returns true if `op` is an Enter. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <returns></returns> | |||
| public static bool IsLoopEnter(Operation op) | |||
| { | |||
| return op.type == "Enter" || op.type == "RefEnter"; | |||
| } | |||
| /// <summary> | |||
| /// Return true iff op is a loop invariant. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <returns></returns> | |||
| public static bool IsLoopConstantEnter(Operation op) | |||
| { | |||
| return IsLoopEnter(op) && op.get_attr<bool>("is_constant"); | |||
| } | |||
| /// <summary> | |||
| @@ -198,5 +198,27 @@ namespace Tensorflow | |||
| return _op.outputs; | |||
| } | |||
| /// <summary> | |||
| /// Read an element from the TensorArray into output `value`. | |||
| /// </summary> | |||
| /// <param name="handle"></param> | |||
| /// <param name="index"></param> | |||
| /// <param name="flow_in"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor tensor_array_read_v3(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("TensorArrayReadV3", name, new | |||
| { | |||
| handle, | |||
| index, | |||
| flow_in, | |||
| dtype | |||
| }); | |||
| return _op.output; | |||
| } | |||
| } | |||
| } | |||
| @@ -58,5 +58,8 @@ namespace Tensorflow | |||
| public TensorArray unstack(Tensor value, string name = null) | |||
| => _implementation.unstack(value, name: name); | |||
| public Tensor read(Tensor index, string name = null) | |||
| => _implementation.read(index, name: name); | |||
| } | |||
| } | |||
| @@ -401,6 +401,13 @@ namespace Tensorflow.Util | |||
| private static int len(IEnumerable<object> x) => x.Count(); | |||
| public static T pack_sequence_as2<T>(T structure, object[] flat_sequence, bool expand_composites = false) | |||
| where T : IPackable | |||
| { | |||
| structure.Pack(flat_sequence); | |||
| return structure; | |||
| } | |||
| /// <summary> | |||
| /// Returns a given flattened sequence packed into a given structure. | |||
| /// If `structure` is a scalar, `flat_sequence` must be a single-element list; | |||