Browse Source

print Accuracy of LogisticRegression

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
f462c5592f
5 changed files with 41 additions and 8 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +14
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  4. +6
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  5. +15
    -7
      test/TensorFlowNET.Examples/LogisticRegression.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -21,6 +21,9 @@ namespace Tensorflow
public static Tensor log(Tensor x, string name = null)
=> gen_math_ops.log(x, name);

public static Tensor equal(Tensor x, Tensor y, string name = null)
=> gen_math_ops.equal(x, y, name);

public static Tensor multiply(Tensor x, Tensor y)
=> gen_math_ops.mul(x, y);



+ 14
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -128,6 +128,20 @@ namespace Tensorflow
return _op.outputs[0];
}
/// <summary>
/// Returns the truth value of (x == y) element-wise.
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor equal(Tensor x, Tensor y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y });
return _op.outputs[0];
}
public static Tensor mul(Tensor x, Tensor y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });


+ 3
- 0
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -43,6 +43,9 @@ namespace Tensorflow
case NDArray value:
result = value;
break;
case float fVal:
result = fVal;
break;
default:
break;
}


+ 6
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -168,7 +168,12 @@ namespace Tensorflow
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns></returns>
public NDArray eval(FeedItem[] feed_dict = null, Session session = null)
public NDArray eval(params FeedItem[] feed_dict)
{
return ops._eval_using_default_session(this, feed_dict, graph);
}

public NDArray eval(Session session, FeedItem[] feed_dict = null)
{
return ops._eval_using_default_session(this, feed_dict, graph, session);
}


+ 15
- 7
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -17,17 +17,14 @@ namespace TensorFlowNET.Examples
public class LogisticRegression : Python, IExample
{
private float learning_rate = 0.01f;
private int training_epochs = 25;
private int training_epochs = 5;
private int batch_size = 100;
private int display_step = 1;

public void Run()
{
PrepareData();
}
var mnist = PrepareData();

private void PrepareData()
{
// tf Graph Input
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
@@ -50,12 +47,12 @@ namespace TensorFlowNET.Examples

with(tf.Session(), sess =>
{
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
// Run the initializer
sess.run(init);

// Training cycle
foreach(var epoch in range(training_epochs))
foreach (var epoch in range(training_epochs))
{
var avg_cost = 0.0f;
var total_batch = mnist.train.num_examples / batch_size;
@@ -81,7 +78,18 @@ namespace TensorFlowNET.Examples
print("Optimization Finished!");

// Test model
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
// Calculate accuracy
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
print($"Accuracy: {acc.ToString("F4")}");
});
}

private Datasets PrepareData()
{
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
return mnist;
}
}
}

Loading…
Cancel
Save