From 8654b41c2a0cb4360c803d6be8d2db2974224276 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 24 Mar 2019 22:58:25 -0500 Subject: [PATCH] Add Nearest Neighbor release v0.5.1 --- README.md | 1 + docs/source/NearestNeighbor.md | 3 + docs/source/index.rst | 1 + src/TensorFlowNET.Core/APIs/tf.math.cs | 12 ++++ src/TensorFlowNET.Core/Graphs/Graph.cs | 2 +- .../Operations/gen_math_ops.cs | 46 ++++++++---- .../{math_ops.py.cs => math_ops.cs} | 29 +++++++- .../Sessions/BaseSession.cs | 6 ++ .../Sessions/_FetchHandler.cs | 5 +- .../TensorFlowNET.Core.csproj | 14 ++-- .../Variables/variables.py.cs | 9 ++- src/TensorFlowNET.Core/ops.py.cs | 2 +- .../ImageRecognition.cs | 2 +- .../LogisticRegression.cs | 2 +- .../TensorFlowNET.Examples/NearestNeighbor.cs | 70 +++++++++++++++++++ test/TensorFlowNET.Examples/Program.cs | 12 ++-- .../TensorFlowNET.Examples.csproj | 2 +- .../TextClassificationWithMovieReviews.cs | 2 +- .../TensorFlowNET.UnitTest.csproj | 4 +- 19 files changed, 184 insertions(+), 40 deletions(-) create mode 100644 docs/source/NearestNeighbor.md rename src/TensorFlowNET.Core/Operations/{math_ops.py.cs => math_ops.cs} (93%) create mode 100644 test/TensorFlowNET.Examples/NearestNeighbor.cs diff --git a/README.md b/README.md index 2900da7a..da0bf7f6 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow * [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs) * [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs) * [Logistic Regression](test/TensorFlowNET.Examples/LogisticRegression.cs) +* [Nearest Neighbor](test/TensorFlowNET.Examples/NearestNeighbor.cs) * [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs) * [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs) * [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs) diff --git a/docs/source/NearestNeighbor.md b/docs/source/NearestNeighbor.md new file mode 100644 index 00000000..fa846e0c --- /dev/null +++ b/docs/source/NearestNeighbor.md @@ -0,0 +1,3 @@ +# Chapter. Nearest Neighbor + +The nearest neighbour algorithm was one of the first algorithms used to solve the travelling salesman problem. In it, the salesman starts at a random city and repeatedly visits the nearest city until all have been visited. It quickly yields a short tour, but usually not the optimal one. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 51ed7727..e26f8378 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -27,4 +27,5 @@ Welcome to TensorFlow.NET's documentation! EagerMode LinearRegression LogisticRegression + NearestNeighbor ImageRecognition \ No newline at end of file diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index aa645931..3eba6032 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -6,9 +6,18 @@ namespace Tensorflow { public static partial class tf { + public static Tensor abs(Tensor x, string name = null) + => math_ops.abs(x, name); + public static Tensor add(Tensor a, Tensor b) => gen_math_ops.add(a, b); + public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name); + + public static Tensor arg_min(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name); + public static Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b); @@ -27,6 +36,9 @@ namespace Tensorflow public static Tensor multiply(Tensor x, Tensor y) => gen_math_ops.mul(x, y); + public static Tensor negative(Tensor x, string name = null) + => gen_math_ops.neg(x, name); + public static Tensor divide(Tensor x, T[] y, string name = null) where T : struct => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 71faa045..916c42a7 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -355,7 +355,7 @@ namespace Tensorflow return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); } - public object get_collection(string name, string scope = "") + public object get_collection(string name, string scope = null) { return _collections.ContainsKey(name) ? _collections[name] : null; } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 52388845..5a4efd32 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -9,6 +9,30 @@ namespace Tensorflow public static class gen_math_ops { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + + /// + /// Returns the index with the largest value across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => _op_def_lib._apply_op_helper("ArgMax", name, args: new { input, dimension, output_type }).outputs[0]; + + /// + /// Returns the index with the smallest value across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor arg_min(Tensor input, int dimension, TF_DataType output_type= TF_DataType.TF_INT64, string name= null) + =>_op_def_lib._apply_op_helper("ArgMin", name, args: new { input, dimension, output_type }).outputs[0]; + + /// /// Computes the mean of elements across dimensions of a tensor. /// Reduces `input` along the dimensions given in `axis`. Unless /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in /// `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1. @@ -207,6 +231,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor _abs(Tensor x, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Abs", name, new { x }); + + return _op.outputs[0]; + } + public static Tensor _max(Tx input, Ty axis, bool keep_dims=false, string name = null) { var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }); @@ -249,20 +280,5 @@ namespace Tensorflow return _op.outputs[0]; } - - /// - /// Returns the index with the largest value across dimensions of a tensor. - /// - /// - /// - /// - /// - /// - public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) - { - var _op = _op_def_lib._apply_op_helper("ArgMax", name, new { input, dimension, output_type }); - - return _op.outputs[0]; - } } } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs similarity index 93% rename from src/TensorFlowNET.Core/Operations/math_ops.py.cs rename to src/TensorFlowNET.Core/Operations/math_ops.cs index 47dc2a81..8077a87a 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -6,9 +6,36 @@ using Tensorflow.Framework; namespace Tensorflow { + /// + /// python\ops\math_ops.py + /// public class math_ops : Python { - public static Tensor add(Tensor x, Tensor y, string name = null) => gen_math_ops.add(x, y, name); + public static Tensor abs(Tensor x, string name = null) + { + return with(ops.name_scope(name, "Abs", new { x }), scope => + { + x = ops.convert_to_tensor(x, name: "x"); + if (x.dtype.is_complex()) + throw new NotImplementedException("math_ops.abs for dtype.is_complex"); + //return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name); + return gen_math_ops._abs(x, name: name); + }); + } + + public static Tensor add(Tensor x, Tensor y, string name = null) + => gen_math_ops.add(x, y, name); + + public static Tensor add(Tensor x, string name = null) + { + return with(ops.name_scope(name, "Abs", new { x }), scope => + { + name = scope; + x = ops.convert_to_tensor(x, name: "x"); + + return gen_math_ops._abs(x, name: name); + }); + } public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) { diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 5c84f34a..6d1977a8 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -222,6 +222,12 @@ namespace Tensorflow ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); nd = np.array(ints).reshape(ndims); break; + case TF_DataType.TF_INT64: + var longs = new long[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(longs).reshape(ndims); + break; case TF_DataType.TF_FLOAT: var floats = new float[tensor.size]; for (ulong i = 0; i < tensor.size; i++) diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index b101f4bf..20194c37 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -65,6 +65,9 @@ namespace Tensorflow case "Int32": full_values.Add(value.Data(0)); break; + case "Int64": + full_values.Add(value.Data(0)); + break; case "Single": full_values.Add(value.Data(0)); break; @@ -78,7 +81,7 @@ namespace Tensorflow } else { - full_values.Add(value[np.arange(1)]); + full_values.Add(value[np.arange(0, value.shape[0])]); } } i += 1; diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index faf799a8..7e53ca9e 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -4,7 +4,7 @@ netstandard2.0 TensorFlow.NET Tensorflow - 0.5.0 + 0.5.1 Haiping Chen SciSharp STACK true @@ -16,11 +16,13 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.5.0.0 - Add Logistic Regression to do MNIST. -Add a lot of APIs to build neural networks model + 0.5.1.0 + Changes since v0.5: +Added Nearest Neighbor. +Add a lot of APIs to build neural networks model. +Bug fix. 7.2 - 0.5.0.0 + 0.5.1.0 @@ -44,7 +46,7 @@ Add a lot of APIs to build neural networks model - + diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 5cde1359..4f11a7a8 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -47,11 +47,11 @@ namespace Tensorflow /// special tokens filters by prefix. /// /// A list of `Variable` objects. - public static List global_variables(string scope = "") + public static List global_variables(string scope = null) { var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); - return result as List; + return result == null ? new List() : result as List; } /// @@ -62,7 +62,10 @@ namespace Tensorflow /// An Op that run the initializers of all the specified variables. public static Operation variables_initializer(RefVariable[] var_list, string name = "init") { - return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray(), name); + if (var_list.Length > 0) + return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray(), name); + else + return gen_control_flow_ops.no_op(name: name); } } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 32ff4501..92e1bbb0 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -41,7 +41,7 @@ namespace Tensorflow /// list contains the values in the order under which they were /// collected. /// - public static object get_collection(string key, string scope = "") + public static object get_collection(string key, string scope = null) { return get_default_graph().get_collection(key, scope); } diff --git a/test/TensorFlowNET.Examples/ImageRecognition.cs b/test/TensorFlowNET.Examples/ImageRecognition.cs index 101ff836..9a5805eb 100644 --- a/test/TensorFlowNET.Examples/ImageRecognition.cs +++ b/test/TensorFlowNET.Examples/ImageRecognition.cs @@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples { public class ImageRecognition : Python, IExample { - public int Priority => 5; + public int Priority => 6; public bool Enabled => true; public string Name => "Image Recognition"; diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index 7b5925ef..44895932 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -98,7 +98,7 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); + mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); } public void SaveModel(Session sess) diff --git a/test/TensorFlowNET.Examples/NearestNeighbor.cs b/test/TensorFlowNET.Examples/NearestNeighbor.cs new file mode 100644 index 00000000..f8899315 --- /dev/null +++ b/test/TensorFlowNET.Examples/NearestNeighbor.cs @@ -0,0 +1,70 @@ +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using TensorFlowNET.Examples.Utility; + +namespace TensorFlowNET.Examples +{ + /// + /// A nearest neighbor learning algorithm example + /// This example is using the MNIST database of handwritten digits + /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py + /// + public class NearestNeighbor : Python, IExample + { + public int Priority => 5; + public bool Enabled => true; + public string Name => "Nearest Neighbor"; + Datasets mnist; + NDArray Xtr, Ytr, Xte, Yte; + + public bool Run() + { + // tf Graph Input + var xtr = tf.placeholder(tf.float32, new TensorShape(-1, 784)); + var xte = tf.placeholder(tf.float32, new TensorShape(784)); + + // Nearest Neighbor calculation using L1 Distance + // Calculate L1 Distance + var distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices: 1); + // Prediction: Get min distance index (Nearest neighbor) + var pred = tf.arg_min(distance, 0); + + float accuracy = 0f; + // Initialize the variables (i.e. assign their default value) + var init = tf.global_variables_initializer(); + with(tf.Session(), sess => + { + // Run the initializer + sess.run(init); + + PrepareData(); + + foreach(int i in range(Xte.shape[0])) + { + // Get nearest neighbor + long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i])); + // Get nearest neighbor class label and compare it to its true label + print($"Test {i} Prediction: {np.argmax(Ytr[nn_index])} True Class: {np.argmax(Yte[i] as NDArray)}"); + // Calculate accuracy + if (np.argmax(Ytr[nn_index]) == np.argmax(Yte[i] as NDArray)) + accuracy += 1f/ Xte.shape[0]; + } + + print($"Accuracy: {accuracy}"); + }); + + return accuracy > 0.9; + } + + public void PrepareData() + { + mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); + // In this example, we limit mnist data + (Xtr, Ytr) = mnist.train.next_batch(5000); // 5000 for training (nn candidates) + (Xte, Yte) = mnist.test.next_batch(200); // 200 for testing + } + } +} diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 64ec3449..44448caf 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -32,11 +32,11 @@ namespace TensorFlowNET.Examples { if (example.Enabled) if (example.Run()) - success.Add($"{example.Priority} {example.Name}"); + success.Add($"Example {example.Priority}: {example.Name}"); else - errors.Add($"{example.Priority} {example.Name}"); + errors.Add($"Example {example.Priority}: {example.Name}"); else - disabled.Add($"{example.Priority} {example.Name}"); + disabled.Add($"Example {example.Priority}: {example.Name}"); } catch (Exception ex) { @@ -46,9 +46,9 @@ namespace TensorFlowNET.Examples Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); } - success.ForEach(x => Console.WriteLine($"{x} example is OK!", Color.Green)); - disabled.ForEach(x => Console.WriteLine($"{x} example is Disabled!", Color.Tan)); - errors.ForEach(x => Console.WriteLine($"{x} example is Failed!", Color.Red)); + success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green)); + disabled.ForEach(x => Console.WriteLine($"{x} is Disabled!", Color.Tan)); + errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); Console.ReadLine(); } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 9d3ba5a8..59fc1fd0 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -8,7 +8,7 @@ - + diff --git a/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs index 9068b17a..a1e0fd74 100644 --- a/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs +++ b/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs @@ -11,7 +11,7 @@ namespace TensorFlowNET.Examples { public class TextClassificationWithMovieReviews : Python, IExample { - public int Priority => 6; + public int Priority => 7; public bool Enabled => false; public string Name => "Movie Reviews"; diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index d37356b2..af1f38a2 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -19,8 +19,8 @@ - - + +