You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

KMeansClustering.cs 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. using Tensorflow.Clustering;
  7. using TensorFlowNET.Examples.Utility;
  8. namespace TensorFlowNET.Examples
  9. {
  10. /// <summary>
  11. /// Implement K-Means algorithm with TensorFlow.NET, and apply it to classify
  12. /// handwritten digit images.
  13. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/kmeans.py
  14. /// </summary>
  15. public class KMeansClustering : Python, IExample
  16. {
  17. public int Priority => 8;
  18. public bool Enabled { get; set; } = true;
  19. public string Name => "K-means Clustering";
  20. public int? train_size = null;
  21. public int validation_size = 5000;
  22. public int? test_size = null;
  23. public int batch_size = 1024; // The number of samples per batch
  24. Datasets mnist;
  25. NDArray full_data_x;
  26. int num_steps = 50; // Total steps to train
  27. int k = 25; // The number of clusters
  28. int num_classes = 10; // The 10 digits
  29. int num_features = 784; // Each image is 28x28 pixels
  30. public bool Run()
  31. {
  32. PrepareData();
  33. var graph = tf.Graph().as_default();
  34. tf.train.import_meta_graph("kmeans.meta");
  35. // Input images
  36. var X = graph.get_operation_by_name("Placeholder").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
  37. // Labels (for assigning a label to a centroid and testing)
  38. var Y = graph.get_operation_by_name("Placeholder_1").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));
  39. // K-Means Parameters
  40. //var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true);
  41. // Build KMeans graph
  42. //var training_graph = kmeans.training_graph();
  43. var init_vars = tf.global_variables_initializer();
  44. Tensor init_op = graph.get_operation_by_name("cond/Merge");
  45. var train_op = graph.get_operation_by_name("group_deps");
  46. Tensor avg_distance = graph.get_operation_by_name("Mean");
  47. Tensor cluster_idx = graph.get_operation_by_name("Squeeze_1");
  48. with(tf.Session(graph), sess =>
  49. {
  50. sess.run(init_vars, new FeedItem(X, full_data_x));
  51. sess.run(init_op, new FeedItem(X, full_data_x));
  52. // Training
  53. foreach(var i in range(1, num_steps + 1))
  54. {
  55. var result = sess.run(new Tensor[] { avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
  56. }
  57. });
  58. return false;
  59. }
  60. public void PrepareData()
  61. {
  62. mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size);
  63. full_data_x = mnist.train.images;
  64. }
  65. }
  66. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。