diff --git a/README.md b/README.md index 421242cd..9ec3432a 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ Example runner will download all the required files like training data and model * [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs) * [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs) * [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs) +* [MNIST CNN](test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs) * [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER) * [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs) diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs index c7a40255..c637e09f 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs @@ -73,7 +73,9 @@ namespace TensorFlowNET.Examples.ImageProcess float accuracy_test = 0f; float loss_test = 1f; - NDArray x_train; + NDArray x_train, y_train; + NDArray x_valid, y_valid; + NDArray x_test, y_test; public bool Run() { @@ -135,6 +137,62 @@ namespace TensorFlowNET.Examples.ImageProcess return graph; } + public void Train(Session sess) + { + // Number of training iterations in each epoch + var num_tr_iter = y_train.len / batch_size; + + var init = tf.global_variables_initializer(); + sess.run(init); + + float loss_val = 100.0f; + float accuracy_val = 0f; + + foreach (var epoch in range(epochs)) + { + print($"Training epoch: {epoch + 1}"); + // Randomly shuffle the training data at the beginning of each epoch + (x_train, y_train) = mnist.Randomize(x_train, y_train); + + foreach (var iteration in range(num_tr_iter)) + { + var start = iteration * batch_size; + var end = (iteration + 1) * batch_size; + var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); + + // Run optimization op (backprop) + sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + + if (iteration % display_freq == 0) + { + // Calculate and display the batch loss and accuracy + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + loss_val = result[0]; + accuracy_val = result[1]; + print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); + } + } + + // Run validation after every epoch + var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid)); + loss_val = results1[0]; + accuracy_val = results1[1]; + print("---------------------------------------------------------"); + print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); + print("---------------------------------------------------------"); + } + } + + public void Test(Session sess) + { + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test)); + loss_test = result[0]; + accuracy_test = result[1]; + print("---------------------------------------------------------"); + print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); + print("---------------------------------------------------------"); + } + /// /// Create a 2D convolution layer /// @@ -219,6 +277,14 @@ namespace TensorFlowNET.Examples.ImageProcess initializer: initial); } + /// + /// Create a fully-connected layer + /// + /// input from previous layer + /// number of hidden units in the fully-connected layer + /// layer name + /// boolean to add ReLU non-linearity (or not) + /// The output array private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) { return with(tf.variable_scope(name), delegate @@ -235,81 +301,36 @@ namespace TensorFlowNET.Examples.ImageProcess return layer; }); } - - public Graph ImportGraph() => throw new NotImplementedException(); - - public void Predict(Session sess) => throw new NotImplementedException(); public void PrepareData() { mnist = MNIST.read_data_sets("mnist", one_hot: true); - x_train = Reformat(mnist.train.data, mnist.train.labels); + (x_train, y_train) = Reformat(mnist.train.data, mnist.train.labels); + (x_valid, y_valid) = Reformat(mnist.validation.data, mnist.validation.labels); + (x_test, y_test) = Reformat(mnist.test.data, mnist.test.labels); + print("Size of:"); print($"- Training-set:\t\t{len(mnist.train.data)}"); print($"- Validation-set:\t{len(mnist.validation.data)}"); } - private NDArray Reformat(NDArray x, NDArray y) + /// + /// Reformats the data to the format acceptable for convolutional layers + /// + /// + /// + /// + private (NDArray, NDArray) Reformat(NDArray x, NDArray y) { - var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, np.unique(np.argmax(y, 1))); - - return x; + var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, len(np.unique(np.argmax(y, 1)))); + var dataset = x.reshape(x.shape[0], img_size, img_size, num_ch).astype(np.float32); + //y[0] = np.arange(num_class) == y[0]; + //var labels = (np.arange(num_class) == y.reshape(y.shape[0], 1, y.shape[1])).astype(np.float32); + return (dataset, y); } - public void Train(Session sess) - { - // Number of training iterations in each epoch - var num_tr_iter = mnist.train.labels.len / batch_size; - - var init = tf.global_variables_initializer(); - sess.run(init); - - float loss_val = 100.0f; - float accuracy_val = 0f; - - foreach (var epoch in range(epochs)) - { - print($"Training epoch: {epoch + 1}"); - // Randomly shuffle the training data at the beginning of each epoch - var (x_train, y_train) = mnist.Randomize(mnist.train.data, mnist.train.labels); - - foreach (var iteration in range(num_tr_iter)) - { - var start = iteration * batch_size; - var end = (iteration + 1) * batch_size; - var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); - - // Run optimization op (backprop) - sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); - - if (iteration % display_freq == 0) - { - // Calculate and display the batch loss and accuracy - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); - loss_val = result[0]; - accuracy_val = result[1]; - print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); - } - } - - // Run validation after every epoch - var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels)); - loss_val = results1[0]; - accuracy_val = results1[1]; - print("---------------------------------------------------------"); - print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); - print("---------------------------------------------------------"); - } - } + public Graph ImportGraph() => throw new NotImplementedException(); - public void Test(Session sess) - { - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); - loss_test = result[0]; - accuracy_test = result[1]; - print("---------------------------------------------------------"); - print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); - print("---------------------------------------------------------"); - } + public void Predict(Session sess) => throw new NotImplementedException(); } } diff --git a/test/TensorFlowNET.Examples/Utility/Datasets.cs b/test/TensorFlowNET.Examples/Utility/Datasets.cs index 93ca1869..af57c7cf 100644 --- a/test/TensorFlowNET.Examples/Utility/Datasets.cs +++ b/test/TensorFlowNET.Examples/Utility/Datasets.cs @@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples.Utility var perm = np.random.permutation(y.shape[0]); np.random.shuffle(perm); - return (train.data[perm], train.labels[perm]); + return (x[perm], y[perm]); } ///