diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 0039e0d2..aa645931 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -21,6 +21,9 @@ namespace Tensorflow
public static Tensor log(Tensor x, string name = null)
=> gen_math_ops.log(x, name);
+ public static Tensor equal(Tensor x, Tensor y, string name = null)
+ => gen_math_ops.equal(x, y, name);
+
public static Tensor multiply(Tensor x, Tensor y)
=> gen_math_ops.mul(x, y);
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index a48d60c4..52388845 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -128,6 +128,20 @@ namespace Tensorflow
return _op.outputs[0];
}
+ ///
+ /// Returns the truth value of (x == y) element-wise.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor equal(Tensor x, Tensor y, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y });
+
+ return _op.outputs[0];
+ }
+
public static Tensor mul(Tensor x, Tensor y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });
diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
index c0de60ee..bd86e8d8 100644
--- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
+++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
@@ -43,6 +43,9 @@ namespace Tensorflow
case NDArray value:
result = value;
break;
+ case float fVal:
+ result = fVal;
+ break;
default:
break;
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index cb8d24db..dffa8ff6 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -168,7 +168,12 @@ namespace Tensorflow
/// A dictionary that maps `Tensor` objects to feed values.
/// The `Session` to be used to evaluate this tensor.
///
- public NDArray eval(FeedItem[] feed_dict = null, Session session = null)
+ public NDArray eval(params FeedItem[] feed_dict)
+ {
+ return ops._eval_using_default_session(this, feed_dict, graph);
+ }
+
+ public NDArray eval(Session session, FeedItem[] feed_dict = null)
{
return ops._eval_using_default_session(this, feed_dict, graph, session);
}
diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs
index 38c124cc..814ea2a8 100644
--- a/test/TensorFlowNET.Examples/LogisticRegression.cs
+++ b/test/TensorFlowNET.Examples/LogisticRegression.cs
@@ -17,17 +17,14 @@ namespace TensorFlowNET.Examples
public class LogisticRegression : Python, IExample
{
private float learning_rate = 0.01f;
- private int training_epochs = 25;
+ private int training_epochs = 5;
private int batch_size = 100;
private int display_step = 1;
public void Run()
{
- PrepareData();
- }
+ var mnist = PrepareData();
- private void PrepareData()
- {
// tf Graph Input
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
@@ -50,12 +47,12 @@ namespace TensorFlowNET.Examples
with(tf.Session(), sess =>
{
- var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
+
// Run the initializer
sess.run(init);
// Training cycle
- foreach(var epoch in range(training_epochs))
+ foreach (var epoch in range(training_epochs))
{
var avg_cost = 0.0f;
var total_batch = mnist.train.num_examples / batch_size;
@@ -81,7 +78,18 @@ namespace TensorFlowNET.Examples
print("Optimization Finished!");
// Test model
+ var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
+ // Calculate accuracy
+ var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
+ float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
+ print($"Accuracy: {acc.ToString("F4")}");
});
}
+
+ private Datasets PrepareData()
+ {
+ var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
+ return mnist;
+ }
}
}