Browse Source

add dynamic_rnn, _transpose_batch_time for RNN.

tags/v0.10
Oceania2018 6 years ago
parent
commit
d674d51df6
10 changed files with 213 additions and 30 deletions
  1. +1
    -1
      README.md
  2. +6
    -17
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  3. +2
    -0
      src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
  4. +117
    -0
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  5. +39
    -0
      src/TensorFlowNET.Core/Operations/RNNCell.cs
  6. +8
    -0
      src/TensorFlowNET.Core/Util/nest.py.cs
  7. +1
    -1
      test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj
  8. +2
    -2
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
  9. +37
    -3
      test/TensorFlowNET.Examples/Program.cs
  10. +0
    -6
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj

+ 1
- 1
README.md View File

@@ -129,7 +129,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow
Run specific example in shell: Run specific example in shell:


```cs ```cs
dotnet TensorFlowNET.Examples.dll "EXAMPLE NAME"
dotnet TensorFlowNET.Examples.dll -ex "MNIST CNN"
``` ```


Example runner will download all the required files like training data and model pb files. Example runner will download all the required files like training data and model pb files.


+ 6
- 17
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -79,23 +79,12 @@ namespace Tensorflow
/// <param name="swap_memory"></param> /// <param name="swap_memory"></param>
/// <param name="time_major"></param> /// <param name="time_major"></param>
/// <returns>A pair (outputs, state)</returns> /// <returns>A pair (outputs, state)</returns>
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid,
bool swap_memory = false, bool time_major = false)
{
with(variable_scope("rnn"), scope =>
{
VariableScope varscope = scope;
var flat_input = nest.flatten(inputs);

if (!time_major)
{
flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList();
//flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList();
}
});

throw new NotImplementedException("");
}
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
=> rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,
parallel_iterations: parallel_iterations, swap_memory: swap_memory,
time_major: time_major);


public static Tensor elu(Tensor features, string name = null) public static Tensor elu(Tensor features, string name = null)
=> gen_nn_ops.elu(features, name: name); => gen_nn_ops.elu(features, name: name);


+ 2
- 0
src/TensorFlowNET.Core/Operations/BasicRNNCell.cs View File

@@ -27,6 +27,8 @@ namespace Tensorflow
int _num_units; int _num_units;
Func<Tensor, string, Tensor> _activation; Func<Tensor, string, Tensor> _activation;


protected override int state_size => _num_units;

public BasicRNNCell(int num_units, public BasicRNNCell(int num_units,
Func<Tensor, string, Tensor> activation = null, Func<Tensor, string, Tensor> activation = null,
bool? reuse = null, bool? reuse = null,


+ 117
- 0
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -0,0 +1,117 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using System.Linq;
using static Tensorflow.Python;
using Tensorflow.Util;

namespace Tensorflow.Operations
{
internal class rnn
{
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
int? sequence_length = null, Tensor initial_state = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
{
with(tf.variable_scope("rnn"), scope =>
{
VariableScope varscope = scope;
var flat_input = nest.flatten(inputs);

if (!time_major)
{
flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList();
flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList();
}

parallel_iterations = parallel_iterations ?? 32;

if (sequence_length.HasValue)
throw new NotImplementedException("dynamic_rnn sequence_length has value");

var batch_size = _best_effort_input_batch_size(flat_input);

if (initial_state != null)
{
var state = initial_state;
}
else
{
cell.get_initial_state(batch_size: batch_size, dtype: dtype);
}
});

throw new NotImplementedException("");
}

/// <summary>
/// Transposes the batch and time dimensions of a Tensor.
/// </summary>
/// <param name="x"></param>
/// <returns></returns>
public static Tensor _transpose_batch_time(Tensor x)
{
var x_static_shape = x.TensorShape;
if (x_static_shape.NDim == 1)
return x;

var x_rank = array_ops.rank(x);
var con1 = new object[]
{
new []{1, 0 },
math_ops.range(2, x_rank)
};
var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));

var dims = new int[] { x_static_shape.Dimensions[1], x_static_shape.Dimensions[0] }
.ToList();
dims.AddRange(x_static_shape.Dimensions.Skip(2));
var shape = new TensorShape(dims.ToArray());

x_t.SetShape(shape);

return x_t;
}

/// <summary>
/// Get static input batch size if available, with fallback to the dynamic one.
/// </summary>
/// <param name="flat_input"></param>
/// <returns></returns>
private static Tensor _best_effort_input_batch_size(List<Tensor> flat_input)
{
foreach(var input_ in flat_input)
{
var shape = input_.TensorShape;
if (shape.NDim < 0)
continue;
if (shape.NDim < 2)
throw new ValueError($"Expected input tensor {input_.name} to have rank at least 2");

var batch_size = shape.Dimensions[1];
if (batch_size > -1)
throw new ValueError("_best_effort_input_batch_size batch_size > -1");
//return batch_size;
}

return array_ops.shape(flat_input[0]).slice(1);
}
}
}

+ 39
- 0
src/TensorFlowNET.Core/Operations/RNNCell.cs View File

@@ -17,6 +17,8 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Util;
using static Tensorflow.Python;


namespace Tensorflow namespace Tensorflow
{ {
@@ -48,6 +50,7 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell. /// difference between TF and Keras RNN cell.
/// </summary> /// </summary>
protected bool _is_tf_rnn_cell = false; protected bool _is_tf_rnn_cell = false;
protected virtual int state_size { get; }


public RNNCell(bool trainable = true, public RNNCell(bool trainable = true,
string name = null, string name = null,
@@ -59,5 +62,41 @@ namespace Tensorflow
{ {
_is_tf_rnn_cell = true; _is_tf_rnn_cell = true;
} }

public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
if (inputs != null)
throw new NotImplementedException("get_initial_state input is not null");

return zero_state(batch_size, dtype);
}

/// <summary>
/// Return zero-filled state tensor(s).
/// </summary>
/// <param name="batch_size"></param>
/// <param name="dtype"></param>
/// <returns></returns>
public Tensor zero_state(Tensor batch_size, TF_DataType dtype)
{
Tensor output = null;
var state_size = this.state_size;
with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate
{
output = _zero_state_tensors(state_size, batch_size, dtype);
});

return output;
}

private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype)
{
nest.map_structure(x =>
{
throw new NotImplementedException("");
}, state_size);

throw new NotImplementedException("");
}
} }
} }

+ 8
- 0
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -512,6 +512,14 @@ namespace Tensorflow.Util
return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList(); return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList();
} }
public static Tensor map_structure<T>(Func<T, Tensor> func, T structure)
{
var flat_structure = flatten(structure);
var mapped_flat_structure = flat_structure.Select(func).ToList();
return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
}
/// <summary> /// <summary>
/// Same as map_structure, but with only one structure (no combining of multiple structures) /// Same as map_structure, but with only one structure (no combining of multiple structures)
/// </summary> /// </summary>


+ 1
- 1
test/TensorFlowHub.Examples/TensorFlowHub.Examples.csproj View File

@@ -2,7 +2,7 @@


<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
<TargetFramework>netcoreapp3.0</TargetFramework>
<TargetFramework>netcoreapp2.2</TargetFramework>
</PropertyGroup> </PropertyGroup>


<ItemGroup> <ItemGroup>


+ 2
- 2
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs View File

@@ -30,7 +30,7 @@ namespace TensorFlowNET.Examples.ImageProcess
/// </summary> /// </summary>
public class DigitRecognitionRNN : IExample public class DigitRecognitionRNN : IExample
{ {
public bool Enabled { get; set; } = false;
public bool Enabled { get; set; } = true;
public bool IsImportingGraph { get; set; } = false; public bool IsImportingGraph { get; set; } = false;


public string Name => "MNIST RNN"; public string Name => "MNIST RNN";
@@ -95,7 +95,7 @@ namespace TensorFlowNET.Examples.ImageProcess
var init = tf.global_variables_initializer(); var init = tf.global_variables_initializer();
sess.run(init); sess.run(init);


float loss_val = 100.0f;
float loss_val = 100.0f;
float accuracy_val = 0f; float accuracy_val = 0f;


foreach (var epoch in range(epochs)) foreach (var epoch in range(epochs))


+ 37
- 3
test/TensorFlowNET.Examples/Program.cs View File

@@ -29,8 +29,12 @@ namespace TensorFlowNET.Examples
{ {
static void Main(string[] args) static void Main(string[] args)
{ {
int finished = 0;
var errors = new List<string>(); var errors = new List<string>();
var success = new List<string>(); var success = new List<string>();

var parsedArgs = ParseArgs(args);

var examples = Assembly.GetEntryAssembly().GetTypes() var examples = Assembly.GetEntryAssembly().GetTypes()
.Where(x => x.GetInterfaces().Contains(typeof(IExample))) .Where(x => x.GetInterfaces().Contains(typeof(IExample)))
.Select(x => (IExample)Activator.CreateInstance(x)) .Select(x => (IExample)Activator.CreateInstance(x))
@@ -38,14 +42,23 @@ namespace TensorFlowNET.Examples
.OrderBy(x => x.Name) .OrderBy(x => x.Name)
.ToArray(); .ToArray();


if (parsedArgs.ContainsKey("ex"))
examples = examples.Where(x => x.Name == parsedArgs["ex"]).ToArray();

Console.WriteLine(Environment.OSVersion.ToString(), Color.Yellow); Console.WriteLine(Environment.OSVersion.ToString(), Color.Yellow);
Console.WriteLine($"TensorFlow Binary v{tf.VERSION}", Color.Yellow); Console.WriteLine($"TensorFlow Binary v{tf.VERSION}", Color.Yellow);
Console.WriteLine($"TensorFlow.NET v{Assembly.GetAssembly(typeof(TF_DataType)).GetName().Version}", Color.Yellow); Console.WriteLine($"TensorFlow.NET v{Assembly.GetAssembly(typeof(TF_DataType)).GetName().Version}", Color.Yellow);


for (var i = 0; i < examples.Length; i++) for (var i = 0; i < examples.Length; i++)
Console.WriteLine($"[{i}]: {examples[i].Name}"); Console.WriteLine($"[{i}]: {examples[i].Name}");
Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow);
var key = Console.ReadLine();

var key = "0";
if (examples.Length > 1)
{
Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow);
key = Console.ReadLine();
}


var sw = new Stopwatch(); var sw = new Stopwatch();
for (var i = 0; i < examples.Length; i++) for (var i = 0; i < examples.Length; i++)
@@ -72,14 +85,35 @@ namespace TensorFlowNET.Examples
Console.WriteLine(ex); Console.WriteLine(ex);
} }


finished++;
Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White);
} }


success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green)); success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green));
errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red));


Console.WriteLine($"{examples.Length} examples are completed.");
Console.WriteLine($"{finished} of {examples.Length} example(s) are completed.");
Console.ReadLine(); Console.ReadLine();
} }

private static Dictionary<string, string> ParseArgs(string[] args)
{
var parsed = new Dictionary<string, string>();

for (int i = 0; i < args.Length; i++)
{
string key = args[i].Substring(1);
switch (key)
{
case "ex":
parsed.Add(key, args[++i]);
break;
default:
break;
}
}

return parsed;
}
} }
} }

+ 0
- 6
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -6,12 +6,6 @@
<GeneratePackageOnBuild>false</GeneratePackageOnBuild> <GeneratePackageOnBuild>false</GeneratePackageOnBuild>
</PropertyGroup> </PropertyGroup>


<ItemGroup>
<Compile Remove="python\**" />
<EmbeddedResource Remove="python\**" />
<None Remove="python\**" />
</ItemGroup>

<ItemGroup> <ItemGroup>
<PackageReference Include="Colorful.Console" Version="1.2.9" /> <PackageReference Include="Colorful.Console" Version="1.2.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> <PackageReference Include="Newtonsoft.Json" Version="12.0.2" />


Loading…
Cancel
Save