You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LogisticRegression.cs 3.7 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. using Newtonsoft.Json;
  2. using NumSharp.Core;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. using TensorFlowNET.Examples.Utility;
  9. namespace TensorFlowNET.Examples
  10. {
  11. /// <summary>
  12. /// A logistic regression learning algorithm example using TensorFlow library.
  13. /// This example is using the MNIST database of handwritten digits
  14. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py
  15. /// </summary>
  16. public class LogisticRegression : Python, IExample
  17. {
  18. private float learning_rate = 0.01f;
  19. private int training_epochs = 5;
  20. private int batch_size = 100;
  21. private int display_step = 1;
  22. public void Run()
  23. {
  24. var mnist = PrepareData();
  25. // tf Graph Input
  26. var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
  27. var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
  28. // Set model weights
  29. var W = tf.Variable(tf.zeros(new Shape(784, 10)));
  30. var b = tf.Variable(tf.zeros(new Shape(10)));
  31. // Construct model
  32. var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax
  33. // Minimize error using cross entropy
  34. var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1));
  35. // Gradient Descent
  36. var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
  37. // Initialize the variables (i.e. assign their default value)
  38. var init = tf.global_variables_initializer();
  39. with(tf.Session(), sess =>
  40. {
  41. // Run the initializer
  42. sess.run(init);
  43. // Training cycle
  44. foreach (var epoch in range(training_epochs))
  45. {
  46. var avg_cost = 0.0f;
  47. var total_batch = mnist.train.num_examples / batch_size;
  48. // Loop over all batches
  49. foreach (var i in range(total_batch))
  50. {
  51. var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
  52. // Run optimization op (backprop) and cost op (to get loss value)
  53. var result = sess.run(new object[] { optimizer, cost },
  54. new FeedItem(x, batch_xs),
  55. new FeedItem(y, batch_ys));
  56. var c = (float)result[1];
  57. // Compute average loss
  58. avg_cost += c / total_batch;
  59. }
  60. // Display logs per epoch step
  61. if ((epoch + 1) % display_step == 0)
  62. print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}");
  63. }
  64. print("Optimization Finished!");
  65. // Test model
  66. var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
  67. // Calculate accuracy
  68. var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
  69. float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
  70. print($"Accuracy: {acc.ToString("F4")}");
  71. });
  72. }
  73. private Datasets PrepareData()
  74. {
  75. var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
  76. return mnist;
  77. }
  78. }
  79. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。