diff --git a/README.md b/README.md
index 1a203bb7..30643299 100644
--- a/README.md
+++ b/README.md
@@ -129,7 +129,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow
Run specific example in shell:
```cs
-dotnet TensorFlowNET.Examples.dll "EXAMPLE NAME"
+dotnet TensorFlowNET.Examples.dll -ex "MNIST CNN"
```
Example runner will download all the required files like training data and model pb files.
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 025d3d57..cb9b0b97 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -79,23 +79,12 @@ namespace Tensorflow
///
///
/// A pair (outputs, state)
- public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid,
- bool swap_memory = false, bool time_major = false)
- {
- with(variable_scope("rnn"), scope =>
- {
- VariableScope varscope = scope;
- var flat_input = nest.flatten(inputs);
-
- if (!time_major)
- {
- flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList();
- //flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList();
- }
- });
-
- throw new NotImplementedException("");
- }
+ public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
+ int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
+ int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
+ => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,
+ parallel_iterations: parallel_iterations, swap_memory: swap_memory,
+ time_major: time_major);
public static Tensor elu(Tensor features, string name = null)
=> gen_nn_ops.elu(features, name: name);
diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
index 1bed4773..61086076 100644
--- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
@@ -27,6 +27,8 @@ namespace Tensorflow
int _num_units;
Func _activation;
+ protected override int state_size => _num_units;
+
public BasicRNNCell(int num_units,
Func activation = null,
bool? reuse = null,
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
new file mode 100644
index 00000000..5f94a0fb
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -0,0 +1,117 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+using System.Linq;
+using static Tensorflow.Python;
+using Tensorflow.Util;
+
+namespace Tensorflow.Operations
+{
+ internal class rnn
+ {
+ public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
+ int? sequence_length = null, Tensor initial_state = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
+ {
+ with(tf.variable_scope("rnn"), scope =>
+ {
+ VariableScope varscope = scope;
+ var flat_input = nest.flatten(inputs);
+
+ if (!time_major)
+ {
+ flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList();
+ flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList();
+ }
+
+ parallel_iterations = parallel_iterations ?? 32;
+
+ if (sequence_length.HasValue)
+ throw new NotImplementedException("dynamic_rnn sequence_length has value");
+
+ var batch_size = _best_effort_input_batch_size(flat_input);
+
+ if (initial_state != null)
+ {
+ var state = initial_state;
+ }
+ else
+ {
+ cell.get_initial_state(batch_size: batch_size, dtype: dtype);
+ }
+ });
+
+ throw new NotImplementedException("");
+ }
+
+ ///
+ /// Transposes the batch and time dimensions of a Tensor.
+ ///
+ ///
+ ///
+ public static Tensor _transpose_batch_time(Tensor x)
+ {
+ var x_static_shape = x.TensorShape;
+ if (x_static_shape.NDim == 1)
+ return x;
+
+ var x_rank = array_ops.rank(x);
+ var con1 = new object[]
+ {
+ new []{1, 0 },
+ math_ops.range(2, x_rank)
+ };
+ var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));
+
+ var dims = new int[] { x_static_shape.Dimensions[1], x_static_shape.Dimensions[0] }
+ .ToList();
+ dims.AddRange(x_static_shape.Dimensions.Skip(2));
+ var shape = new TensorShape(dims.ToArray());
+
+ x_t.SetShape(shape);
+
+ return x_t;
+ }
+
+ ///
+ /// Get static input batch size if available, with fallback to the dynamic one.
+ ///
+ ///
+ ///
+ private static Tensor _best_effort_input_batch_size(List flat_input)
+ {
+ foreach(var input_ in flat_input)
+ {
+ var shape = input_.TensorShape;
+ if (shape.NDim < 0)
+ continue;
+ if (shape.NDim < 2)
+ throw new ValueError($"Expected input tensor {input_.name} to have rank at least 2");
+
+ var batch_size = shape.Dimensions[1];
+ if (batch_size > -1)
+ throw new ValueError("_best_effort_input_batch_size batch_size > -1");
+ //return batch_size;
+ }
+
+ return array_ops.shape(flat_input[0]).slice(1);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/RNNCell.cs
index cbfe7db8..79da7761 100644
--- a/src/TensorFlowNET.Core/Operations/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/RNNCell.cs
@@ -17,6 +17,8 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Util;
+using static Tensorflow.Python;
namespace Tensorflow
{
@@ -48,6 +50,7 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell.
///
protected bool _is_tf_rnn_cell = false;
+ protected virtual int state_size { get; }
public RNNCell(bool trainable = true,
string name = null,
@@ -59,5 +62,41 @@ namespace Tensorflow
{
_is_tf_rnn_cell = true;
}
+
+ public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ if (inputs != null)
+ throw new NotImplementedException("get_initial_state input is not null");
+
+ return zero_state(batch_size, dtype);
+ }
+
+ ///
+ /// Return zero-filled state tensor(s).
+ ///
+ ///
+ ///
+ ///
+ public Tensor zero_state(Tensor batch_size, TF_DataType dtype)
+ {
+ Tensor output = null;
+ var state_size = this.state_size;
+ with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate
+ {
+ output = _zero_state_tensors(state_size, batch_size, dtype);
+ });
+
+ return output;
+ }
+
+ private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype)
+ {
+ nest.map_structure(x =>
+ {
+ throw new NotImplementedException("");
+ }, state_size);
+
+ throw new NotImplementedException("");
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs
index 47e4e4aa..ae13b31c 100644
--- a/src/TensorFlowNET.Core/Util/nest.py.cs
+++ b/src/TensorFlowNET.Core/Util/nest.py.cs
@@ -512,6 +512,14 @@ namespace Tensorflow.Util
return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList();
}
+ public static Tensor map_structure(Func func, T structure)
+ {
+ var flat_structure = flatten(structure);
+ var mapped_flat_structure = flat_structure.Select(func).ToList();
+
+ return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
+ }
+
///
/// Same as map_structure, but with only one structure (no combining of multiple structures)
///
diff --git a/test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj b/test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj
index f9196dcc..1ba60f1e 100644
--- a/test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj
+++ b/test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp3.0
+ netcoreapp2.2
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
index d7cdfb32..2e9160af 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
@@ -30,7 +30,7 @@ namespace TensorFlowNET.Examples.ImageProcess
///
public class DigitRecognitionRNN : IExample
{
- public bool Enabled { get; set; } = false;
+ public bool Enabled { get; set; } = true;
public bool IsImportingGraph { get; set; } = false;
public string Name => "MNIST RNN";
@@ -95,7 +95,7 @@ namespace TensorFlowNET.Examples.ImageProcess
var init = tf.global_variables_initializer();
sess.run(init);
- float loss_val = 100.0f;
+ float loss_val = 100.0f;
float accuracy_val = 0f;
foreach (var epoch in range(epochs))
diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs
index 5f011105..0a4c6721 100644
--- a/test/TensorFlowNET.Examples/Program.cs
+++ b/test/TensorFlowNET.Examples/Program.cs
@@ -29,8 +29,12 @@ namespace TensorFlowNET.Examples
{
static void Main(string[] args)
{
+ int finished = 0;
var errors = new List();
var success = new List();
+
+ var parsedArgs = ParseArgs(args);
+
var examples = Assembly.GetEntryAssembly().GetTypes()
.Where(x => x.GetInterfaces().Contains(typeof(IExample)))
.Select(x => (IExample)Activator.CreateInstance(x))
@@ -38,14 +42,23 @@ namespace TensorFlowNET.Examples
.OrderBy(x => x.Name)
.ToArray();
+ if (parsedArgs.ContainsKey("ex"))
+ examples = examples.Where(x => x.Name == parsedArgs["ex"]).ToArray();
+
Console.WriteLine(Environment.OSVersion.ToString(), Color.Yellow);
Console.WriteLine($"TensorFlow Binary v{tf.VERSION}", Color.Yellow);
Console.WriteLine($"TensorFlow.NET v{Assembly.GetAssembly(typeof(TF_DataType)).GetName().Version}", Color.Yellow);
for (var i = 0; i < examples.Length; i++)
Console.WriteLine($"[{i}]: {examples[i].Name}");
- Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow);
- var key = Console.ReadLine();
+
+ var key = "0";
+
+ if (examples.Length > 1)
+ {
+ Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow);
+ key = Console.ReadLine();
+ }
var sw = new Stopwatch();
for (var i = 0; i < examples.Length; i++)
@@ -72,14 +85,35 @@ namespace TensorFlowNET.Examples
Console.WriteLine(ex);
}
+ finished++;
Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White);
}
success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green));
errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red));
- Console.WriteLine($"{examples.Length} examples are completed.");
+ Console.WriteLine($"{finished} of {examples.Length} example(s) are completed.");
Console.ReadLine();
}
+
+ private static Dictionary ParseArgs(string[] args)
+ {
+ var parsed = new Dictionary();
+
+ for (int i = 0; i < args.Length; i++)
+ {
+ string key = args[i].Substring(1);
+ switch (key)
+ {
+ case "ex":
+ parsed.Add(key, args[++i]);
+ break;
+ default:
+ break;
+ }
+ }
+
+ return parsed;
+ }
}
}
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
index c9b02f0f..b42c70b3 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
@@ -6,12 +6,6 @@
false
-
-
-
-
-
-