| @@ -129,7 +129,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow | |||
| Run specific example in shell: | |||
| ```cs | |||
| dotnet TensorFlowNET.Examples.dll "EXAMPLE NAME" | |||
| dotnet TensorFlowNET.Examples.dll -ex "MNIST CNN" | |||
| ``` | |||
| Example runner will download all the required files like training data and model pb files. | |||
| @@ -79,23 +79,12 @@ namespace Tensorflow | |||
| /// <param name="swap_memory"></param> | |||
| /// <param name="time_major"></param> | |||
| /// <returns>A pair (outputs, state)</returns> | |||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool swap_memory = false, bool time_major = false) | |||
| { | |||
| with(variable_scope("rnn"), scope => | |||
| { | |||
| VariableScope varscope = scope; | |||
| var flat_input = nest.flatten(inputs); | |||
| if (!time_major) | |||
| { | |||
| flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); | |||
| //flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); | |||
| } | |||
| }); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, | |||
| int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | |||
| => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, | |||
| parallel_iterations: parallel_iterations, swap_memory: swap_memory, | |||
| time_major: time_major); | |||
| public static Tensor elu(Tensor features, string name = null) | |||
| => gen_nn_ops.elu(features, name: name); | |||
| @@ -27,6 +27,8 @@ namespace Tensorflow | |||
| int _num_units; | |||
| Func<Tensor, string, Tensor> _activation; | |||
| protected override int state_size => _num_units; | |||
| public BasicRNNCell(int num_units, | |||
| Func<Tensor, string, Tensor> activation = null, | |||
| bool? reuse = null, | |||
| @@ -0,0 +1,117 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.Linq; | |||
| using static Tensorflow.Python; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class rnn | |||
| { | |||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, | |||
| int? sequence_length = null, Tensor initial_state = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | |||
| { | |||
| with(tf.variable_scope("rnn"), scope => | |||
| { | |||
| VariableScope varscope = scope; | |||
| var flat_input = nest.flatten(inputs); | |||
| if (!time_major) | |||
| { | |||
| flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); | |||
| flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); | |||
| } | |||
| parallel_iterations = parallel_iterations ?? 32; | |||
| if (sequence_length.HasValue) | |||
| throw new NotImplementedException("dynamic_rnn sequence_length has value"); | |||
| var batch_size = _best_effort_input_batch_size(flat_input); | |||
| if (initial_state != null) | |||
| { | |||
| var state = initial_state; | |||
| } | |||
| else | |||
| { | |||
| cell.get_initial_state(batch_size: batch_size, dtype: dtype); | |||
| } | |||
| }); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| /// <summary> | |||
| /// Transposes the batch and time dimensions of a Tensor. | |||
| /// </summary> | |||
| /// <param name="x"></param> | |||
| /// <returns></returns> | |||
| public static Tensor _transpose_batch_time(Tensor x) | |||
| { | |||
| var x_static_shape = x.TensorShape; | |||
| if (x_static_shape.NDim == 1) | |||
| return x; | |||
| var x_rank = array_ops.rank(x); | |||
| var con1 = new object[] | |||
| { | |||
| new []{1, 0 }, | |||
| math_ops.range(2, x_rank) | |||
| }; | |||
| var x_t = array_ops.transpose(x, array_ops.concat(con1, 0)); | |||
| var dims = new int[] { x_static_shape.Dimensions[1], x_static_shape.Dimensions[0] } | |||
| .ToList(); | |||
| dims.AddRange(x_static_shape.Dimensions.Skip(2)); | |||
| var shape = new TensorShape(dims.ToArray()); | |||
| x_t.SetShape(shape); | |||
| return x_t; | |||
| } | |||
| /// <summary> | |||
| /// Get static input batch size if available, with fallback to the dynamic one. | |||
| /// </summary> | |||
| /// <param name="flat_input"></param> | |||
| /// <returns></returns> | |||
| private static Tensor _best_effort_input_batch_size(List<Tensor> flat_input) | |||
| { | |||
| foreach(var input_ in flat_input) | |||
| { | |||
| var shape = input_.TensorShape; | |||
| if (shape.NDim < 0) | |||
| continue; | |||
| if (shape.NDim < 2) | |||
| throw new ValueError($"Expected input tensor {input_.name} to have rank at least 2"); | |||
| var batch_size = shape.Dimensions[1]; | |||
| if (batch_size > -1) | |||
| throw new ValueError("_best_effort_input_batch_size batch_size > -1"); | |||
| //return batch_size; | |||
| } | |||
| return array_ops.shape(flat_input[0]).slice(1); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,8 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Python; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -48,6 +50,7 @@ namespace Tensorflow | |||
| /// difference between TF and Keras RNN cell. | |||
| /// </summary> | |||
| protected bool _is_tf_rnn_cell = false; | |||
| protected virtual int state_size { get; } | |||
| public RNNCell(bool trainable = true, | |||
| string name = null, | |||
| @@ -59,5 +62,41 @@ namespace Tensorflow | |||
| { | |||
| _is_tf_rnn_cell = true; | |||
| } | |||
| public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| if (inputs != null) | |||
| throw new NotImplementedException("get_initial_state input is not null"); | |||
| return zero_state(batch_size, dtype); | |||
| } | |||
| /// <summary> | |||
| /// Return zero-filled state tensor(s). | |||
| /// </summary> | |||
| /// <param name="batch_size"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <returns></returns> | |||
| public Tensor zero_state(Tensor batch_size, TF_DataType dtype) | |||
| { | |||
| Tensor output = null; | |||
| var state_size = this.state_size; | |||
| with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate | |||
| { | |||
| output = _zero_state_tensors(state_size, batch_size, dtype); | |||
| }); | |||
| return output; | |||
| } | |||
| private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) | |||
| { | |||
| nest.map_structure(x => | |||
| { | |||
| throw new NotImplementedException(""); | |||
| }, state_size); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -512,6 +512,14 @@ namespace Tensorflow.Util | |||
| return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList(); | |||
| } | |||
| public static Tensor map_structure<T>(Func<T, Tensor> func, T structure) | |||
| { | |||
| var flat_structure = flatten(structure); | |||
| var mapped_flat_structure = flat_structure.Select(func).ToList(); | |||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
| } | |||
| /// <summary> | |||
| /// Same as map_structure, but with only one structure (no combining of multiple structures) | |||
| /// </summary> | |||
| @@ -2,7 +2,7 @@ | |||
| <PropertyGroup> | |||
| <OutputType>Exe</OutputType> | |||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| @@ -30,7 +30,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
| /// </summary> | |||
| public class DigitRecognitionRNN : IExample | |||
| { | |||
| public bool Enabled { get; set; } = false; | |||
| public bool Enabled { get; set; } = true; | |||
| public bool IsImportingGraph { get; set; } = false; | |||
| public string Name => "MNIST RNN"; | |||
| @@ -95,7 +95,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
| var init = tf.global_variables_initializer(); | |||
| sess.run(init); | |||
| float loss_val = 100.0f; | |||
| float loss_val = 100.0f; | |||
| float accuracy_val = 0f; | |||
| foreach (var epoch in range(epochs)) | |||
| @@ -29,8 +29,12 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| static void Main(string[] args) | |||
| { | |||
| int finished = 0; | |||
| var errors = new List<string>(); | |||
| var success = new List<string>(); | |||
| var parsedArgs = ParseArgs(args); | |||
| var examples = Assembly.GetEntryAssembly().GetTypes() | |||
| .Where(x => x.GetInterfaces().Contains(typeof(IExample))) | |||
| .Select(x => (IExample)Activator.CreateInstance(x)) | |||
| @@ -38,14 +42,23 @@ namespace TensorFlowNET.Examples | |||
| .OrderBy(x => x.Name) | |||
| .ToArray(); | |||
| if (parsedArgs.ContainsKey("ex")) | |||
| examples = examples.Where(x => x.Name == parsedArgs["ex"]).ToArray(); | |||
| Console.WriteLine(Environment.OSVersion.ToString(), Color.Yellow); | |||
| Console.WriteLine($"TensorFlow Binary v{tf.VERSION}", Color.Yellow); | |||
| Console.WriteLine($"TensorFlow.NET v{Assembly.GetAssembly(typeof(TF_DataType)).GetName().Version}", Color.Yellow); | |||
| for (var i = 0; i < examples.Length; i++) | |||
| Console.WriteLine($"[{i}]: {examples[i].Name}"); | |||
| Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow); | |||
| var key = Console.ReadLine(); | |||
| var key = "0"; | |||
| if (examples.Length > 1) | |||
| { | |||
| Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow); | |||
| key = Console.ReadLine(); | |||
| } | |||
| var sw = new Stopwatch(); | |||
| for (var i = 0; i < examples.Length; i++) | |||
| @@ -72,14 +85,35 @@ namespace TensorFlowNET.Examples | |||
| Console.WriteLine(ex); | |||
| } | |||
| finished++; | |||
| Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); | |||
| } | |||
| success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green)); | |||
| errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); | |||
| Console.WriteLine($"{examples.Length} examples are completed."); | |||
| Console.WriteLine($"{finished} of {examples.Length} example(s) are completed."); | |||
| Console.ReadLine(); | |||
| } | |||
| private static Dictionary<string, string> ParseArgs(string[] args) | |||
| { | |||
| var parsed = new Dictionary<string, string>(); | |||
| for (int i = 0; i < args.Length; i++) | |||
| { | |||
| string key = args[i].Substring(1); | |||
| switch (key) | |||
| { | |||
| case "ex": | |||
| parsed.Add(key, args[++i]); | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| return parsed; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,12 +6,6 @@ | |||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <Compile Remove="python\**" /> | |||
| <EmbeddedResource Remove="python\**" /> | |||
| <None Remove="python\**" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="Colorful.Console" Version="1.2.9" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||