From 3425aa4a292da4ebc0a148c6fe8518e195203142 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 13:44:04 -0500 Subject: [PATCH] LoopVar --- .../ControlFlows/ControlFlowContext.cs | 27 ++- .../Operations/ControlFlows/LoopVar.cs | 5 + .../Operations/ControlFlows/WhileContext.cs | 211 +++++++++++------- 3 files changed, 160 insertions(+), 83 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index c076cbc7..8a624df2 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Operations.ControlFlows; using static Tensorflow.ControlFlowContextDef; +using static Tensorflow.Binding; namespace Tensorflow.Operations { @@ -72,6 +73,7 @@ namespace Tensorflow.Operations public ControlFlowContext() { _context_stack = new Stack(); + _external_values = new Dictionary(); } public string name { get => _name; } @@ -180,6 +182,11 @@ namespace Tensorflow.Operations public virtual bool back_prop => throw new NotImplementedException("abstract method"); + /// + /// Add `val` to the current context and its outer context recursively. + /// + /// + /// public virtual Tensor AddValue(Tensor val) { // to be overridden @@ -203,7 +210,25 @@ namespace Tensorflow.Operations /// protected virtual void _AddOpInternal(Operation op) { - + if (op.name == "rnn/while/Less") + { + + } + + if(op == null) + { + throw new NotImplementedException(""); + } + else + { + foreach(var index in range(len(op.inputs))) + { + var x = op.inputs[index]; + var real_x = AddValue(x); + if (real_x != x) + op._update_input(index, real_x); + } + } } protected bool OpInContext(Operation op) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs index c313739b..d49d5abf 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -24,5 +24,10 @@ namespace Tensorflow.Operations elements.Add(Item); return elements.ToArray(); } + + public static implicit operator (Tensor, TItem)(LoopVar loopVar) + { + return (loopVar.Counter, loopVar.Item); + } } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 462aca25..b40dae11 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -71,6 +71,8 @@ namespace Tensorflow.Operations string name) { _name = ops.get_default_graph().unique_name(name); + _maximum_iterations = maximum_iterations; + _parallel_iterations = parallel_iterations; _back_prop = back_prop; _swap_memory = swap_memory; _loop_exits = new List(); @@ -107,18 +109,27 @@ namespace Tensorflow.Operations /// /// Add the loop termination condition and body to the graph. /// - internal Tensor[] BuildLoop(Func pred, - Func> body, + internal Tensor[] BuildLoop(Func, Tensor> pred, + Func, LoopVar> body, LoopVar loop_vars, - TensorShape shape_invariants, + TensorShape[] shape_invariants, bool return_same_structure) { // Keep original_loop_vars to identify which are TensorArrays var original_loop_vars = loop_vars; // Convert TensorArrays to their flow variables + var loop_vars_tensors = nest.flatten2(loop_vars) + .Select(x => _convert_tensorarray_to_flow(x)) + .ToArray(); + + if (shape_invariants == null) + shape_invariants = loop_vars_tensors + .Select(x => _get_shape_invariant(x as Tensor)) + .ToArray(); + Enter(); var(original_body_result, exit_vars) = _BuildLoop( - pred, body, original_loop_vars, loop_vars, shape_invariants); + pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); Exit(); var flat_result = original_body_result; @@ -131,7 +142,7 @@ namespace Tensorflow.Operations return packed_exit_vars as Tensor[]; } - private Tensor _convert_tensorarray_to_flow(TItem tensor_or_tensor_array) + private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) { if (tensor_or_tensor_array is TensorArray tensor_array) return tensor_array.flow; @@ -141,97 +152,116 @@ namespace Tensorflow.Operations throw new NotImplementedException("_convert_tensorarray_to_flow"); } - private (Tensor[], Tensor[]) _BuildLoop(Func pred, - Func> body, - LoopVar original_loop_vars, - LoopVar loop_vars, - TensorShape shape_invariants) + private TensorShape _get_shape_invariant(Tensor var, int[] shape = null) { - var flat_loop_vars = original_loop_vars; + return var.TensorShape; + } - // Convert TensorArrays to their flow variables - var loop_vars_tensor = nest.map_structure( - _convert_tensorarray_to_flow, - nest.flatten2(loop_vars)); + /// + /// Add the loop termination condition and body to the graph. + /// + /// + /// + /// + /// + /// + /// + /// + private (Tensor[], Tensor[]) _BuildLoop(Func, Tensor> pred, + Func, LoopVar> body, + LoopVar original_loop_vars, + Tensor[] loop_vars, + TensorShape[] shape_invariants) + { + var flat_loop_vars = nest.flatten2(original_loop_vars) + .Select(x => (ITensorOrTensorArray)x) + .ToArray(); // Let the context know the loop variables so the loop variables // would be added in the outer contexts properly. - if (loop_vars is Tensor[] real_vars) + _InitializeValues(loop_vars); + var real_vars = loop_vars; + Tensor[] enter_vars = null; + tf_with(ops.control_dependencies(null), delegate { - _InitializeValues(real_vars); - Tensor[] enter_vars = null; - tf_with(ops.control_dependencies(null), delegate - { - enter_vars = real_vars.Select(x => _Enter(x, - _name, - is_constant: false, - parallel_iterations: _parallel_iterations, - use_input_shape: shape_invariants == null)) - .ToArray(); - - foreach (var x in enter_vars) - { - x.graph.prevent_feeding(x); - if (_outer_context != null) - _outer_context.AddInnerOp(x.op); - } - }); - - // Finds the closest enclosing non-None control pivot. - var outer_context = _outer_context; - while (outer_context != null) + enter_vars = real_vars.Select(x => _Enter(x, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + use_input_shape: shape_invariants == null)) + .ToArray(); + + foreach (var x in enter_vars) { - + x.graph.prevent_feeding(x); + if (_outer_context != null) + _outer_context.AddInnerOp(x.op); } + }); - _SetShapeInvariants(real_vars, enter_vars, shape_invariants); - - // Fix the control inputs and control flow context of these enter ops. - _FixControlInputsAndContext(enter_vars); - _InitializeValues(enter_vars); - _loop_enters = enter_vars.ToList(); - - var merge_vars = enter_vars - .Select(x => merge(new[] { x, x })) - .ToArray(); + // Finds the closest enclosing non-None control pivot. + var outer_context = _outer_context; + object control_pivot = null; + while (outer_context != null && control_pivot == null) + { - _pivot_for_pred = merge_vars[0]; + } - // Build the graph for pred. - var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); - // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); - var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem))); - _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); - var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) - .ToArray(); + if (control_pivot != null) + { - // Build the graph for body. - var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); - // Convert TensorArray flow variables inside the context back into - // their associated TensorArrays for calling the body. - var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); - /*var body_result = body(packed_vars_for_body[0]); - var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); - - // Store body_result to keep track of TensorArrays returned by body - var original_body_result = new[] { body_result }; - // Convert TensorArrays returned by body into their flow variables - var result = new[] { body_result }; - - var next_vars = new List(); - foreach (var (m, v) in zip(merge_vars, result)) - next_vars.Add(_AddNextAndBackEdge(m, v)); - - // Add the exit ops. - var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); - _loop_exits = exit_vars; - - // Exit the loop. - // ExitResult(exit_vars); - return (original_body_result, exit_vars.ToArray());*/ } - throw new NotImplementedException(""); + _SetShapeInvariants(real_vars, enter_vars, shape_invariants); + + // Fix the control inputs and control flow context of these enter ops. + _FixControlInputsAndContext(enter_vars); + _InitializeValues(enter_vars); + _loop_enters = enter_vars.ToList(); + + var merge_vars = enter_vars + .Select(x => merge(new[] { x, x })) + .ToArray(); + + _pivot_for_pred = merge_vars[0]; + + // Build the graph for pred. + var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); + //var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true); + var packed_vars = new LoopVar((Tensor)merge_vars_with_tensor_arrays[0], + (TItem)(object)new BodyItemInRnnWhileLoop((Tensor)merge_vars_with_tensor_arrays[1], + new[] { (TensorArray)merge_vars_with_tensor_arrays[2] }, + (Tensor)merge_vars_with_tensor_arrays[3])); + var pp = pred(packed_vars); + var c = ops.convert_to_tensor(pp); + _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); + var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) + .ToArray(); + + // Build the graph for body. + var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); + // Convert TensorArray flow variables inside the context back into + // their associated TensorArrays for calling the body. + var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); + var body_result = body(original_loop_vars); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + + // Store body_result to keep track of TensorArrays returned by body + var original_body_result = new[] { body_result }; + // Convert TensorArrays returned by body into their flow variables + var result = new[] { body_result }; + + var next_vars = new List(); + //foreach (var (m, v) in zip(merge_vars, result)) + //next_vars.Add(_AddNextAndBackEdge(m, v)); + + // Add the exit ops. + var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); + _loop_exits = exit_vars; + + // Exit the loop. + // ExitResult(exit_vars); + return (null, exit_vars.ToArray()); } private void _FixControlInputsAndContext(Tensor[] enters) @@ -258,6 +288,23 @@ namespace Tensorflow.Operations _values.Add(x.name); } + public override Tensor AddValue(Tensor val) + { + var result = val; + var new_value = _values.Contains(val.name); + new_value &= val.op._get_control_flow_context() != this; + if (new_value) + throw new NotImplementedException(""); + else + { + var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; + if (actual_val != null) + result = actual_val as Tensor; + } + + return result; + } + public override WhileContext GetWhileContext() { return this;