| @@ -0,0 +1,86 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Clustering | |||
| { | |||
| /// <summary> | |||
| /// Creates the graph for k-means clustering. | |||
| /// </summary> | |||
| public class KMeans : Python | |||
| { | |||
| public const string CLUSTERS_VAR_NAME = "clusters"; | |||
| public const string SQUARED_EUCLIDEAN_DISTANCE = "squared_euclidean"; | |||
| public const string COSINE_DISTANCE = "cosine"; | |||
| public const string RANDOM_INIT = "random"; | |||
| public const string KMEANS_PLUS_PLUS_INIT = "kmeans_plus_plus"; | |||
| public const string KMC2_INIT = "kmc2"; | |||
| Tensor[] _inputs; | |||
| int _num_clusters; | |||
| IInitializer _initial_clusters; | |||
| string _distance_metric; | |||
| bool _use_mini_batch; | |||
| int _mini_batch_steps_per_iteration; | |||
| int _random_seed; | |||
| int _kmeans_plus_plus_num_retries; | |||
| int _kmc2_chain_length; | |||
| public KMeans(Tensor inputs, | |||
| int num_clusters, | |||
| IInitializer initial_clusters = null, | |||
| string distance_metric = SQUARED_EUCLIDEAN_DISTANCE, | |||
| bool use_mini_batch = false, | |||
| int mini_batch_steps_per_iteration = 1, | |||
| int random_seed = 0, | |||
| int kmeans_plus_plus_num_retries = 2, | |||
| int kmc2_chain_length = 200) | |||
| { | |||
| _inputs = new Tensor[] { inputs }; | |||
| _num_clusters = num_clusters; | |||
| _initial_clusters = initial_clusters; | |||
| _distance_metric = distance_metric; | |||
| _use_mini_batch = use_mini_batch; | |||
| _mini_batch_steps_per_iteration = mini_batch_steps_per_iteration; | |||
| _random_seed = random_seed; | |||
| _kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries; | |||
| _kmc2_chain_length = kmc2_chain_length; | |||
| } | |||
| public object training_graph() | |||
| { | |||
| var initial_clusters = _initial_clusters; | |||
| var num_clusters = ops.convert_to_tensor(_num_clusters); | |||
| var inputs = _inputs; | |||
| _create_variables(num_clusters); | |||
| throw new NotImplementedException("KMeans training_graph"); | |||
| } | |||
| private RefVariable[] _create_variables(Tensor num_clusters) | |||
| { | |||
| var init_value = constant_op.constant(new float[0], dtype: TF_DataType.TF_FLOAT); | |||
| var cluster_centers = tf.Variable(init_value, name: CLUSTERS_VAR_NAME, validate_shape: false); | |||
| var cluster_centers_initialized = tf.Variable(false, dtype: TF_DataType.TF_BOOL, name: "initialized"); | |||
| RefVariable update_in_steps = null; | |||
| if (_use_mini_batch && _mini_batch_steps_per_iteration > 1) | |||
| throw new NotImplementedException("KMeans._create_variables"); | |||
| else | |||
| { | |||
| var cluster_centers_updated = cluster_centers; | |||
| var cluster_counts = _use_mini_batch ? | |||
| tf.Variable(array_ops.ones(new Tensor[] { num_clusters }, dtype: TF_DataType.TF_INT64)) : | |||
| null; | |||
| return new RefVariable[] | |||
| { | |||
| cluster_centers, | |||
| cluster_centers_initialized, | |||
| cluster_counts, | |||
| cluster_centers_updated, | |||
| update_in_steps | |||
| }; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -161,6 +161,18 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static Tensor ones(Tensor[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||
| { | |||
| dtype = dtype.as_base_dtype(); | |||
| return with(ops.name_scope(name, "ones", new { shape }), scope => | |||
| { | |||
| name = scope; | |||
| var shape1 = ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32); | |||
| var output = gen_array_ops.fill(shape1, constant_op.constant(1, dtype: dtype), name: name); | |||
| return output; | |||
| }); | |||
| } | |||
| public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||
| { | |||
| dtype = dtype.as_base_dtype(); | |||
| @@ -4,7 +4,7 @@ | |||
| <TargetFramework>netstandard2.0</TargetFramework> | |||
| <AssemblyName>TensorFlow.NET</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <Version>0.5.1</Version> | |||
| <Version>0.6.0</Version> | |||
| <Authors>Haiping Chen</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| @@ -16,13 +16,14 @@ | |||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | |||
| <Description>Google's TensorFlow binding in .NET Standard. | |||
| Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.5.1.0</AssemblyVersion> | |||
| <AssemblyVersion>0.6.0.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.5: | |||
| Added K-means Clustering. | |||
| Added Nearest Neighbor. | |||
| Add a lot of APIs to build neural networks model. | |||
| Added a lot of APIs to build neural networks model. | |||
| Bug fix.</PackageReleaseNotes> | |||
| <LangVersion>7.2</LangVersion> | |||
| <FileVersion>0.5.1.0</FileVersion> | |||
| <FileVersion>0.6.0.0</FileVersion> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| @@ -31,6 +31,9 @@ namespace Tensorflow | |||
| switch (type.Name) | |||
| { | |||
| case "Boolean": | |||
| dtype = TF_DataType.TF_BOOL; | |||
| break; | |||
| case "Int32": | |||
| dtype = TF_DataType.TF_INT32; | |||
| break; | |||
| @@ -229,6 +229,7 @@ namespace Tensorflow | |||
| switch (nparray.dtype.Name) | |||
| { | |||
| case "Bool": | |||
| case "Boolean": | |||
| tensor_proto.BoolVal.AddRange(proto_values.Data<bool>()); | |||
| break; | |||
| case "Int32": | |||
| @@ -156,6 +156,7 @@ namespace Tensorflow | |||
| { | |||
| return new RefVariable(initial_value, | |||
| trainable: trainable.Value, | |||
| validate_shape: validate_shape, | |||
| name: name, | |||
| dtype: dtype); | |||
| } | |||
| @@ -22,11 +22,13 @@ namespace Tensorflow | |||
| public static RefVariable Variable<T>(T data, | |||
| bool trainable = true, | |||
| bool validate_shape = true, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| return Tensorflow.variable_scope.default_variable_creator(data, | |||
| trainable: trainable, | |||
| validate_shape: validate_shape, | |||
| name: name, | |||
| dtype: TF_DataType.DtInvalid); | |||
| } | |||
| @@ -0,0 +1,52 @@ | |||
| using NumSharp.Core; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using Tensorflow.Clustering; | |||
| using TensorFlowNET.Examples.Utility; | |||
| namespace TensorFlowNET.Examples | |||
| { | |||
| /// <summary> | |||
| /// Implement K-Means algorithm with TensorFlow.NET, and apply it to classify | |||
| /// handwritten digit images. | |||
| /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/kmeans.py | |||
| /// </summary> | |||
| public class KMeansClustering : Python, IExample | |||
| { | |||
| public int Priority => 7; | |||
| public bool Enabled => true; | |||
| public string Name => "K-means Clustering"; | |||
| Datasets mnist; | |||
| NDArray full_data_x; | |||
| int num_steps = 50; // Total steps to train | |||
| int batch_size = 1024; // The number of samples per batch | |||
| int k = 25; // The number of clusters | |||
| int num_classes = 10; // The 10 digits | |||
| int num_features = 784; // Each image is 28x28 pixels | |||
| public bool Run() | |||
| { | |||
| // Input images | |||
| var X = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); | |||
| // Labels (for assigning a label to a centroid and testing) | |||
| var Y = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes)); | |||
| // K-Means Parameters | |||
| var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true); | |||
| // Build KMeans graph | |||
| var training_graph = kmeans.training_graph(); | |||
| return false; | |||
| } | |||
| public void PrepareData() | |||
| { | |||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||
| full_data_x = mnist.train.images; | |||
| } | |||
| } | |||
| } | |||