Browse Source

print Accuracy of LogisticRegression

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
f462c5592f
5 changed files with 41 additions and 8 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +14
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  4. +6
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  5. +15
    -7
      test/TensorFlowNET.Examples/LogisticRegression.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -21,6 +21,9 @@ namespace Tensorflow
public static Tensor log(Tensor x, string name = null) public static Tensor log(Tensor x, string name = null)
=> gen_math_ops.log(x, name); => gen_math_ops.log(x, name);


public static Tensor equal(Tensor x, Tensor y, string name = null)
=> gen_math_ops.equal(x, y, name);

public static Tensor multiply(Tensor x, Tensor y) public static Tensor multiply(Tensor x, Tensor y)
=> gen_math_ops.mul(x, y); => gen_math_ops.mul(x, y);




+ 14
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -128,6 +128,20 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }
/// <summary>
/// Returns the truth value of (x == y) element-wise.
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor equal(Tensor x, Tensor y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y });
return _op.outputs[0];
}
public static Tensor mul(Tensor x, Tensor y, string name = null) public static Tensor mul(Tensor x, Tensor y, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });


+ 3
- 0
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -43,6 +43,9 @@ namespace Tensorflow
case NDArray value: case NDArray value:
result = value; result = value;
break; break;
case float fVal:
result = fVal;
break;
default: default:
break; break;
} }


+ 6
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -168,7 +168,12 @@ namespace Tensorflow
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param> /// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns></returns> /// <returns></returns>
public NDArray eval(FeedItem[] feed_dict = null, Session session = null)
public NDArray eval(params FeedItem[] feed_dict)
{
return ops._eval_using_default_session(this, feed_dict, graph);
}

public NDArray eval(Session session, FeedItem[] feed_dict = null)
{ {
return ops._eval_using_default_session(this, feed_dict, graph, session); return ops._eval_using_default_session(this, feed_dict, graph, session);
} }


+ 15
- 7
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -17,17 +17,14 @@ namespace TensorFlowNET.Examples
public class LogisticRegression : Python, IExample public class LogisticRegression : Python, IExample
{ {
private float learning_rate = 0.01f; private float learning_rate = 0.01f;
private int training_epochs = 25;
private int training_epochs = 5;
private int batch_size = 100; private int batch_size = 100;
private int display_step = 1; private int display_step = 1;


public void Run() public void Run()
{ {
PrepareData();
}
var mnist = PrepareData();


private void PrepareData()
{
// tf Graph Input // tf Graph Input
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784 var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
@@ -50,12 +47,12 @@ namespace TensorFlowNET.Examples


with(tf.Session(), sess => with(tf.Session(), sess =>
{ {
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
// Run the initializer // Run the initializer
sess.run(init); sess.run(init);


// Training cycle // Training cycle
foreach(var epoch in range(training_epochs))
foreach (var epoch in range(training_epochs))
{ {
var avg_cost = 0.0f; var avg_cost = 0.0f;
var total_batch = mnist.train.num_examples / batch_size; var total_batch = mnist.train.num_examples / batch_size;
@@ -81,7 +78,18 @@ namespace TensorFlowNET.Examples
print("Optimization Finished!"); print("Optimization Finished!");


// Test model // Test model
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
// Calculate accuracy
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
print($"Accuracy: {acc.ToString("F4")}");
}); });
} }

private Datasets PrepareData()
{
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
return mnist;
}
} }
} }

Loading…
Cancel
Save