From 2d52f6065049b886c56ad4b8375ffa29d00e3b20 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 27 Mar 2019 01:23:36 -0500 Subject: [PATCH] add K-means, in progress. add validate_shape for variable. --- src/TensorFlowNET.Core/Clustering/KMeans.cs | 86 +++++++++++++++++++ .../Operations/array_ops.py.cs | 12 +++ .../TensorFlowNET.Core.csproj | 9 +- src/TensorFlowNET.Core/Tensors/dtypes.cs | 3 + src/TensorFlowNET.Core/Tensors/tensor_util.cs | 1 + .../Variables/variable_scope.py.cs | 1 + src/TensorFlowNET.Core/tf.cs | 2 + .../KMeansClustering.cs | 52 +++++++++++ 8 files changed, 162 insertions(+), 4 deletions(-) create mode 100644 src/TensorFlowNET.Core/Clustering/KMeans.cs create mode 100644 test/TensorFlowNET.Examples/KMeansClustering.cs diff --git a/src/TensorFlowNET.Core/Clustering/KMeans.cs b/src/TensorFlowNET.Core/Clustering/KMeans.cs new file mode 100644 index 00000000..ac945cb2 --- /dev/null +++ b/src/TensorFlowNET.Core/Clustering/KMeans.cs @@ -0,0 +1,86 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Clustering +{ + /// + /// Creates the graph for k-means clustering. + /// + public class KMeans : Python + { + public const string CLUSTERS_VAR_NAME = "clusters"; + + public const string SQUARED_EUCLIDEAN_DISTANCE = "squared_euclidean"; + public const string COSINE_DISTANCE = "cosine"; + public const string RANDOM_INIT = "random"; + public const string KMEANS_PLUS_PLUS_INIT = "kmeans_plus_plus"; + public const string KMC2_INIT = "kmc2"; + + Tensor[] _inputs; + int _num_clusters; + IInitializer _initial_clusters; + string _distance_metric; + bool _use_mini_batch; + int _mini_batch_steps_per_iteration; + int _random_seed; + int _kmeans_plus_plus_num_retries; + int _kmc2_chain_length; + + public KMeans(Tensor inputs, + int num_clusters, + IInitializer initial_clusters = null, + string distance_metric = SQUARED_EUCLIDEAN_DISTANCE, + bool use_mini_batch = false, + int mini_batch_steps_per_iteration = 1, + int random_seed = 0, + int kmeans_plus_plus_num_retries = 2, + int kmc2_chain_length = 200) + { + _inputs = new Tensor[] { inputs }; + _num_clusters = num_clusters; + _initial_clusters = initial_clusters; + _distance_metric = distance_metric; + _use_mini_batch = use_mini_batch; + _mini_batch_steps_per_iteration = mini_batch_steps_per_iteration; + _random_seed = random_seed; + _kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries; + _kmc2_chain_length = kmc2_chain_length; + } + + public object training_graph() + { + var initial_clusters = _initial_clusters; + var num_clusters = ops.convert_to_tensor(_num_clusters); + var inputs = _inputs; + _create_variables(num_clusters); + + throw new NotImplementedException("KMeans training_graph"); + } + + private RefVariable[] _create_variables(Tensor num_clusters) + { + var init_value = constant_op.constant(new float[0], dtype: TF_DataType.TF_FLOAT); + var cluster_centers = tf.Variable(init_value, name: CLUSTERS_VAR_NAME, validate_shape: false); + var cluster_centers_initialized = tf.Variable(false, dtype: TF_DataType.TF_BOOL, name: "initialized"); + RefVariable update_in_steps = null; + if (_use_mini_batch && _mini_batch_steps_per_iteration > 1) + throw new NotImplementedException("KMeans._create_variables"); + else + { + var cluster_centers_updated = cluster_centers; + var cluster_counts = _use_mini_batch ? + tf.Variable(array_ops.ones(new Tensor[] { num_clusters }, dtype: TF_DataType.TF_INT64)) : + null; + return new RefVariable[] + { + cluster_centers, + cluster_centers_initialized, + cluster_counts, + cluster_centers_updated, + update_in_steps + }; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index e6b3671c..0f2172e9 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -161,6 +161,18 @@ namespace Tensorflow }); } + public static Tensor ones(Tensor[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return with(ops.name_scope(name, "ones", new { shape }), scope => + { + name = scope; + var shape1 = ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32); + var output = gen_array_ops.fill(shape1, constant_op.constant(1, dtype: dtype), name: name); + return output; + }); + } + public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { dtype = dtype.as_base_dtype(); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 7e53ca9e..d691fc3a 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -4,7 +4,7 @@ netstandard2.0 TensorFlow.NET Tensorflow - 0.5.1 + 0.6.0 Haiping Chen SciSharp STACK true @@ -16,13 +16,14 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.5.1.0 + 0.6.0.0 Changes since v0.5: +Added K-means Clustering. Added Nearest Neighbor. -Add a lot of APIs to build neural networks model. +Added a lot of APIs to build neural networks model. Bug fix. 7.2 - 0.5.1.0 + 0.6.0.0 diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index ebce4e21..d461cc32 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -31,6 +31,9 @@ namespace Tensorflow switch (type.Name) { + case "Boolean": + dtype = TF_DataType.TF_BOOL; + break; case "Int32": dtype = TF_DataType.TF_INT32; break; diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 44b55259..73bc6006 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -229,6 +229,7 @@ namespace Tensorflow switch (nparray.dtype.Name) { case "Bool": + case "Boolean": tensor_proto.BoolVal.AddRange(proto_values.Data()); break; case "Int32": diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 1a15a6ac..beb5e703 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -156,6 +156,7 @@ namespace Tensorflow { return new RefVariable(initial_value, trainable: trainable.Value, + validate_shape: validate_shape, name: name, dtype: dtype); } diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 349ae39e..65551e7b 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -22,11 +22,13 @@ namespace Tensorflow public static RefVariable Variable(T data, bool trainable = true, + bool validate_shape = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) { return Tensorflow.variable_scope.default_variable_creator(data, trainable: trainable, + validate_shape: validate_shape, name: name, dtype: TF_DataType.DtInvalid); } diff --git a/test/TensorFlowNET.Examples/KMeansClustering.cs b/test/TensorFlowNET.Examples/KMeansClustering.cs new file mode 100644 index 00000000..1305fef5 --- /dev/null +++ b/test/TensorFlowNET.Examples/KMeansClustering.cs @@ -0,0 +1,52 @@ +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using Tensorflow.Clustering; +using TensorFlowNET.Examples.Utility; + +namespace TensorFlowNET.Examples +{ + /// + /// Implement K-Means algorithm with TensorFlow.NET, and apply it to classify + /// handwritten digit images. + /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/kmeans.py + /// + public class KMeansClustering : Python, IExample + { + public int Priority => 7; + public bool Enabled => true; + public string Name => "K-means Clustering"; + + Datasets mnist; + NDArray full_data_x; + int num_steps = 50; // Total steps to train + int batch_size = 1024; // The number of samples per batch + int k = 25; // The number of clusters + int num_classes = 10; // The 10 digits + int num_features = 784; // Each image is 28x28 pixels + + public bool Run() + { + // Input images + var X = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); + // Labels (for assigning a label to a centroid and testing) + var Y = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes)); + + // K-Means Parameters + var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true); + + // Build KMeans graph + var training_graph = kmeans.training_graph(); + + return false; + } + + public void PrepareData() + { + mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); + full_data_x = mnist.train.images; + } + } +}