From a65d881213b7ab11a3223fef3996a10005ec3627 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 24 Oct 2019 13:45:02 -0500 Subject: [PATCH] CheckInputFromValidContext --- .../Operations/NnOps/rnn.cs | 9 ++-- .../Operations/Operation.Control.cs | 5 +- .../Operations/Operation.cs | 1 + .../Operations/_GraphTensorArray.cs | 10 ++-- .../Operations/control_flow_ops.cs | 52 +++++++++++++------ .../Operations/control_flow_util.py.cs | 22 ++++++++ .../Operations/gen_math_ops.cs | 2 + .../Operations/tensor_array_ops.cs | 33 ++++++++++++ .../TensorFlowNET.Core.csproj | 3 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 4 +- src/TensorFlowNET.Core/Util/nest.py.cs | 9 ++-- 11 files changed, 117 insertions(+), 33 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/tensor_array_ops.cs diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index e058c077..475dd0ff 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -172,7 +172,8 @@ namespace Tensorflow.Operations for (int i = 0; i < input_ta.Count; i++) { - var (ta, input_) = (input_ta[0], flat_input[0]); + var (ta, input_) = (input_ta[i], flat_input[i]); + ta.unstack(input_); } } @@ -185,16 +186,16 @@ namespace Tensorflow.Operations Func cond = (item) => { - return time < loop_bound; + return item.time < loop_bound; }; // Take a time step of the dynamic RNN. Func _time_step = (item) => { - return item; + throw new NotImplementedException(""); }; - control_flow_ops.while_loop( + control_flow_ops.while_loop( cond: cond, body: _time_step, loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 2f61f954..5e93cfd0 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -30,10 +30,9 @@ namespace Tensorflow /// public void _control_flow_post_processing() { - foreach(var input_tensor in inputs) + foreach(Tensor input_tensor in inputs) { - //TODO: implement below code dependency - //control_flow_util.CheckInputFromValidContext(this, input_tensor.op); + control_flow_util.CheckInputFromValidContext(this, input_tensor.op); } if (_control_flow_context != null) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 3b40c95a..db001e51 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -23,6 +23,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using Tensorflow.Util; +using static Tensorflow.Binding; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index 5a667560..ebc88230 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -25,6 +25,7 @@ namespace Tensorflow.Operations internal class _GraphTensorArray { internal TF_DataType _dtype; + public TF_DataType dtype => _dtype; /// /// Used to keep track of what tensors the TensorArray should be @@ -32,14 +33,17 @@ namespace Tensorflow.Operations /// first tensor written to it. /// bool _colocate_with_first_write_call; + public bool colocate_with_first_write_call => _colocate_with_first_write_call; bool _infer_shape; - bool _dynamic_size; - List _element_shape; + public bool infer_shape => _infer_shape; + public bool _dynamic_size; + public List _element_shape; - List _colocate_with; + public List _colocate_with; internal Tensor _handle; + public Tensor handle => _handle; internal Tensor _flow; public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 181b7e71..6c286fc1 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -21,6 +21,7 @@ using Tensorflow.Operations; using Tensorflow.Operations.ControlFlows; using util = Tensorflow.control_flow_util; using static Tensorflow.Binding; +using Tensorflow.Util; namespace Tensorflow { @@ -251,12 +252,16 @@ namespace Tensorflow return gen_array_ops.identity(data, name: name); } - public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null) + public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape[] shapes = null) { if (shapes == null) return; - throw new NotImplementedException("_SetShapeInvariants"); + var flat_shapes = nest.flatten2(shapes); + foreach (var (inp, var, shape) in zip(input_vars, enter_vars, flat_shapes)) + { + var.set_shape(shape); + } } /// @@ -428,12 +433,12 @@ namespace Tensorflow .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .ToArray(); - merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); + var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); - return merges[0]; + return new Tensor(IntPtr.Zero); }); } @@ -473,22 +478,28 @@ namespace Tensorflow var res_f_flat = res_f; var merges = zip(res_f_flat, res_t_flat) - .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) + .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) .ToArray(); - merges = _convert_flows_to_tensorarrays(orig_res_t, merges); + var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); - return merges; + return new[] { new Tensor(IntPtr.Zero) }; }); } - public static Tensor[] _convert_flows_to_tensorarrays(T tensors_or_tensorarrays, Tensor[] tensors_or_flows) + public static ITensorOrTensorArray[] _convert_flows_to_tensorarrays(ITensorOrTensorArray[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) { - // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); - return tensors_or_flows; + return zip(tensors_or_tensorarrays, tensors_or_flows).Select(x => + { + var (ta, t_or_flow) = (x.Item1, x.Item2); + if (ta is TensorArray ta_1) + return tensor_array_ops.build_ta_with_new_flow(ta_1, t_or_flow) as ITensorOrTensorArray; + else + return t_or_flow as ITensorOrTensorArray; + }).ToArray(); } /// @@ -592,7 +603,7 @@ namespace Tensorflow /// /// public static Tensor while_loop(Func cond, Func body, TItem loop_vars, - TensorShape shape_invariants = null, + TensorShape[] shape_invariants = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, @@ -617,8 +628,8 @@ namespace Tensorflow var orig_body = body; LoopVar loop_vars_1 = null; - Func> body_buildloop = null; - Func cond_buildloop = null; + Func, LoopVar> body_buildloop = null; + Func, Tensor> cond_buildloop = null; if (try_to_pack) { @@ -627,9 +638,18 @@ namespace Tensorflow else { loop_vars_1 = new LoopVar(counter, loop_vars); - cond_buildloop = (i, lv) => - math_ops.logical_and(i < maximum_iterations, orig_cond(lv)); - body_buildloop = (i, lv) => new LoopVar(i + 1, orig_body(lv)); + cond_buildloop = (item) => + { + var (i, lv) = (item.Counter, item.Item); + var oc = orig_cond(lv); + return math_ops.logical_and(i < maximum_iterations, oc); + }; + + body_buildloop = (item) => + { + var (i, lv) = (item.Counter, item.Item); + return new LoopVar(i + 1, orig_body(lv)); + }; } try_to_pack = false; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs index 4ae03e42..5377eb5b 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -14,7 +14,9 @@ limitations under the License. ******************************************************************************/ +using System; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { @@ -53,5 +55,25 @@ namespace Tensorflow ctxt = ctxt.outer_context; return ctxt; } + + public static void CheckInputFromValidContext(Operation op, Operation input_op) + { + var op_ctxt = op._get_control_flow_context(); + var input_ctxt = GetOutputContext(input_op); + var valid = false; + if (input_ctxt == null) + valid = true; + else if (op_ctxt == input_ctxt) + valid = true; + else + { + throw new NotImplementedException(""); + } + + if (!valid) + { + throw new NotImplementedException(""); + } + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 7192dc57..e1225cc9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; + namespace Tensorflow { public static class gen_math_ops diff --git a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs new file mode 100644 index 00000000..8ce3b5c7 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class tensor_array_ops + { + /// + /// Builds a TensorArray with a new `flow` tensor. + /// + /// + /// + /// + public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) + { + var impl = old_ta._implementation; + + var new_ta = new TensorArray( + dtype: impl.dtype, + handle: impl.handle, + flow: flow, + infer_shape: impl.infer_shape, + colocate_with_first_write_call: impl.colocate_with_first_write_call); + + var new_impl = new_ta._implementation; + new_impl._dynamic_size = impl._dynamic_size; + new_impl._colocate_with = impl._colocate_with; + new_impl._element_shape = impl._element_shape; + return new_ta; + } + } +} diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 33bba3dc..fbad178e 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -20,7 +20,8 @@ Building, training and infering deep learning models. https://tensorflownet.readthedocs.io 0.12.0.0 Changes since v0.11.0: - +1: Add ICanBeFlattened for nest.flatten2. +2: 7.3 0.12.0.0 LICENSE diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 161696a1..943edaaf 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -39,7 +39,7 @@ namespace Tensorflow /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// [SuppressMessage("ReSharper", "ConvertToAutoProperty")] - public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike + public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray { private readonly int _id; private readonly Operation _op; @@ -178,7 +178,7 @@ namespace Tensorflow /// public void set_shape(TensorShape shape) { - this.shape = shape.rank > 0 ? shape.dims : null; + this.shape = shape.rank >= 0 ? shape.dims : null; } /// diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 9b0af4f6..28f9ba03 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -223,9 +223,10 @@ namespace Tensorflow.Util } public static object[] flatten2(ICanBeFlattened structure) - { - return structure.Flatten(); - } + => structure.Flatten(); + + public static T[] flatten2(T[] structure) + => structure; private static void _flatten_recursive(T obj, List list) { @@ -423,7 +424,7 @@ namespace Tensorflow.Util /// `flat_sequence` converted to have the same recursive structure as /// `structure`. /// - public static object pack_sequence_as(object structure, IEnumerable flat_sequence) + public static object pack_sequence_as(object structure, IEnumerable flat_sequence, bool expand_composites = false) { List flat = null; if (flat_sequence is List)