diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs index fa2fe9d3..c313739b 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -4,22 +4,25 @@ using System.Text; namespace Tensorflow.Operations { - internal class LoopVar + internal class LoopVar : ICanBeFlattened { public Tensor Counter { get; } - public TItem[] Items { get; } public TItem Item { get; } - public LoopVar(Tensor counter, TItem[] items) + public LoopVar(Tensor counter, TItem item) { Counter = counter; - Items = items; + Item = item; } - public LoopVar(Tensor counter, TItem item) + public object[] Flatten() { - Counter = counter; - Item = item; + var elements = new List { Counter }; + if (typeof(TItem).GetInterface(typeof(ICanBeFlattened).Name) != null) + elements.AddRange((Item as ICanBeFlattened).Flatten()); + else + elements.Add(Item); + return elements.ToArray(); } } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index c00fc2c7..462aca25 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -109,7 +109,7 @@ namespace Tensorflow.Operations /// internal Tensor[] BuildLoop(Func pred, Func> body, - TItem loop_vars, + LoopVar loop_vars, TensorShape shape_invariants, bool return_same_structure) { @@ -143,8 +143,8 @@ namespace Tensorflow.Operations private (Tensor[], Tensor[]) _BuildLoop(Func pred, Func> body, - TItem original_loop_vars, - TItem loop_vars, + LoopVar original_loop_vars, + LoopVar loop_vars, TensorShape shape_invariants) { var flat_loop_vars = original_loop_vars; @@ -152,7 +152,7 @@ namespace Tensorflow.Operations // Convert TensorArrays to their flow variables var loop_vars_tensor = nest.map_structure( _convert_tensorarray_to_flow, - nest.flatten(loop_vars)); + nest.flatten2(loop_vars)); // Let the context know the loop variables so the loop variables // would be added in the outer contexts properly. diff --git a/src/TensorFlowNET.Core/Operations/IFlatten.cs b/src/TensorFlowNET.Core/Operations/IFlatten.cs new file mode 100644 index 00000000..305dc72e --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/IFlatten.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public interface ICanBeFlattened + { + object[] Flatten(); + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs index f0086793..9ffea25c 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow.Operations { - internal class BodyItemInRnnWhileLoop + internal class BodyItemInRnnWhileLoop : ICanBeFlattened { /// /// int32 scalar Tensor. @@ -28,5 +28,13 @@ namespace Tensorflow.Operations public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) => (item.time, item.output_ta_t, item.state); + + public object[] Flatten() + { + var elements = new List { time }; + elements.AddRange(output_ta_t); + elements.Add(state); + return elements.ToArray(); + } } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 27e43153..181b7e71 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -642,7 +642,7 @@ namespace Tensorflow if (loop_context.outer_context == null) ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); - var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants, + var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants, return_same_structure); if (maximum_iterations != null) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 598969b7..33bba3dc 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.8.1 + 0.12.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -16,25 +16,13 @@ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. -Docs: https://tensorflownet.readthedocs.io - 0.11.8.1 - Changes since v0.10.0: -1. Upgrade NumSharp to v0.20.3. -2. Add DisposableObject class to manage object lifetime. -3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. -4. Change tensorflow to non-static class in order to execute some initialization process. -5. Overload session.run(), make syntax simpler. -6. Add Local Response Normalization. -7. Add tf.image related APIs. -8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. -9. MultiThread is safe. -10. Support n-dim indexing for tensor. -11. Add RegisterNoGradients -12. Add CumsumGrad, BroadcastToGrad. -13. Return VariableV1 instead of RefVariable. -14. Add Tensor overload to GradientDescentOptimizer. +Building, training and infering deep learning models. +https://tensorflownet.readthedocs.io + 0.12.0.0 + Changes since v0.11.0: + 7.3 - 0.11.8.1 + 0.12.0.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 54ff358a..9b0af4f6 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -19,6 +19,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; using NumSharp; +using Tensorflow.Operations; namespace Tensorflow.Util { @@ -221,6 +222,11 @@ namespace Tensorflow.Util return list; } + public static object[] flatten2(ICanBeFlattened structure) + { + return structure.Flatten(); + } + private static void _flatten_recursive(T obj, List list) { switch(obj)