diff --git a/README.md b/README.md index 1a203bb7..30643299 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow Run specific example in shell: ```cs -dotnet TensorFlowNET.Examples.dll "EXAMPLE NAME" +dotnet TensorFlowNET.Examples.dll -ex "MNIST CNN" ``` Example runner will download all the required files like training data and model pb files. diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 025d3d57..cb9b0b97 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -79,23 +79,12 @@ namespace Tensorflow /// /// /// A pair (outputs, state) - public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid, - bool swap_memory = false, bool time_major = false) - { - with(variable_scope("rnn"), scope => - { - VariableScope varscope = scope; - var flat_input = nest.flatten(inputs); - - if (!time_major) - { - flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); - //flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); - } - }); - - throw new NotImplementedException(""); - } + public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, + int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, + int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) + => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, + parallel_iterations: parallel_iterations, swap_memory: swap_memory, + time_major: time_major); public static Tensor elu(Tensor features, string name = null) => gen_nn_ops.elu(features, name: name); diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index 1bed4773..61086076 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -27,6 +27,8 @@ namespace Tensorflow int _num_units; Func _activation; + protected override int state_size => _num_units; + public BasicRNNCell(int num_units, Func activation = null, bool? reuse = null, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs new file mode 100644 index 00000000..5f94a0fb --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -0,0 +1,117 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using static Tensorflow.Python; +using Tensorflow.Util; + +namespace Tensorflow.Operations +{ + internal class rnn + { + public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, + int? sequence_length = null, Tensor initial_state = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) + { + with(tf.variable_scope("rnn"), scope => + { + VariableScope varscope = scope; + var flat_input = nest.flatten(inputs); + + if (!time_major) + { + flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); + flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); + } + + parallel_iterations = parallel_iterations ?? 32; + + if (sequence_length.HasValue) + throw new NotImplementedException("dynamic_rnn sequence_length has value"); + + var batch_size = _best_effort_input_batch_size(flat_input); + + if (initial_state != null) + { + var state = initial_state; + } + else + { + cell.get_initial_state(batch_size: batch_size, dtype: dtype); + } + }); + + throw new NotImplementedException(""); + } + + /// + /// Transposes the batch and time dimensions of a Tensor. + /// + /// + /// + public static Tensor _transpose_batch_time(Tensor x) + { + var x_static_shape = x.TensorShape; + if (x_static_shape.NDim == 1) + return x; + + var x_rank = array_ops.rank(x); + var con1 = new object[] + { + new []{1, 0 }, + math_ops.range(2, x_rank) + }; + var x_t = array_ops.transpose(x, array_ops.concat(con1, 0)); + + var dims = new int[] { x_static_shape.Dimensions[1], x_static_shape.Dimensions[0] } + .ToList(); + dims.AddRange(x_static_shape.Dimensions.Skip(2)); + var shape = new TensorShape(dims.ToArray()); + + x_t.SetShape(shape); + + return x_t; + } + + /// + /// Get static input batch size if available, with fallback to the dynamic one. + /// + /// + /// + private static Tensor _best_effort_input_batch_size(List flat_input) + { + foreach(var input_ in flat_input) + { + var shape = input_.TensorShape; + if (shape.NDim < 0) + continue; + if (shape.NDim < 2) + throw new ValueError($"Expected input tensor {input_.name} to have rank at least 2"); + + var batch_size = shape.Dimensions[1]; + if (batch_size > -1) + throw new ValueError("_best_effort_input_batch_size batch_size > -1"); + //return batch_size; + } + + return array_ops.shape(flat_input[0]).slice(1); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/RNNCell.cs index cbfe7db8..79da7761 100644 --- a/src/TensorFlowNET.Core/Operations/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/RNNCell.cs @@ -17,6 +17,8 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Util; +using static Tensorflow.Python; namespace Tensorflow { @@ -48,6 +50,7 @@ namespace Tensorflow /// difference between TF and Keras RNN cell. /// protected bool _is_tf_rnn_cell = false; + protected virtual int state_size { get; } public RNNCell(bool trainable = true, string name = null, @@ -59,5 +62,41 @@ namespace Tensorflow { _is_tf_rnn_cell = true; } + + public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + if (inputs != null) + throw new NotImplementedException("get_initial_state input is not null"); + + return zero_state(batch_size, dtype); + } + + /// + /// Return zero-filled state tensor(s). + /// + /// + /// + /// + public Tensor zero_state(Tensor batch_size, TF_DataType dtype) + { + Tensor output = null; + var state_size = this.state_size; + with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate + { + output = _zero_state_tensors(state_size, batch_size, dtype); + }); + + return output; + } + + private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) + { + nest.map_structure(x => + { + throw new NotImplementedException(""); + }, state_size); + + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 47e4e4aa..ae13b31c 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -512,6 +512,14 @@ namespace Tensorflow.Util return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList(); } + public static Tensor map_structure(Func func, T structure) + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).ToList(); + + return pack_sequence_as(structure, mapped_flat_structure) as Tensor; + } + /// /// Same as map_structure, but with only one structure (no combining of multiple structures) /// diff --git a/test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj b/test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj index f9196dcc..1ba60f1e 100644 --- a/test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj +++ b/test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp3.0 + netcoreapp2.2 diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs index d7cdfb32..2e9160af 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs @@ -30,7 +30,7 @@ namespace TensorFlowNET.Examples.ImageProcess /// public class DigitRecognitionRNN : IExample { - public bool Enabled { get; set; } = false; + public bool Enabled { get; set; } = true; public bool IsImportingGraph { get; set; } = false; public string Name => "MNIST RNN"; @@ -95,7 +95,7 @@ namespace TensorFlowNET.Examples.ImageProcess var init = tf.global_variables_initializer(); sess.run(init); - float loss_val = 100.0f; + float loss_val = 100.0f; float accuracy_val = 0f; foreach (var epoch in range(epochs)) diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 5f011105..0a4c6721 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -29,8 +29,12 @@ namespace TensorFlowNET.Examples { static void Main(string[] args) { + int finished = 0; var errors = new List(); var success = new List(); + + var parsedArgs = ParseArgs(args); + var examples = Assembly.GetEntryAssembly().GetTypes() .Where(x => x.GetInterfaces().Contains(typeof(IExample))) .Select(x => (IExample)Activator.CreateInstance(x)) @@ -38,14 +42,23 @@ namespace TensorFlowNET.Examples .OrderBy(x => x.Name) .ToArray(); + if (parsedArgs.ContainsKey("ex")) + examples = examples.Where(x => x.Name == parsedArgs["ex"]).ToArray(); + Console.WriteLine(Environment.OSVersion.ToString(), Color.Yellow); Console.WriteLine($"TensorFlow Binary v{tf.VERSION}", Color.Yellow); Console.WriteLine($"TensorFlow.NET v{Assembly.GetAssembly(typeof(TF_DataType)).GetName().Version}", Color.Yellow); for (var i = 0; i < examples.Length; i++) Console.WriteLine($"[{i}]: {examples[i].Name}"); - Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow); - var key = Console.ReadLine(); + + var key = "0"; + + if (examples.Length > 1) + { + Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow); + key = Console.ReadLine(); + } var sw = new Stopwatch(); for (var i = 0; i < examples.Length; i++) @@ -72,14 +85,35 @@ namespace TensorFlowNET.Examples Console.WriteLine(ex); } + finished++; Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); } success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green)); errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); - Console.WriteLine($"{examples.Length} examples are completed."); + Console.WriteLine($"{finished} of {examples.Length} example(s) are completed."); Console.ReadLine(); } + + private static Dictionary ParseArgs(string[] args) + { + var parsed = new Dictionary(); + + for (int i = 0; i < args.Length; i++) + { + string key = args[i].Substring(1); + switch (key) + { + case "ex": + parsed.Add(key, args[++i]); + break; + default: + break; + } + } + + return parsed; + } } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index c9b02f0f..b42c70b3 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -6,12 +6,6 @@ false - - - - - -