| @@ -129,7 +129,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow | |||||
| Run specific example in shell: | Run specific example in shell: | ||||
| ```cs | ```cs | ||||
| dotnet TensorFlowNET.Examples.dll "EXAMPLE NAME" | |||||
| dotnet TensorFlowNET.Examples.dll -ex "MNIST CNN" | |||||
| ``` | ``` | ||||
| Example runner will download all the required files like training data and model pb files. | Example runner will download all the required files like training data and model pb files. | ||||
| @@ -79,23 +79,12 @@ namespace Tensorflow | |||||
| /// <param name="swap_memory"></param> | /// <param name="swap_memory"></param> | ||||
| /// <param name="time_major"></param> | /// <param name="time_major"></param> | ||||
| /// <returns>A pair (outputs, state)</returns> | /// <returns>A pair (outputs, state)</returns> | ||||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| bool swap_memory = false, bool time_major = false) | |||||
| { | |||||
| with(variable_scope("rnn"), scope => | |||||
| { | |||||
| VariableScope varscope = scope; | |||||
| var flat_input = nest.flatten(inputs); | |||||
| if (!time_major) | |||||
| { | |||||
| flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); | |||||
| //flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); | |||||
| } | |||||
| }); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, | |||||
| int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | |||||
| => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, | |||||
| parallel_iterations: parallel_iterations, swap_memory: swap_memory, | |||||
| time_major: time_major); | |||||
| public static Tensor elu(Tensor features, string name = null) | public static Tensor elu(Tensor features, string name = null) | ||||
| => gen_nn_ops.elu(features, name: name); | => gen_nn_ops.elu(features, name: name); | ||||
| @@ -27,6 +27,8 @@ namespace Tensorflow | |||||
| int _num_units; | int _num_units; | ||||
| Func<Tensor, string, Tensor> _activation; | Func<Tensor, string, Tensor> _activation; | ||||
| protected override int state_size => _num_units; | |||||
| public BasicRNNCell(int num_units, | public BasicRNNCell(int num_units, | ||||
| Func<Tensor, string, Tensor> activation = null, | Func<Tensor, string, Tensor> activation = null, | ||||
| bool? reuse = null, | bool? reuse = null, | ||||
| @@ -0,0 +1,117 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using System.Linq; | |||||
| using static Tensorflow.Python; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| internal class rnn | |||||
| { | |||||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, | |||||
| int? sequence_length = null, Tensor initial_state = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | |||||
| { | |||||
| with(tf.variable_scope("rnn"), scope => | |||||
| { | |||||
| VariableScope varscope = scope; | |||||
| var flat_input = nest.flatten(inputs); | |||||
| if (!time_major) | |||||
| { | |||||
| flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); | |||||
| flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); | |||||
| } | |||||
| parallel_iterations = parallel_iterations ?? 32; | |||||
| if (sequence_length.HasValue) | |||||
| throw new NotImplementedException("dynamic_rnn sequence_length has value"); | |||||
| var batch_size = _best_effort_input_batch_size(flat_input); | |||||
| if (initial_state != null) | |||||
| { | |||||
| var state = initial_state; | |||||
| } | |||||
| else | |||||
| { | |||||
| cell.get_initial_state(batch_size: batch_size, dtype: dtype); | |||||
| } | |||||
| }); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | |||||
| /// Transposes the batch and time dimensions of a Tensor. | |||||
| /// </summary> | |||||
| /// <param name="x"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor _transpose_batch_time(Tensor x) | |||||
| { | |||||
| var x_static_shape = x.TensorShape; | |||||
| if (x_static_shape.NDim == 1) | |||||
| return x; | |||||
| var x_rank = array_ops.rank(x); | |||||
| var con1 = new object[] | |||||
| { | |||||
| new []{1, 0 }, | |||||
| math_ops.range(2, x_rank) | |||||
| }; | |||||
| var x_t = array_ops.transpose(x, array_ops.concat(con1, 0)); | |||||
| var dims = new int[] { x_static_shape.Dimensions[1], x_static_shape.Dimensions[0] } | |||||
| .ToList(); | |||||
| dims.AddRange(x_static_shape.Dimensions.Skip(2)); | |||||
| var shape = new TensorShape(dims.ToArray()); | |||||
| x_t.SetShape(shape); | |||||
| return x_t; | |||||
| } | |||||
| /// <summary> | |||||
| /// Get static input batch size if available, with fallback to the dynamic one. | |||||
| /// </summary> | |||||
| /// <param name="flat_input"></param> | |||||
| /// <returns></returns> | |||||
| private static Tensor _best_effort_input_batch_size(List<Tensor> flat_input) | |||||
| { | |||||
| foreach(var input_ in flat_input) | |||||
| { | |||||
| var shape = input_.TensorShape; | |||||
| if (shape.NDim < 0) | |||||
| continue; | |||||
| if (shape.NDim < 2) | |||||
| throw new ValueError($"Expected input tensor {input_.name} to have rank at least 2"); | |||||
| var batch_size = shape.Dimensions[1]; | |||||
| if (batch_size > -1) | |||||
| throw new ValueError("_best_effort_input_batch_size batch_size > -1"); | |||||
| //return batch_size; | |||||
| } | |||||
| return array_ops.shape(flat_input[0]).slice(1); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -17,6 +17,8 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Python; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -48,6 +50,7 @@ namespace Tensorflow | |||||
| /// difference between TF and Keras RNN cell. | /// difference between TF and Keras RNN cell. | ||||
| /// </summary> | /// </summary> | ||||
| protected bool _is_tf_rnn_cell = false; | protected bool _is_tf_rnn_cell = false; | ||||
| protected virtual int state_size { get; } | |||||
| public RNNCell(bool trainable = true, | public RNNCell(bool trainable = true, | ||||
| string name = null, | string name = null, | ||||
| @@ -59,5 +62,41 @@ namespace Tensorflow | |||||
| { | { | ||||
| _is_tf_rnn_cell = true; | _is_tf_rnn_cell = true; | ||||
| } | } | ||||
| public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | |||||
| if (inputs != null) | |||||
| throw new NotImplementedException("get_initial_state input is not null"); | |||||
| return zero_state(batch_size, dtype); | |||||
| } | |||||
| /// <summary> | |||||
| /// Return zero-filled state tensor(s). | |||||
| /// </summary> | |||||
| /// <param name="batch_size"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor zero_state(Tensor batch_size, TF_DataType dtype) | |||||
| { | |||||
| Tensor output = null; | |||||
| var state_size = this.state_size; | |||||
| with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate | |||||
| { | |||||
| output = _zero_state_tensors(state_size, batch_size, dtype); | |||||
| }); | |||||
| return output; | |||||
| } | |||||
| private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) | |||||
| { | |||||
| nest.map_structure(x => | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| }, state_size); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -512,6 +512,14 @@ namespace Tensorflow.Util | |||||
| return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList(); | return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList(); | ||||
| } | } | ||||
| public static Tensor map_structure<T>(Func<T, Tensor> func, T structure) | |||||
| { | |||||
| var flat_structure = flatten(structure); | |||||
| var mapped_flat_structure = flat_structure.Select(func).ToList(); | |||||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Same as map_structure, but with only one structure (no combining of multiple structures) | /// Same as map_structure, but with only one structure (no combining of multiple structures) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -2,7 +2,7 @@ | |||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <OutputType>Exe</OutputType> | <OutputType>Exe</OutputType> | ||||
| <TargetFramework>netcoreapp3.0</TargetFramework> | |||||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -30,7 +30,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| /// </summary> | /// </summary> | ||||
| public class DigitRecognitionRNN : IExample | public class DigitRecognitionRNN : IExample | ||||
| { | { | ||||
| public bool Enabled { get; set; } = false; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public bool IsImportingGraph { get; set; } = false; | public bool IsImportingGraph { get; set; } = false; | ||||
| public string Name => "MNIST RNN"; | public string Name => "MNIST RNN"; | ||||
| @@ -95,7 +95,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
| var init = tf.global_variables_initializer(); | var init = tf.global_variables_initializer(); | ||||
| sess.run(init); | sess.run(init); | ||||
| float loss_val = 100.0f; | |||||
| float loss_val = 100.0f; | |||||
| float accuracy_val = 0f; | float accuracy_val = 0f; | ||||
| foreach (var epoch in range(epochs)) | foreach (var epoch in range(epochs)) | ||||
| @@ -29,8 +29,12 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| static void Main(string[] args) | static void Main(string[] args) | ||||
| { | { | ||||
| int finished = 0; | |||||
| var errors = new List<string>(); | var errors = new List<string>(); | ||||
| var success = new List<string>(); | var success = new List<string>(); | ||||
| var parsedArgs = ParseArgs(args); | |||||
| var examples = Assembly.GetEntryAssembly().GetTypes() | var examples = Assembly.GetEntryAssembly().GetTypes() | ||||
| .Where(x => x.GetInterfaces().Contains(typeof(IExample))) | .Where(x => x.GetInterfaces().Contains(typeof(IExample))) | ||||
| .Select(x => (IExample)Activator.CreateInstance(x)) | .Select(x => (IExample)Activator.CreateInstance(x)) | ||||
| @@ -38,14 +42,23 @@ namespace TensorFlowNET.Examples | |||||
| .OrderBy(x => x.Name) | .OrderBy(x => x.Name) | ||||
| .ToArray(); | .ToArray(); | ||||
| if (parsedArgs.ContainsKey("ex")) | |||||
| examples = examples.Where(x => x.Name == parsedArgs["ex"]).ToArray(); | |||||
| Console.WriteLine(Environment.OSVersion.ToString(), Color.Yellow); | Console.WriteLine(Environment.OSVersion.ToString(), Color.Yellow); | ||||
| Console.WriteLine($"TensorFlow Binary v{tf.VERSION}", Color.Yellow); | Console.WriteLine($"TensorFlow Binary v{tf.VERSION}", Color.Yellow); | ||||
| Console.WriteLine($"TensorFlow.NET v{Assembly.GetAssembly(typeof(TF_DataType)).GetName().Version}", Color.Yellow); | Console.WriteLine($"TensorFlow.NET v{Assembly.GetAssembly(typeof(TF_DataType)).GetName().Version}", Color.Yellow); | ||||
| for (var i = 0; i < examples.Length; i++) | for (var i = 0; i < examples.Length; i++) | ||||
| Console.WriteLine($"[{i}]: {examples[i].Name}"); | Console.WriteLine($"[{i}]: {examples[i].Name}"); | ||||
| Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow); | |||||
| var key = Console.ReadLine(); | |||||
| var key = "0"; | |||||
| if (examples.Length > 1) | |||||
| { | |||||
| Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow); | |||||
| key = Console.ReadLine(); | |||||
| } | |||||
| var sw = new Stopwatch(); | var sw = new Stopwatch(); | ||||
| for (var i = 0; i < examples.Length; i++) | for (var i = 0; i < examples.Length; i++) | ||||
| @@ -72,14 +85,35 @@ namespace TensorFlowNET.Examples | |||||
| Console.WriteLine(ex); | Console.WriteLine(ex); | ||||
| } | } | ||||
| finished++; | |||||
| Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); | Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); | ||||
| } | } | ||||
| success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green)); | success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green)); | ||||
| errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); | errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); | ||||
| Console.WriteLine($"{examples.Length} examples are completed."); | |||||
| Console.WriteLine($"{finished} of {examples.Length} example(s) are completed."); | |||||
| Console.ReadLine(); | Console.ReadLine(); | ||||
| } | } | ||||
| private static Dictionary<string, string> ParseArgs(string[] args) | |||||
| { | |||||
| var parsed = new Dictionary<string, string>(); | |||||
| for (int i = 0; i < args.Length; i++) | |||||
| { | |||||
| string key = args[i].Substring(1); | |||||
| switch (key) | |||||
| { | |||||
| case "ex": | |||||
| parsed.Add(key, args[++i]); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| return parsed; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,12 +6,6 @@ | |||||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | |||||
| <Compile Remove="python\**" /> | |||||
| <EmbeddedResource Remove="python\**" /> | |||||
| <None Remove="python\**" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Colorful.Console" Version="1.2.9" /> | <PackageReference Include="Colorful.Console" Version="1.2.9" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | ||||