diff --git a/README.md b/README.md
index 421242cd..9ec3432a 100644
--- a/README.md
+++ b/README.md
@@ -149,6 +149,7 @@ Example runner will download all the required files like training data and model
* [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs)
* [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs)
* [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs)
+* [MNIST CNN](test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs)
* [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER)
* [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs)
diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs
index c7a40255..c637e09f 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs
@@ -73,7 +73,9 @@ namespace TensorFlowNET.Examples.ImageProcess
float accuracy_test = 0f;
float loss_test = 1f;
- NDArray x_train;
+ NDArray x_train, y_train;
+ NDArray x_valid, y_valid;
+ NDArray x_test, y_test;
public bool Run()
{
@@ -135,6 +137,62 @@ namespace TensorFlowNET.Examples.ImageProcess
return graph;
}
+ public void Train(Session sess)
+ {
+ // Number of training iterations in each epoch
+ var num_tr_iter = y_train.len / batch_size;
+
+ var init = tf.global_variables_initializer();
+ sess.run(init);
+
+ float loss_val = 100.0f;
+ float accuracy_val = 0f;
+
+ foreach (var epoch in range(epochs))
+ {
+ print($"Training epoch: {epoch + 1}");
+ // Randomly shuffle the training data at the beginning of each epoch
+ (x_train, y_train) = mnist.Randomize(x_train, y_train);
+
+ foreach (var iteration in range(num_tr_iter))
+ {
+ var start = iteration * batch_size;
+ var end = (iteration + 1) * batch_size;
+ var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
+
+ // Run optimization op (backprop)
+ sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
+
+ if (iteration % display_freq == 0)
+ {
+ // Calculate and display the batch loss and accuracy
+ var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
+ loss_val = result[0];
+ accuracy_val = result[1];
+ print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
+ }
+ }
+
+ // Run validation after every epoch
+ var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid));
+ loss_val = results1[0];
+ accuracy_val = results1[1];
+ print("---------------------------------------------------------");
+ print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
+ print("---------------------------------------------------------");
+ }
+ }
+
+ public void Test(Session sess)
+ {
+ var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test));
+ loss_test = result[0];
+ accuracy_test = result[1];
+ print("---------------------------------------------------------");
+ print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
+ print("---------------------------------------------------------");
+ }
+
///
/// Create a 2D convolution layer
///
@@ -219,6 +277,14 @@ namespace TensorFlowNET.Examples.ImageProcess
initializer: initial);
}
+ ///
+ /// Create a fully-connected layer
+ ///
+ /// input from previous layer
+ /// number of hidden units in the fully-connected layer
+ /// layer name
+ /// boolean to add ReLU non-linearity (or not)
+ /// The output array
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
{
return with(tf.variable_scope(name), delegate
@@ -235,81 +301,36 @@ namespace TensorFlowNET.Examples.ImageProcess
return layer;
});
}
-
- public Graph ImportGraph() => throw new NotImplementedException();
-
- public void Predict(Session sess) => throw new NotImplementedException();
public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true);
- x_train = Reformat(mnist.train.data, mnist.train.labels);
+ (x_train, y_train) = Reformat(mnist.train.data, mnist.train.labels);
+ (x_valid, y_valid) = Reformat(mnist.validation.data, mnist.validation.labels);
+ (x_test, y_test) = Reformat(mnist.test.data, mnist.test.labels);
+
print("Size of:");
print($"- Training-set:\t\t{len(mnist.train.data)}");
print($"- Validation-set:\t{len(mnist.validation.data)}");
}
- private NDArray Reformat(NDArray x, NDArray y)
+ ///
+ /// Reformats the data to the format acceptable for convolutional layers
+ ///
+ ///
+ ///
+ ///
+ private (NDArray, NDArray) Reformat(NDArray x, NDArray y)
{
- var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, np.unique(np.argmax(y, 1)));
-
- return x;
+ var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, len(np.unique(np.argmax(y, 1))));
+ var dataset = x.reshape(x.shape[0], img_size, img_size, num_ch).astype(np.float32);
+ //y[0] = np.arange(num_class) == y[0];
+ //var labels = (np.arange(num_class) == y.reshape(y.shape[0], 1, y.shape[1])).astype(np.float32);
+ return (dataset, y);
}
- public void Train(Session sess)
- {
- // Number of training iterations in each epoch
- var num_tr_iter = mnist.train.labels.len / batch_size;
-
- var init = tf.global_variables_initializer();
- sess.run(init);
-
- float loss_val = 100.0f;
- float accuracy_val = 0f;
-
- foreach (var epoch in range(epochs))
- {
- print($"Training epoch: {epoch + 1}");
- // Randomly shuffle the training data at the beginning of each epoch
- var (x_train, y_train) = mnist.Randomize(mnist.train.data, mnist.train.labels);
-
- foreach (var iteration in range(num_tr_iter))
- {
- var start = iteration * batch_size;
- var end = (iteration + 1) * batch_size;
- var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
-
- // Run optimization op (backprop)
- sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
-
- if (iteration % display_freq == 0)
- {
- // Calculate and display the batch loss and accuracy
- var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
- loss_val = result[0];
- accuracy_val = result[1];
- print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
- }
- }
-
- // Run validation after every epoch
- var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels));
- loss_val = results1[0];
- accuracy_val = results1[1];
- print("---------------------------------------------------------");
- print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
- print("---------------------------------------------------------");
- }
- }
+ public Graph ImportGraph() => throw new NotImplementedException();
- public void Test(Session sess)
- {
- var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
- loss_test = result[0];
- accuracy_test = result[1];
- print("---------------------------------------------------------");
- print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
- print("---------------------------------------------------------");
- }
+ public void Predict(Session sess) => throw new NotImplementedException();
}
}
diff --git a/test/TensorFlowNET.Examples/Utility/Datasets.cs b/test/TensorFlowNET.Examples/Utility/Datasets.cs
index 93ca1869..af57c7cf 100644
--- a/test/TensorFlowNET.Examples/Utility/Datasets.cs
+++ b/test/TensorFlowNET.Examples/Utility/Datasets.cs
@@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples.Utility
var perm = np.random.permutation(y.shape[0]);
np.random.shuffle(perm);
- return (train.data[perm], train.labels[perm]);
+ return (x[perm], y[perm]);
}
///