Browse Source

Build RNN graph, BasicRNNCell not implemented.

tags/v0.10
Oceania2018 6 years ago
parent
commit
c3ae8b1f03
8 changed files with 321 additions and 1 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  2. +10
    -0
      src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
  3. +14
    -0
      src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
  4. +163
    -0
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
  5. +78
    -0
      test/TensorFlowNET.Examples/python/minst_lstm.py
  6. +48
    -0
      test/TensorFlowNET.Examples/python/minst_rnn.py
  7. +4
    -0
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
  8. +2
    -1
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

+ 2
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -127,6 +127,8 @@ namespace Tensorflow
});
}

public static rnn_cell_impl rnn_cell => new rnn_cell_impl();

public static Tensor softmax(Tensor logits, int axis = -1, string name = null)
=> gen_nn_ops.softmax(logits, name);



+ 10
- 0
src/TensorFlowNET.Core/Operations/BasicRNNCell.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class BasicRNNCell
{
}
}

+ 14
- 0
src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
public class rnn_cell_impl
{
public BasicRNNCell BasicRNNCell(int num_units)
{
throw new NotImplementedException();
}
}
}

+ 163
- 0
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs View File

@@ -0,0 +1,163 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples.ImageProcess
{
/// <summary>
/// Convolutional Neural Network classifier for Hand Written Digits
/// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end.
/// Use Stochastic Gradient Descent (SGD) optimizer.
/// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1
/// </summary>
public class DigitRecognitionRNN : IExample
{
public bool Enabled { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;

public string Name => "MNIST RNN";

string logs_path = "logs";

// Hyper-parameters
int n_neurons = 128;
float learning_rate = 0.001f;
int batch_size = 128;
int epochs = 10;

int n_steps = 28;
int n_inputs = 28;
int n_outputs = 10;

Datasets<DataSetMnist> mnist;

Tensor x, y;
Tensor loss, accuracy, cls_prediction;
Operation optimizer;

int display_freq = 100;
float accuracy_test = 0f;
float loss_test = 1f;

NDArray x_train, y_train;
NDArray x_valid, y_valid;
NDArray x_test, y_test;

public bool Run()
{
PrepareData();
BuildGraph();

with(tf.Session(), sess =>
{
Train(sess);
Test(sess);
});

return loss_test < 0.09 && accuracy_test > 0.95;
}

public Graph BuildGraph()
{
var graph = new Graph().as_default();

var X = tf.placeholder(tf.float32, new[] { -1, n_steps, n_inputs });
var y = tf.placeholder(tf.int32, new[] { -1 });
var cell = tf.nn.rnn_cell.BasicRNNCell(num_units: n_neurons);

return graph;
}

public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = y_train.len / batch_size;

var init = tf.global_variables_initializer();
sess.run(init);

float loss_val = 100.0f;
float accuracy_val = 0f;

foreach (var epoch in range(epochs))
{
print($"Training epoch: {epoch + 1}");
// Randomly shuffle the training data at the beginning of each epoch
(x_train, y_train) = mnist.Randomize(x_train, y_train);

foreach (var iteration in range(num_tr_iter))
{
var start = iteration * batch_size;
var end = (iteration + 1) * batch_size;
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);

// Run optimization op (backprop)
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));

if (iteration % display_freq == 0)
{
// Calculate and display the batch loss and accuracy
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
loss_val = result[0];
accuracy_val = result[1];
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
}
}

// Run validation after every epoch
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid));
loss_val = results1[0];
accuracy_val = results1[1];
print("---------------------------------------------------------");
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
print("---------------------------------------------------------");
}
}

public void Test(Session sess)
{
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test));
loss_test = result[0];
accuracy_test = result[1];
print("---------------------------------------------------------");
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
print("---------------------------------------------------------");
}

public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true);
(x_train, y_train) = (mnist.train.data, mnist.train.labels);
(x_valid, y_valid) = (mnist.validation.data, mnist.validation.labels);
(x_test, y_test) = (mnist.test.data, mnist.test.labels);

print("Size of:");
print($"- Training-set:\t\t{len(mnist.train.data)}");
print($"- Validation-set:\t{len(mnist.validation.data)}");
}

public Graph ImportGraph() => throw new NotImplementedException();

public void Predict(Session sess) => throw new NotImplementedException();
}
}

+ 78
- 0
test/TensorFlowNET.Examples/python/minst_lstm.py View File

@@ -0,0 +1,78 @@
import tensorflow as tf
from tensorflow.contrib import rnn

#import mnist dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)

#define constants
#unrolled through 28 time steps
time_steps=28
#hidden LSTM units
num_units=128
#rows of 28 pixels
n_input=28
#learning rate for adam
learning_rate=0.001
#mnist is meant to be classified in 10 classes(0-9).
n_classes=10
#size of batch
batch_size=128


#weights and biases of appropriate shape to accomplish above task
out_weights=tf.Variable(tf.random_normal([num_units,n_classes]))
out_bias=tf.Variable(tf.random_normal([n_classes]))

#defining placeholders
#input image placeholder
x=tf.placeholder("float",[None,time_steps,n_input])
#input label placeholder
y=tf.placeholder("float",[None,n_classes])

#processing the input tensor from [batch_size,n_steps,n_input] to "time_steps" number of [batch_size,n_input] tensors
input=tf.unstack(x ,time_steps,1)

#defining the network
lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1)
outputs,_=rnn.static_rnn(lstm_layer,input,dtype="float32")

#converting last output of dimension [batch_size,num_units] to [batch_size,n_classes] by out_weight multiplication
prediction=tf.matmul(outputs[-1],out_weights)+out_bias

#loss_function
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
#optimization
opt=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

#model evaluation
correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#initialize variables
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
iter=1
while iter<800:
batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)

batch_x=batch_x.reshape((batch_size,time_steps,n_input))

sess.run(opt, feed_dict={x: batch_x, y: batch_y})

if iter %10==0:
acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})
los=sess.run(loss,feed_dict={x:batch_x,y:batch_y})
print("For iter ",iter)
print("Accuracy ",acc)
print("Loss ",los)
print("__________________")

iter=iter+1

#calculating test accuracy
test_data = mnist.test.images[:128].reshape((-1, time_steps, n_input))
test_label = mnist.test.labels[:128]
print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))


+ 48
- 0
test/TensorFlowNET.Examples/python/minst_rnn.py View File

@@ -0,0 +1,48 @@
import tensorflow as tf

# hyperparameters
n_neurons = 128
learning_rate = 0.001
batch_size = 128
n_epochs = 10
# parameters
n_steps = 28 # 28 rows
n_inputs = 28 # 28 cols
n_outputs = 10 # 10 classes
# build a rnn model
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.int32, [None])
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
logits = tf.layers.dense(state, n_outputs)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
loss = tf.reduce_mean(cross_entropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
prediction = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))

# input data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")
X_test = mnist.test.images # X_test shape: [num_test, 28*28]
X_test = X_test.reshape([-1, n_steps, n_inputs])
y_test = mnist.test.labels

# initialize the variables
init = tf.global_variables_initializer()
# train the model
with tf.Session() as sess:
sess.run(init)
n_batches = mnist.train.num_examples // batch_size
for epoch in range(n_epochs):
for batch in range(n_batches):
X_train, y_train = mnist.train.next_batch(batch_size)
X_train = X_train.reshape([-1, n_steps, n_inputs])
sess.run(optimizer, feed_dict={X: X_train, y: y_train})
loss_train, acc_train = sess.run(
[loss, accuracy], feed_dict={X: X_train, y: y_train})
print('Epoch: {}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format(
epoch + 1, loss_train, acc_train))
loss_test, acc_test = sess.run(
[loss, accuracy], feed_dict={X: X_test, y: y_test})
print('Test Loss: {:.3f}, Test Acc: {:.3f}'.format(loss_test, acc_test))

+ 4
- 0
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -12,6 +12,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
[TestClass]
public class CondTestCases : PythonTest
{
[Ignore("need tesnroflow expose AddControlInput API")]
[TestMethod]
public void testCondTrue_ConstOnly()
{
@@ -31,6 +32,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
});
}
[Ignore("need tesnroflow expose AddControlInput API")]
[TestMethod]
public void testCondFalse_ConstOnly()
{
@@ -50,6 +52,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
});
}
[Ignore("need tesnroflow expose AddControlInput API")]
[TestMethod]
public void testCondTrue()
{
@@ -66,6 +69,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
assertEquals(result, 34);
}
[Ignore("need tesnroflow expose AddControlInput API")]
[TestMethod]
public void testCondFalse()
{


+ 2
- 1
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -65,11 +65,12 @@ namespace TensorFlowNET.UnitTest.ops_test
});
}
[Ignore("need tesnroflow expose UpdateEdge API")]
[TestMethod]
public void TestCond()
{
var graph = tf.Graph().as_default();
with<Graph>(graph, g =>
with(graph, g =>
{
var x = constant_op.constant(10);


Loading…
Cancel
Save