From f462c5592ff62d61383d0333c93f19e10c142835 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 22 Mar 2019 06:53:00 -0500 Subject: [PATCH] print Accuracy of LogisticRegression --- src/TensorFlowNET.Core/APIs/tf.math.cs | 3 +++ .../Operations/gen_math_ops.cs | 14 ++++++++++++ .../Sessions/_ElementFetchMapper.cs | 3 +++ src/TensorFlowNET.Core/Tensors/Tensor.cs | 7 +++++- .../LogisticRegression.cs | 22 +++++++++++++------ 5 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 0039e0d2..aa645931 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -21,6 +21,9 @@ namespace Tensorflow public static Tensor log(Tensor x, string name = null) => gen_math_ops.log(x, name); + public static Tensor equal(Tensor x, Tensor y, string name = null) + => gen_math_ops.equal(x, y, name); + public static Tensor multiply(Tensor x, Tensor y) => gen_math_ops.mul(x, y); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index a48d60c4..52388845 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -128,6 +128,20 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Returns the truth value of (x == y) element-wise. + /// + /// + /// + /// + /// + public static Tensor equal(Tensor x, Tensor y, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y }); + + return _op.outputs[0]; + } + public static Tensor mul(Tensor x, Tensor y, string name = null) { var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index c0de60ee..bd86e8d8 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -43,6 +43,9 @@ namespace Tensorflow case NDArray value: result = value; break; + case float fVal: + result = fVal; + break; default: break; } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index cb8d24db..dffa8ff6 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -168,7 +168,12 @@ namespace Tensorflow /// A dictionary that maps `Tensor` objects to feed values. /// The `Session` to be used to evaluate this tensor. /// - public NDArray eval(FeedItem[] feed_dict = null, Session session = null) + public NDArray eval(params FeedItem[] feed_dict) + { + return ops._eval_using_default_session(this, feed_dict, graph); + } + + public NDArray eval(Session session, FeedItem[] feed_dict = null) { return ops._eval_using_default_session(this, feed_dict, graph, session); } diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index 38c124cc..814ea2a8 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -17,17 +17,14 @@ namespace TensorFlowNET.Examples public class LogisticRegression : Python, IExample { private float learning_rate = 0.01f; - private int training_epochs = 25; + private int training_epochs = 5; private int batch_size = 100; private int display_step = 1; public void Run() { - PrepareData(); - } + var mnist = PrepareData(); - private void PrepareData() - { // tf Graph Input var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784 var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes @@ -50,12 +47,12 @@ namespace TensorFlowNET.Examples with(tf.Session(), sess => { - var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); + // Run the initializer sess.run(init); // Training cycle - foreach(var epoch in range(training_epochs)) + foreach (var epoch in range(training_epochs)) { var avg_cost = 0.0f; var total_batch = mnist.train.num_examples / batch_size; @@ -81,7 +78,18 @@ namespace TensorFlowNET.Examples print("Optimization Finished!"); // Test model + var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); + // Calculate accuracy + var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); + float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels)); + print($"Accuracy: {acc.ToString("F4")}"); }); } + + private Datasets PrepareData() + { + var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); + return mnist; + } } }