diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 1448d9ae..7dc34df9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -26,6 +26,14 @@ namespace Tensorflow partition_strategy: partition_strategy, name: name); + public static Tensor embedding_lookup(Tensor @params, + Tensor ids, + string partition_strategy = "mod", + string name = null) => embedding_ops._embedding_lookup_and_transform(new Tensor[] { @params }, + ids, + partition_strategy: partition_strategy, + name: name); + public static IActivation relu() => new relu(); public static Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index a1538b9c..f4c28fba 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -333,10 +333,15 @@ namespace Tensorflow else { var rank = common_shapes.rank(x); - if (rank != null) + + // we rely on Range and Rank to do the right thing at run-time. + if (rank == -1) return range(0, array_ops.rank(x)); + + if (rank.HasValue && rank.Value > -1) { return constant_op.constant(np.arange(rank.Value), TF_DataType.TF_INT32); } + return range(0, rank, 1); } } diff --git a/test/TensorFlowNET.Examples/KMeansClustering.cs b/test/TensorFlowNET.Examples/KMeansClustering.cs index 1e04ac6e..b020a856 100644 --- a/test/TensorFlowNET.Examples/KMeansClustering.cs +++ b/test/TensorFlowNET.Examples/KMeansClustering.cs @@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples Datasets mnist; NDArray full_data_x; - int num_steps = 10; // Total steps to train + int num_steps = 20; // Total steps to train int k = 25; // The number of clusters int num_classes = 10; // The 10 digits int num_features = 784; // Each image is 28x28 pixels @@ -42,9 +42,9 @@ namespace TensorFlowNET.Examples tf.train.import_meta_graph("graph/kmeans.meta"); // Input images - var X = graph.get_operation_by_name("Placeholder").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); + Tensor X = graph.get_operation_by_name("Placeholder"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); // Labels (for assigning a label to a centroid and testing) - var Y = graph.get_operation_by_name("Placeholder_1").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes)); + Tensor Y = graph.get_operation_by_name("Placeholder_1"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes)); // K-Means Parameters //var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true); @@ -57,6 +57,7 @@ namespace TensorFlowNET.Examples var train_op = graph.get_operation_by_name("group_deps"); Tensor avg_distance = graph.get_operation_by_name("Mean"); Tensor cluster_idx = graph.get_operation_by_name("Squeeze_1"); + NDArray result = null; with(tf.Session(graph), sess => { @@ -64,19 +65,16 @@ namespace TensorFlowNET.Examples sess.run(init_op, new FeedItem(X, full_data_x)); // Training - NDArray result = null; var sw = new Stopwatch(); foreach (var i in range(1, num_steps + 1)) { - sw.Start(); + sw.Restart(); result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); sw.Stop(); - if (i % 5 == 0 || i == 1) + if (i % 4 == 0 || i == 1) print($"Step {i}, Avg Distance: {result[1]} Elapse: {sw.ElapsedMilliseconds}ms"); - - sw.Reset(); } var idx = result[2].Data(); @@ -102,9 +100,20 @@ namespace TensorFlowNET.Examples // Evaluation ops // Lookup: centroid_id -> label + var cluster_label = tf.nn.embedding_lookup(labels_map, cluster_idx); + + // Compute accuracy + var correct_prediction = tf.equal(cluster_label, tf.cast(tf.argmax(Y, 1), tf.int32)); + var cast = tf.cast(correct_prediction, tf.float32); + var accuracy_op = tf.reduce_mean(cast); + + // Test Model + var (test_x, test_y) = (mnist.test.images, mnist.test.labels); + result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y)); + print($"Test Accuracy: {result}"); }); - return false; + return (float)result > 0.70; } public void PrepareData() diff --git a/test/TensorFlowNET.Examples/NearestNeighbor.cs b/test/TensorFlowNET.Examples/NearestNeighbor.cs index f811b5da..2735e514 100644 --- a/test/TensorFlowNET.Examples/NearestNeighbor.cs +++ b/test/TensorFlowNET.Examples/NearestNeighbor.cs @@ -51,7 +51,10 @@ namespace TensorFlowNET.Examples long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i])); // Get nearest neighbor class label and compare it to its true label int index = (int)nn_index; - print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}"); + + if (i % 10 == 0 || i == 0) + print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}"); + // Calculate accuracy if ((int)np.argmax(Ytr[index]) == (int)np.argmax(Yte[i])) accuracy += 1f/ Xte.shape[0]; @@ -60,7 +63,7 @@ namespace TensorFlowNET.Examples print($"Accuracy: {accuracy}"); }); - return accuracy > 0.9; + return accuracy > 0.8; } public void PrepareData() diff --git a/test/TensorFlowNET.Examples/ObjectDetection.cs b/test/TensorFlowNET.Examples/ObjectDetection.cs index ae0e87c6..b6b6826c 100644 --- a/test/TensorFlowNET.Examples/ObjectDetection.cs +++ b/test/TensorFlowNET.Examples/ObjectDetection.cs @@ -16,7 +16,7 @@ namespace TensorFlowNET.Examples public class ObjectDetection : Python, IExample { public int Priority => 11; - public bool Enabled { get; set; } = true; + public bool Enabled { get; set; } = false; public string Name => "Object Detection"; public float MIN_SCORE = 0.5f; diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 69fd1592..fd1e64c8 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Drawing; using System.Linq; using System.Reflection; @@ -21,6 +22,7 @@ namespace TensorFlowNET.Examples .OrderBy(x => x.Priority) .ToArray(); + var sw = new Stopwatch(); foreach (IExample example in examples) { if (args.Length > 0 && !args.Contains(example.Name)) @@ -28,21 +30,31 @@ namespace TensorFlowNET.Examples Console.WriteLine($"{DateTime.UtcNow} Starting {example.Name}", Color.White); + try { if (example.Enabled) - if (example.Run()) - success.Add($"Example {example.Priority}: {example.Name}"); + { + sw.Restart(); + bool isSuccess = example.Run(); + sw.Stop(); + + if (isSuccess) + success.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s"); else - errors.Add($"Example {example.Priority}: {example.Name}"); + errors.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s"); + } else - disabled.Add($"Example {example.Priority}: {example.Name}"); + { + disabled.Add($"Example {example.Priority}: {example.Name} in {sw.ElapsedMilliseconds}ms"); + } } catch (Exception ex) { errors.Add($"Example {example.Priority}: {example.Name}"); Console.WriteLine(ex); } + Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); } diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs index df71b54e..c0b4eb1f 100644 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs @@ -45,7 +45,7 @@ namespace TensorFlowNET.ExamplesTests public void KMeansClustering() { tf.Graph().as_default(); - new KMeansClustering() { Enabled = false, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run(); + new KMeansClustering() { Enabled = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run(); } [TestMethod]