Browse Source

Finished kmean model.

tags/v0.9
Oceania2018 6 years ago
parent
commit
e152782f6c
7 changed files with 55 additions and 18 deletions
  1. +8
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  2. +6
    -1
      src/TensorFlowNET.Core/Operations/math_ops.cs
  3. +18
    -9
      test/TensorFlowNET.Examples/KMeansClustering.cs
  4. +5
    -2
      test/TensorFlowNET.Examples/NearestNeighbor.cs
  5. +1
    -1
      test/TensorFlowNET.Examples/ObjectDetection.cs
  6. +16
    -4
      test/TensorFlowNET.Examples/Program.cs
  7. +1
    -1
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 8
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -26,6 +26,14 @@ namespace Tensorflow
partition_strategy: partition_strategy,
name: name);

public static Tensor embedding_lookup(Tensor @params,
Tensor ids,
string partition_strategy = "mod",
string name = null) => embedding_ops._embedding_lookup_and_transform(new Tensor[] { @params },
ids,
partition_strategy: partition_strategy,
name: name);

public static IActivation relu() => new relu();

public static Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name);


+ 6
- 1
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -333,10 +333,15 @@ namespace Tensorflow
else
{
var rank = common_shapes.rank(x);
if (rank != null)

// we rely on Range and Rank to do the right thing at run-time.
if (rank == -1) return range(0, array_ops.rank(x));

if (rank.HasValue && rank.Value > -1)
{
return constant_op.constant(np.arange(rank.Value), TF_DataType.TF_INT32);
}

return range(0, rank, 1);
}
}


+ 18
- 9
test/TensorFlowNET.Examples/KMeansClustering.cs View File

@@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples

Datasets mnist;
NDArray full_data_x;
int num_steps = 10; // Total steps to train
int num_steps = 20; // Total steps to train
int k = 25; // The number of clusters
int num_classes = 10; // The 10 digits
int num_features = 784; // Each image is 28x28 pixels
@@ -42,9 +42,9 @@ namespace TensorFlowNET.Examples
tf.train.import_meta_graph("graph/kmeans.meta");

// Input images
var X = graph.get_operation_by_name("Placeholder").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
Tensor X = graph.get_operation_by_name("Placeholder"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
// Labels (for assigning a label to a centroid and testing)
var Y = graph.get_operation_by_name("Placeholder_1").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));
Tensor Y = graph.get_operation_by_name("Placeholder_1"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));

// K-Means Parameters
//var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true);
@@ -57,6 +57,7 @@ namespace TensorFlowNET.Examples
var train_op = graph.get_operation_by_name("group_deps");
Tensor avg_distance = graph.get_operation_by_name("Mean");
Tensor cluster_idx = graph.get_operation_by_name("Squeeze_1");
NDArray result = null;

with(tf.Session(graph), sess =>
{
@@ -64,19 +65,16 @@ namespace TensorFlowNET.Examples
sess.run(init_op, new FeedItem(X, full_data_x));

// Training
NDArray result = null;
var sw = new Stopwatch();

foreach (var i in range(1, num_steps + 1))
{
sw.Start();
sw.Restart();
result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
sw.Stop();

if (i % 5 == 0 || i == 1)
if (i % 4 == 0 || i == 1)
print($"Step {i}, Avg Distance: {result[1]} Elapse: {sw.ElapsedMilliseconds}ms");

sw.Reset();
}

var idx = result[2].Data<int>();
@@ -102,9 +100,20 @@ namespace TensorFlowNET.Examples

// Evaluation ops
// Lookup: centroid_id -> label
var cluster_label = tf.nn.embedding_lookup(labels_map, cluster_idx);

// Compute accuracy
var correct_prediction = tf.equal(cluster_label, tf.cast(tf.argmax(Y, 1), tf.int32));
var cast = tf.cast(correct_prediction, tf.float32);
var accuracy_op = tf.reduce_mean(cast);

// Test Model
var (test_x, test_y) = (mnist.test.images, mnist.test.labels);
result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y));
print($"Test Accuracy: {result}");
});

return false;
return (float)result > 0.70;
}

public void PrepareData()


+ 5
- 2
test/TensorFlowNET.Examples/NearestNeighbor.cs View File

@@ -51,7 +51,10 @@ namespace TensorFlowNET.Examples
long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]));
// Get nearest neighbor class label and compare it to its true label
int index = (int)nn_index;
print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}");

if (i % 10 == 0 || i == 0)
print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}");

// Calculate accuracy
if ((int)np.argmax(Ytr[index]) == (int)np.argmax(Yte[i]))
accuracy += 1f/ Xte.shape[0];
@@ -60,7 +63,7 @@ namespace TensorFlowNET.Examples
print($"Accuracy: {accuracy}");
});

return accuracy > 0.9;
return accuracy > 0.8;
}

public void PrepareData()


+ 1
- 1
test/TensorFlowNET.Examples/ObjectDetection.cs View File

@@ -16,7 +16,7 @@ namespace TensorFlowNET.Examples
public class ObjectDetection : Python, IExample
{
public int Priority => 11;
public bool Enabled { get; set; } = true;
public bool Enabled { get; set; } = false;
public string Name => "Object Detection";
public float MIN_SCORE = 0.5f;



+ 16
- 4
test/TensorFlowNET.Examples/Program.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.Linq;
using System.Reflection;
@@ -21,6 +22,7 @@ namespace TensorFlowNET.Examples
.OrderBy(x => x.Priority)
.ToArray();

var sw = new Stopwatch();
foreach (IExample example in examples)
{
if (args.Length > 0 && !args.Contains(example.Name))
@@ -28,21 +30,31 @@ namespace TensorFlowNET.Examples

Console.WriteLine($"{DateTime.UtcNow} Starting {example.Name}", Color.White);

try
{
if (example.Enabled)
if (example.Run())
success.Add($"Example {example.Priority}: {example.Name}");
{
sw.Restart();
bool isSuccess = example.Run();
sw.Stop();

if (isSuccess)
success.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s");
else
errors.Add($"Example {example.Priority}: {example.Name}");
errors.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s");
}
else
disabled.Add($"Example {example.Priority}: {example.Name}");
{
disabled.Add($"Example {example.Priority}: {example.Name} in {sw.ElapsedMilliseconds}ms");
}
}
catch (Exception ex)
{
errors.Add($"Example {example.Priority}: {example.Name}");
Console.WriteLine(ex);
}

Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White);
}


+ 1
- 1
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -45,7 +45,7 @@ namespace TensorFlowNET.ExamplesTests
public void KMeansClustering()
{
tf.Graph().as_default();
new KMeansClustering() { Enabled = false, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
new KMeansClustering() { Enabled = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
}
[TestMethod]


Loading…
Cancel
Save