| @@ -0,0 +1,86 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Clustering | |||||
| { | |||||
| /// <summary> | |||||
| /// Creates the graph for k-means clustering. | |||||
| /// </summary> | |||||
| public class KMeans : Python | |||||
| { | |||||
| public const string CLUSTERS_VAR_NAME = "clusters"; | |||||
| public const string SQUARED_EUCLIDEAN_DISTANCE = "squared_euclidean"; | |||||
| public const string COSINE_DISTANCE = "cosine"; | |||||
| public const string RANDOM_INIT = "random"; | |||||
| public const string KMEANS_PLUS_PLUS_INIT = "kmeans_plus_plus"; | |||||
| public const string KMC2_INIT = "kmc2"; | |||||
| Tensor[] _inputs; | |||||
| int _num_clusters; | |||||
| IInitializer _initial_clusters; | |||||
| string _distance_metric; | |||||
| bool _use_mini_batch; | |||||
| int _mini_batch_steps_per_iteration; | |||||
| int _random_seed; | |||||
| int _kmeans_plus_plus_num_retries; | |||||
| int _kmc2_chain_length; | |||||
| public KMeans(Tensor inputs, | |||||
| int num_clusters, | |||||
| IInitializer initial_clusters = null, | |||||
| string distance_metric = SQUARED_EUCLIDEAN_DISTANCE, | |||||
| bool use_mini_batch = false, | |||||
| int mini_batch_steps_per_iteration = 1, | |||||
| int random_seed = 0, | |||||
| int kmeans_plus_plus_num_retries = 2, | |||||
| int kmc2_chain_length = 200) | |||||
| { | |||||
| _inputs = new Tensor[] { inputs }; | |||||
| _num_clusters = num_clusters; | |||||
| _initial_clusters = initial_clusters; | |||||
| _distance_metric = distance_metric; | |||||
| _use_mini_batch = use_mini_batch; | |||||
| _mini_batch_steps_per_iteration = mini_batch_steps_per_iteration; | |||||
| _random_seed = random_seed; | |||||
| _kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries; | |||||
| _kmc2_chain_length = kmc2_chain_length; | |||||
| } | |||||
| public object training_graph() | |||||
| { | |||||
| var initial_clusters = _initial_clusters; | |||||
| var num_clusters = ops.convert_to_tensor(_num_clusters); | |||||
| var inputs = _inputs; | |||||
| _create_variables(num_clusters); | |||||
| throw new NotImplementedException("KMeans training_graph"); | |||||
| } | |||||
| private RefVariable[] _create_variables(Tensor num_clusters) | |||||
| { | |||||
| var init_value = constant_op.constant(new float[0], dtype: TF_DataType.TF_FLOAT); | |||||
| var cluster_centers = tf.Variable(init_value, name: CLUSTERS_VAR_NAME, validate_shape: false); | |||||
| var cluster_centers_initialized = tf.Variable(false, dtype: TF_DataType.TF_BOOL, name: "initialized"); | |||||
| RefVariable update_in_steps = null; | |||||
| if (_use_mini_batch && _mini_batch_steps_per_iteration > 1) | |||||
| throw new NotImplementedException("KMeans._create_variables"); | |||||
| else | |||||
| { | |||||
| var cluster_centers_updated = cluster_centers; | |||||
| var cluster_counts = _use_mini_batch ? | |||||
| tf.Variable(array_ops.ones(new Tensor[] { num_clusters }, dtype: TF_DataType.TF_INT64)) : | |||||
| null; | |||||
| return new RefVariable[] | |||||
| { | |||||
| cluster_centers, | |||||
| cluster_centers_initialized, | |||||
| cluster_counts, | |||||
| cluster_centers_updated, | |||||
| update_in_steps | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -161,6 +161,18 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor ones(Tensor[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||||
| { | |||||
| dtype = dtype.as_base_dtype(); | |||||
| return with(ops.name_scope(name, "ones", new { shape }), scope => | |||||
| { | |||||
| name = scope; | |||||
| var shape1 = ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32); | |||||
| var output = gen_array_ops.fill(shape1, constant_op.constant(1, dtype: dtype), name: name); | |||||
| return output; | |||||
| }); | |||||
| } | |||||
| public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
| { | { | ||||
| dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
| @@ -4,7 +4,7 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <Version>0.5.1</Version> | |||||
| <Version>0.6.0</Version> | |||||
| <Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| @@ -16,13 +16,14 @@ | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | ||||
| <Description>Google's TensorFlow binding in .NET Standard. | <Description>Google's TensorFlow binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.5.1.0</AssemblyVersion> | |||||
| <AssemblyVersion>0.6.0.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Changes since v0.5: | <PackageReleaseNotes>Changes since v0.5: | ||||
| Added K-means Clustering. | |||||
| Added Nearest Neighbor. | Added Nearest Neighbor. | ||||
| Add a lot of APIs to build neural networks model. | |||||
| Added a lot of APIs to build neural networks model. | |||||
| Bug fix.</PackageReleaseNotes> | Bug fix.</PackageReleaseNotes> | ||||
| <LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
| <FileVersion>0.5.1.0</FileVersion> | |||||
| <FileVersion>0.6.0.0</FileVersion> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| @@ -31,6 +31,9 @@ namespace Tensorflow | |||||
| switch (type.Name) | switch (type.Name) | ||||
| { | { | ||||
| case "Boolean": | |||||
| dtype = TF_DataType.TF_BOOL; | |||||
| break; | |||||
| case "Int32": | case "Int32": | ||||
| dtype = TF_DataType.TF_INT32; | dtype = TF_DataType.TF_INT32; | ||||
| break; | break; | ||||
| @@ -229,6 +229,7 @@ namespace Tensorflow | |||||
| switch (nparray.dtype.Name) | switch (nparray.dtype.Name) | ||||
| { | { | ||||
| case "Bool": | case "Bool": | ||||
| case "Boolean": | |||||
| tensor_proto.BoolVal.AddRange(proto_values.Data<bool>()); | tensor_proto.BoolVal.AddRange(proto_values.Data<bool>()); | ||||
| break; | break; | ||||
| case "Int32": | case "Int32": | ||||
| @@ -156,6 +156,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| return new RefVariable(initial_value, | return new RefVariable(initial_value, | ||||
| trainable: trainable.Value, | trainable: trainable.Value, | ||||
| validate_shape: validate_shape, | |||||
| name: name, | name: name, | ||||
| dtype: dtype); | dtype: dtype); | ||||
| } | } | ||||
| @@ -22,11 +22,13 @@ namespace Tensorflow | |||||
| public static RefVariable Variable<T>(T data, | public static RefVariable Variable<T>(T data, | ||||
| bool trainable = true, | bool trainable = true, | ||||
| bool validate_shape = true, | |||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid) | TF_DataType dtype = TF_DataType.DtInvalid) | ||||
| { | { | ||||
| return Tensorflow.variable_scope.default_variable_creator(data, | return Tensorflow.variable_scope.default_variable_creator(data, | ||||
| trainable: trainable, | trainable: trainable, | ||||
| validate_shape: validate_shape, | |||||
| name: name, | name: name, | ||||
| dtype: TF_DataType.DtInvalid); | dtype: TF_DataType.DtInvalid); | ||||
| } | } | ||||
| @@ -0,0 +1,52 @@ | |||||
| using NumSharp.Core; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Clustering; | |||||
| using TensorFlowNET.Examples.Utility; | |||||
| namespace TensorFlowNET.Examples | |||||
| { | |||||
| /// <summary> | |||||
| /// Implement K-Means algorithm with TensorFlow.NET, and apply it to classify | |||||
| /// handwritten digit images. | |||||
| /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/kmeans.py | |||||
| /// </summary> | |||||
| public class KMeansClustering : Python, IExample | |||||
| { | |||||
| public int Priority => 7; | |||||
| public bool Enabled => true; | |||||
| public string Name => "K-means Clustering"; | |||||
| Datasets mnist; | |||||
| NDArray full_data_x; | |||||
| int num_steps = 50; // Total steps to train | |||||
| int batch_size = 1024; // The number of samples per batch | |||||
| int k = 25; // The number of clusters | |||||
| int num_classes = 10; // The 10 digits | |||||
| int num_features = 784; // Each image is 28x28 pixels | |||||
| public bool Run() | |||||
| { | |||||
| // Input images | |||||
| var X = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); | |||||
| // Labels (for assigning a label to a centroid and testing) | |||||
| var Y = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes)); | |||||
| // K-Means Parameters | |||||
| var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true); | |||||
| // Build KMeans graph | |||||
| var training_graph = kmeans.training_graph(); | |||||
| return false; | |||||
| } | |||||
| public void PrepareData() | |||||
| { | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| full_data_x = mnist.train.images; | |||||
| } | |||||
| } | |||||
| } | |||||