|
|
|
@@ -17,17 +17,14 @@ namespace TensorFlowNET.Examples |
|
|
|
public class LogisticRegression : Python, IExample |
|
|
|
{ |
|
|
|
private float learning_rate = 0.01f; |
|
|
|
private int training_epochs = 25; |
|
|
|
private int training_epochs = 5; |
|
|
|
private int batch_size = 100; |
|
|
|
private int display_step = 1; |
|
|
|
|
|
|
|
public void Run() |
|
|
|
{ |
|
|
|
PrepareData(); |
|
|
|
} |
|
|
|
var mnist = PrepareData(); |
|
|
|
|
|
|
|
private void PrepareData() |
|
|
|
{ |
|
|
|
// tf Graph Input |
|
|
|
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784 |
|
|
|
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes |
|
|
|
@@ -50,12 +47,12 @@ namespace TensorFlowNET.Examples |
|
|
|
|
|
|
|
with(tf.Session(), sess => |
|
|
|
{ |
|
|
|
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); |
|
|
|
|
|
|
|
// Run the initializer |
|
|
|
sess.run(init); |
|
|
|
|
|
|
|
// Training cycle |
|
|
|
foreach(var epoch in range(training_epochs)) |
|
|
|
foreach (var epoch in range(training_epochs)) |
|
|
|
{ |
|
|
|
var avg_cost = 0.0f; |
|
|
|
var total_batch = mnist.train.num_examples / batch_size; |
|
|
|
@@ -81,7 +78,18 @@ namespace TensorFlowNET.Examples |
|
|
|
print("Optimization Finished!"); |
|
|
|
|
|
|
|
// Test model |
|
|
|
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); |
|
|
|
// Calculate accuracy |
|
|
|
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); |
|
|
|
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels)); |
|
|
|
print($"Accuracy: {acc.ToString("F4")}"); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
private Datasets PrepareData() |
|
|
|
{ |
|
|
|
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); |
|
|
|
return mnist; |
|
|
|
} |
|
|
|
} |
|
|
|
} |