You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NearestNeighbor.cs 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. using TensorFlowNET.Examples.Utility;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// A nearest neighbor learning algorithm example
  11. /// This example is using the MNIST database of handwritten digits
  12. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py
  13. /// </summary>
  14. public class NearestNeighbor : Python, IExample
  15. {
  16. public int Priority => 5;
  17. public bool Enabled { get; set; } = true;
  18. public string Name => "Nearest Neighbor";
  19. Datasets mnist;
  20. NDArray Xtr, Ytr, Xte, Yte;
  21. public int DataSize = 5000;
  22. public int TestBatchSize = 200;
  23. public bool Run()
  24. {
  25. // tf Graph Input
  26. var xtr = tf.placeholder(tf.float32, new TensorShape(-1, 784));
  27. var xte = tf.placeholder(tf.float32, new TensorShape(784));
  28. // Nearest Neighbor calculation using L1 Distance
  29. // Calculate L1 Distance
  30. var distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices: 1);
  31. // Prediction: Get min distance index (Nearest neighbor)
  32. var pred = tf.arg_min(distance, 0);
  33. float accuracy = 0f;
  34. // Initialize the variables (i.e. assign their default value)
  35. var init = tf.global_variables_initializer();
  36. with(tf.Session(), sess =>
  37. {
  38. // Run the initializer
  39. sess.run(init);
  40. PrepareData();
  41. foreach(int i in range(Xte.shape[0]))
  42. {
  43. // Get nearest neighbor
  44. long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]));
  45. // Get nearest neighbor class label and compare it to its true label
  46. int index = (int)nn_index;
  47. print($"Test {i} Prediction: {np.argmax(Ytr[(NDArray)index])} True Class: {np.argmax(Yte[i] as NDArray)}");
  48. // Calculate accuracy
  49. if (np.argmax(Ytr[(NDArray)index]) == np.argmax(Yte[i] as NDArray))
  50. accuracy += 1f/ Xte.shape[0];
  51. }
  52. print($"Accuracy: {accuracy}");
  53. });
  54. return accuracy > 0.9;
  55. }
  56. public void PrepareData()
  57. {
  58. mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize);
  59. // In this example, we limit mnist data
  60. (Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates)
  61. (Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing
  62. }
  63. }
  64. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。