| @@ -0,0 +1,27 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using TensorFlowNET.Examples.Utility; | |||
| namespace TensorFlowNET.Examples | |||
| { | |||
| /// <summary> | |||
| /// A logistic regression learning algorithm example using TensorFlow library. | |||
| /// This example is using the MNIST database of handwritten digits | |||
| /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py | |||
| /// </summary> | |||
| public class LogisticRegression : Python, IExample | |||
| { | |||
| public void Run() | |||
| { | |||
| PrepareData(); | |||
| } | |||
| private void PrepareData() | |||
| { | |||
| MnistDataSet.read_data_sets("logistic_regression", one_hot: true); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,6 +6,7 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="DevExpress.Xpo" Version="18.2.6" /> | |||
| <PackageReference Include="NumSharp" Version="0.8.0" /> | |||
| <PackageReference Include="SharpZipLib" Version="1.1.0" /> | |||
| <PackageReference Include="TensorFlow.NET" Version="0.4.2" /> | |||
| @@ -1,4 +1,5 @@ | |||
| using ICSharpCode.SharpZipLib.GZip; | |||
| using ICSharpCode.SharpZipLib.Core; | |||
| using ICSharpCode.SharpZipLib.GZip; | |||
| using ICSharpCode.SharpZipLib.Tar; | |||
| using System; | |||
| using System.IO; | |||
| @@ -11,6 +12,26 @@ namespace TensorFlowNET.Examples.Utility | |||
| { | |||
| public class Compress | |||
| { | |||
| public static void ExtractGZip(string gzipFileName, string targetDir) | |||
| { | |||
| // Use a 4K buffer. Any larger is a waste. | |||
| byte[] dataBuffer = new byte[4096]; | |||
| using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read)) | |||
| { | |||
| using (GZipInputStream gzipStream = new GZipInputStream(fs)) | |||
| { | |||
| // Change this to your needs | |||
| string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName)); | |||
| using (FileStream fsOut = File.Create(fnOut)) | |||
| { | |||
| StreamUtils.Copy(gzipStream, fsOut, dataBuffer); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| public static void UnZip(String gzArchiveName, String destFolder) | |||
| { | |||
| var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; | |||
| @@ -0,0 +1,20 @@ | |||
| using NumSharp.Core; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.Examples.Utility | |||
| { | |||
| public class DataSet | |||
| { | |||
| private int _num_examples; | |||
| public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | |||
| { | |||
| _num_examples = images.shape[0]; | |||
| images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||
| images = np.multiply(images, 1.0f / 255.0f); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,110 @@ | |||
| using ICSharpCode.SharpZipLib.Core; | |||
| using ICSharpCode.SharpZipLib.GZip; | |||
| using NumSharp.Core; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.Examples.Utility | |||
| { | |||
| public class MnistDataSet | |||
| { | |||
| private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; | |||
| private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; | |||
| private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | |||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||
| public static void read_data_sets(string train_dir, | |||
| bool one_hot = false, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool reshape = true, | |||
| int validation_size = 5000, | |||
| string source_url = DEFAULT_SOURCE_URL) | |||
| { | |||
| Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | |||
| Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | |||
| var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0])); | |||
| Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); | |||
| Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); | |||
| var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot); | |||
| Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); | |||
| Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); | |||
| var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0])); | |||
| Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); | |||
| Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); | |||
| var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot); | |||
| int end = train_images.shape[0]; | |||
| var validation_images = train_images[np.arange(validation_size)]; | |||
| var validation_labels = train_labels[np.arange(validation_size)]; | |||
| train_images = train_images[np.arange(validation_size, end)]; | |||
| train_labels = train_labels[np.arange(validation_size, end)]; | |||
| var train = new DataSet(train_images, train_labels, dtype, reshape); | |||
| } | |||
| public static NDArray extract_images(string file) | |||
| { | |||
| using (var bytestream = new FileStream(file, FileMode.Open)) | |||
| { | |||
| var magic = _read32(bytestream); | |||
| if (magic != 2051) | |||
| throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); | |||
| var num_images = _read32(bytestream); | |||
| var rows = _read32(bytestream); | |||
| var cols = _read32(bytestream); | |||
| var buf = new byte[rows * cols * num_images]; | |||
| bytestream.Read(buf, 0, buf.Length); | |||
| var data = np.frombuffer(buf, np.uint8); | |||
| data = data.reshape((int)num_images, (int)rows, (int)cols, 1); | |||
| return data; | |||
| } | |||
| } | |||
| public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10) | |||
| { | |||
| using (var bytestream = new FileStream(file, FileMode.Open)) | |||
| { | |||
| var magic = _read32(bytestream); | |||
| if (magic != 2049) | |||
| throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); | |||
| var num_items = _read32(bytestream); | |||
| var buf = new byte[num_items]; | |||
| bytestream.Read(buf, 0, buf.Length); | |||
| var labels = np.frombuffer(buf, np.uint8); | |||
| if (one_hot) | |||
| return dense_to_one_hot(labels, num_classes); | |||
| return labels; | |||
| } | |||
| } | |||
| private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes) | |||
| { | |||
| var num_labels = labels_dense.shape[0]; | |||
| var index_offset = np.arange(num_labels) * num_classes; | |||
| var labels_one_hot = np.zeros(num_labels, num_classes); | |||
| for(int row = 0; row < num_labels; row++) | |||
| { | |||
| var col = labels_dense.Data<byte>(row); | |||
| labels_one_hot[row, col] = 1; | |||
| } | |||
| return labels_one_hot; | |||
| } | |||
| private static uint _read32(FileStream bytestream) | |||
| { | |||
| var buffer = new byte[sizeof(uint)]; | |||
| var count = bytestream.Read(buffer, 0, 4); | |||
| return np.frombuffer(buffer, ">u4").Data<uint>(0); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,70 @@ | |||
| ''' | |||
| A logistic regression learning algorithm example using TensorFlow library. | |||
| This example is using the MNIST database of handwritten digits | |||
| (http://yann.lecun.com/exdb/mnist/) | |||
| Author: Aymeric Damien | |||
| Project: https://github.com/aymericdamien/TensorFlow-Examples/ | |||
| ''' | |||
| from __future__ import print_function | |||
| import tensorflow as tf | |||
| # Import MNIST data | |||
| from tensorflow.examples.tutorials.mnist import input_data | |||
| mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | |||
| # Parameters | |||
| learning_rate = 0.01 | |||
| training_epochs = 25 | |||
| batch_size = 100 | |||
| display_step = 1 | |||
| # tf Graph Input | |||
| x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 | |||
| y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes | |||
| # Set model weights | |||
| W = tf.Variable(tf.zeros([784, 10])) | |||
| b = tf.Variable(tf.zeros([10])) | |||
| # Construct model | |||
| pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax | |||
| # Minimize error using cross entropy | |||
| cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) | |||
| # Gradient Descent | |||
| optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) | |||
| # Initialize the variables (i.e. assign their default value) | |||
| init = tf.global_variables_initializer() | |||
| # Start training | |||
| with tf.Session() as sess: | |||
| # Run the initializer | |||
| sess.run(init) | |||
| # Training cycle | |||
| for epoch in range(training_epochs): | |||
| avg_cost = 0. | |||
| total_batch = int(mnist.train.num_examples/batch_size) | |||
| # Loop over all batches | |||
| for i in range(total_batch): | |||
| batch_xs, batch_ys = mnist.train.next_batch(batch_size) | |||
| # Run optimization op (backprop) and cost op (to get loss value) | |||
| _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, | |||
| y: batch_ys}) | |||
| # Compute average loss | |||
| avg_cost += c / total_batch | |||
| # Display logs per epoch step | |||
| if (epoch+1) % display_step == 0: | |||
| print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)) | |||
| print("Optimization Finished!") | |||
| # Test model | |||
| correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) | |||
| # Calculate accuracy | |||
| accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |||
| print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) | |||