| @@ -2,7 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Operations | |||||
| namespace Tensorflow | |||||
| { | { | ||||
| public interface ICanBeFlattened | public interface ICanBeFlattened | ||||
| { | { | ||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public interface IPackable | |||||
| { | |||||
| void Pack(object[] sequences); | |||||
| } | |||||
| } | |||||
| @@ -170,7 +170,7 @@ namespace Tensorflow.Operations | |||||
| /// <summary> | /// <summary> | ||||
| /// Add `op` to the current context. | /// Add `op` to the current context. | ||||
| /// </summary> | /// </summary> | ||||
| public void AddOp(Operation op) | |||||
| public virtual void AddOp(Operation op) | |||||
| { | { | ||||
| _AddOpInternal(op); | _AddOpInternal(op); | ||||
| } | } | ||||
| @@ -210,11 +210,6 @@ namespace Tensorflow.Operations | |||||
| /// </summary> | /// </summary> | ||||
| protected virtual void _AddOpInternal(Operation op) | protected virtual void _AddOpInternal(Operation op) | ||||
| { | { | ||||
| if (op.name == "rnn/while/Less") | |||||
| { | |||||
| } | |||||
| if(op == null) | if(op == null) | ||||
| { | { | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| @@ -255,9 +250,34 @@ namespace Tensorflow.Operations | |||||
| throw new NotImplementedException("_IsInOuterContext"); | throw new NotImplementedException("_IsInOuterContext"); | ||||
| } | } | ||||
| protected virtual void _RemoveExternalControlEdges(Operation op) | |||||
| /// <summary> | |||||
| /// Remove any external control dependency on this op. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operation op) | |||||
| { | { | ||||
| var internal_control_inputs = op.control_inputs; | |||||
| var while_ctxt = GetWhileContext(); | |||||
| var internal_control_inputs = new List<Operation>(); | |||||
| // A control input of `op` is internal if it is in the same while | |||||
| // loop context as the enclosing while loop context of self. | |||||
| if (while_ctxt == null) | |||||
| { | |||||
| internal_control_inputs = op.control_inputs.ToList(); | |||||
| } | |||||
| else | |||||
| { | |||||
| foreach(Tensor x in op.control_inputs) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| var external_control_inputs = new List<Operation>(); | |||||
| if (len(internal_control_inputs) != len(op.control_inputs)) | |||||
| throw new NotImplementedException(""); | |||||
| return (internal_control_inputs.ToArray(), external_control_inputs.ToArray()); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,13 +1,14 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| internal class LoopVar<TItem> : ICanBeFlattened | |||||
| internal class LoopVar<TItem> : ICanBeFlattened, IPackable | |||||
| { | { | ||||
| public Tensor Counter { get; } | |||||
| public TItem Item { get; } | |||||
| public Tensor Counter { get; set; } | |||||
| public TItem Item { get; set; } | |||||
| public LoopVar(Tensor counter, TItem item) | public LoopVar(Tensor counter, TItem item) | ||||
| { | { | ||||
| @@ -25,6 +26,13 @@ namespace Tensorflow.Operations | |||||
| return elements.ToArray(); | return elements.ToArray(); | ||||
| } | } | ||||
| public void Pack(object[] sequences) | |||||
| { | |||||
| Counter = sequences[0] as Tensor; | |||||
| if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null) | |||||
| (Item as IPackable).Pack(sequences.Skip(1).ToArray()); | |||||
| } | |||||
| public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar) | public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar) | ||||
| { | { | ||||
| return (loopVar.Counter, loopVar.Item); | return (loopVar.Counter, loopVar.Item); | ||||
| @@ -240,10 +240,13 @@ namespace Tensorflow.Operations | |||||
| // Build the graph for body. | // Build the graph for body. | ||||
| var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | ||||
| _pivot_for_body = vars_for_body[0]; | |||||
| // Convert TensorArray flow variables inside the context back into | // Convert TensorArray flow variables inside the context back into | ||||
| // their associated TensorArrays for calling the body. | // their associated TensorArrays for calling the body. | ||||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||||
| var body_result = body(original_loop_vars); | |||||
| var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||||
| var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays); | |||||
| var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| var body_result = body(packed_vars_for_body); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | ||||
| // Store body_result to keep track of TensorArrays returned by body | // Store body_result to keep track of TensorArrays returned by body | ||||
| @@ -267,17 +270,27 @@ namespace Tensorflow.Operations | |||||
| private void _FixControlInputsAndContext(Tensor[] enters) | private void _FixControlInputsAndContext(Tensor[] enters) | ||||
| { | { | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| foreach(var e in enters) | |||||
| foreach(var x in enters) | |||||
| { | { | ||||
| var inp_op = e.op.inputs[0].op; | |||||
| var inp_op = x.op.inputs[0].op; | |||||
| var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); | var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); | ||||
| var outer_control_inputs = new List<Operation>(); | |||||
| foreach(Operation op in control_inputs) | |||||
| { | |||||
| // We need to keep control inputs that are in any ancestor | |||||
| // ControlFlowContext, and within outer WhileContext. | |||||
| var keep_as_control_input = true; | |||||
| var op_ctxt = control_flow_util.GetOutputContext(op); | |||||
| var outer_ctxt = outer_context; | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| // op for op in control_inputs if self._IsInOuterContext(op) | // op for op in control_inputs if self._IsInOuterContext(op) | ||||
| var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||||
| /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||||
| .Select(x => x.op) | .Select(x => x.op) | ||||
| .ToArray(); | |||||
| e.op._set_control_flow_context(this); | |||||
| e.op._add_control_inputs(outer_control_inputs); | |||||
| graph._record_op_seen_by_control_dependencies(e.op); | |||||
| .ToArray();*/ | |||||
| x.op._set_control_flow_context(this); | |||||
| x.op._add_control_inputs(outer_control_inputs.ToArray()); | |||||
| graph._record_op_seen_by_control_dependencies(x.op); | |||||
| } | } | ||||
| } | } | ||||
| @@ -288,13 +301,127 @@ namespace Tensorflow.Operations | |||||
| _values.Add(x.name); | _values.Add(x.name); | ||||
| } | } | ||||
| protected override void _AddOpInternal(Operation op) | |||||
| { | |||||
| Operation[] external_inputs = new Operation[0]; | |||||
| if (op == null) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| else | |||||
| { | |||||
| foreach (var index in range(len(op.inputs))) | |||||
| { | |||||
| var x = op.inputs[index]; | |||||
| var real_x = AddValue(x); | |||||
| if (real_x != x) | |||||
| op._update_input(index, real_x); | |||||
| } | |||||
| // Remove any external control dependency on this op. | |||||
| (_, external_inputs) = _RemoveExternalControlEdges(op); | |||||
| // Add a control dependency to prevent loop invariants from | |||||
| // enabling ops that should not be executed. | |||||
| _MaybeAddControlDependency(op); | |||||
| foreach (Tensor x in op.outputs) | |||||
| _values.Add(x.name); | |||||
| } | |||||
| if (external_inputs.Length > 0) | |||||
| { | |||||
| throw new NotImplementedException("external_inputs.Length > 0"); | |||||
| } | |||||
| if (_outer_context != null || !IsLoopExit(op)) | |||||
| foreach (Tensor x in op.outputs) | |||||
| op.graph.prevent_feeding(x); | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(op); | |||||
| } | |||||
| protected void _MaybeAddControlDependency(Operation op) | |||||
| { | |||||
| // Determines if `op` needs a control dependency. | |||||
| Func<Operation, bool> _IsOpFree = (op1) => | |||||
| { | |||||
| if (op1.control_inputs.Length > 0) | |||||
| return false; | |||||
| if (op1.type == "SymbolicGradient") | |||||
| return true; | |||||
| foreach (Tensor x in op1.inputs) | |||||
| if (!control_flow_util.IsLoopConstantEnter(x.op)) | |||||
| return false; | |||||
| return true; | |||||
| }; | |||||
| if (_IsOpFree(op)) | |||||
| op._add_control_input(GetControlPivot().op); | |||||
| } | |||||
| private Tensor GetControlPivot() | |||||
| { | |||||
| if (_pivot_for_body != null) | |||||
| return _pivot_for_body; | |||||
| return _pivot_for_pred; | |||||
| } | |||||
| public override void AddOp(Operation op) | |||||
| { | |||||
| _AddOpInternal(op); | |||||
| } | |||||
| public override Tensor AddValue(Tensor val) | public override Tensor AddValue(Tensor val) | ||||
| { | { | ||||
| var result = val; | var result = val; | ||||
| var new_value = _values.Contains(val.name); | |||||
| var new_value = !_values.Contains(val.name); | |||||
| new_value &= val.op._get_control_flow_context() != this; | new_value &= val.op._get_control_flow_context() != this; | ||||
| if (new_value) | if (new_value) | ||||
| throw new NotImplementedException(""); | |||||
| { | |||||
| _values.Add(val.name); | |||||
| // If we are in a grad context and val is from its forward context, | |||||
| // use GetRealValue(), which adds the logic to save the history of | |||||
| // val in forward. | |||||
| var grad_ctxt = ops.get_default_graph()._get_control_flow_context(); | |||||
| if(grad_ctxt != null) | |||||
| { | |||||
| grad_ctxt = grad_ctxt.GetWhileContext(); | |||||
| if (grad_ctxt.grad_state != null) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| if (_outer_context != null) | |||||
| { | |||||
| result = _outer_context.AddValue(val); | |||||
| } | |||||
| // Create an Enter to make `result` known to this loop context. | |||||
| Tensor enter = null; | |||||
| tf_with(ops.control_dependencies(new ITensorOrOperation[0]), delegate | |||||
| { | |||||
| enter = _Enter( | |||||
| result, | |||||
| _name, | |||||
| is_constant: true, | |||||
| parallel_iterations: _parallel_iterations); | |||||
| enter.graph.prevent_feeding(enter); | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(enter.op); | |||||
| }); | |||||
| // Fix the control inputs and control flow context of these enter ops. | |||||
| _FixControlInputsAndContext(new[] { enter }); | |||||
| // Add `enter` in this context. | |||||
| _values.Add(enter.name); | |||||
| _external_values[val.name] = enter; | |||||
| result = enter; | |||||
| } | |||||
| else | else | ||||
| { | { | ||||
| var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; | var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; | ||||
| @@ -4,7 +4,7 @@ using System.Text; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| internal class BodyItemInRnnWhileLoop : ICanBeFlattened | |||||
| internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// int32 scalar Tensor. | /// int32 scalar Tensor. | ||||
| @@ -36,5 +36,12 @@ namespace Tensorflow.Operations | |||||
| elements.Add(state); | elements.Add(state); | ||||
| return elements.ToArray(); | return elements.ToArray(); | ||||
| } | } | ||||
| public void Pack(object[] sequences) | |||||
| { | |||||
| time = sequences[0] as Tensor; | |||||
| output_ta_t = new[] { sequences[1] as TensorArray }; | |||||
| state = sequences[2] as Tensor; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -192,7 +192,12 @@ namespace Tensorflow.Operations | |||||
| // Take a time step of the dynamic RNN. | // Take a time step of the dynamic RNN. | ||||
| Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => | Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => | ||||
| { | { | ||||
| throw new NotImplementedException(""); | |||||
| if (in_graph_mode) | |||||
| { | |||||
| input_ta.Select(ta => ta.read(time)).ToArray(); | |||||
| } | |||||
| return item; | |||||
| }; | }; | ||||
| control_flow_ops.while_loop( | control_flow_ops.while_loop( | ||||
| @@ -159,5 +159,20 @@ namespace Tensorflow.Operations | |||||
| { | { | ||||
| _colocate_with.Add(value); | _colocate_with.Add(value); | ||||
| } | } | ||||
| public Tensor read(Tensor index, string name = null) | |||||
| { | |||||
| var value = gen_data_flow_ops.tensor_array_read_v3( | |||||
| handle: _handle, | |||||
| index: index, | |||||
| flow_in: _flow, | |||||
| dtype: _dtype, | |||||
| name: name); | |||||
| if (_element_shape != null) | |||||
| value.set_shape(_element_shape[0].dims); | |||||
| return value; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -648,7 +648,8 @@ namespace Tensorflow | |||||
| body_buildloop = (item) => | body_buildloop = (item) => | ||||
| { | { | ||||
| var (i, lv) = (item.Counter, item.Item); | var (i, lv) = (item.Counter, item.Item); | ||||
| return new LoopVar<TItem>(i + 1, orig_body(lv)); | |||||
| var ob = orig_body(lv); | |||||
| return new LoopVar<TItem>(i + 1, ob); | |||||
| }; | }; | ||||
| } | } | ||||
| try_to_pack = false; | try_to_pack = false; | ||||
| @@ -30,6 +30,26 @@ namespace Tensorflow | |||||
| public static bool IsLoopExit(Operation op) | public static bool IsLoopExit(Operation op) | ||||
| { | { | ||||
| return op.type == "Exit" || op.type == "RefExit"; | return op.type == "Exit" || op.type == "RefExit"; | ||||
| } | |||||
| /// <summary> | |||||
| /// Returns true if `op` is an Enter. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| /// <returns></returns> | |||||
| public static bool IsLoopEnter(Operation op) | |||||
| { | |||||
| return op.type == "Enter" || op.type == "RefEnter"; | |||||
| } | |||||
| /// <summary> | |||||
| /// Return true iff op is a loop invariant. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| /// <returns></returns> | |||||
| public static bool IsLoopConstantEnter(Operation op) | |||||
| { | |||||
| return IsLoopEnter(op) && op.get_attr<bool>("is_constant"); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -198,5 +198,27 @@ namespace Tensorflow | |||||
| return _op.outputs; | return _op.outputs; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Read an element from the TensorArray into output `value`. | |||||
| /// </summary> | |||||
| /// <param name="handle"></param> | |||||
| /// <param name="index"></param> | |||||
| /// <param name="flow_in"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor tensor_array_read_v3(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("TensorArrayReadV3", name, new | |||||
| { | |||||
| handle, | |||||
| index, | |||||
| flow_in, | |||||
| dtype | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -58,5 +58,8 @@ namespace Tensorflow | |||||
| public TensorArray unstack(Tensor value, string name = null) | public TensorArray unstack(Tensor value, string name = null) | ||||
| => _implementation.unstack(value, name: name); | => _implementation.unstack(value, name: name); | ||||
| public Tensor read(Tensor index, string name = null) | |||||
| => _implementation.read(index, name: name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -401,6 +401,13 @@ namespace Tensorflow.Util | |||||
| private static int len(IEnumerable<object> x) => x.Count(); | private static int len(IEnumerable<object> x) => x.Count(); | ||||
| public static T pack_sequence_as2<T>(T structure, object[] flat_sequence, bool expand_composites = false) | |||||
| where T : IPackable | |||||
| { | |||||
| structure.Pack(flat_sequence); | |||||
| return structure; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns a given flattened sequence packed into a given structure. | /// Returns a given flattened sequence packed into a given structure. | ||||
| /// If `structure` is a scalar, `flat_sequence` must be a single-element list; | /// If `structure` is a scalar, `flat_sequence` must be a single-element list; | ||||