You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NaiveBayesClassifier.cs 3.0 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. using NumSharp.Core;
  6. using System.Linq;
  7. namespace TensorFlowNET.Examples
  8. {
  9. /// <summary>
  10. /// https://github.com/nicolov/naive_bayes_tensorflow
  11. /// </summary>
  12. public class NaiveBayesClassifier : Python, IExample
  13. {
  14. public void Run()
  15. {
  16. np.array<float>(1.0f, 1.0f);
  17. var X = np.array<float>(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, });
  18. var y = np.array<int>(0,0,1,1,2,2);
  19. fit(X, y);
  20. // Create a regular grid and classify each point
  21. }
  22. public void fit(NDArray X, NDArray y)
  23. {
  24. NDArray unique_y = y.unique<long>();
  25. Dictionary<long, List<List<float>>> dic = new Dictionary<long, List<List<float>>>();
  26. // Init uy in dic
  27. foreach (int uy in unique_y.Data<int>())
  28. {
  29. dic.Add(uy, new List<List<float>>());
  30. }
  31. // Separate training points by class
  32. // Shape : nb_classes * nb_samples * nb_features
  33. int maxCount = 0;
  34. for (int i = 0; i < y.size; i++)
  35. {
  36. long curClass = (long)y[i];
  37. List<List<float>> l = dic[curClass];
  38. List<float> pair = new List<float>();
  39. pair.Add((float)X[i,0]);
  40. pair.Add((float)X[i, 1]);
  41. l.Add(pair);
  42. if (l.Count > maxCount)
  43. {
  44. maxCount = l.Count;
  45. }
  46. dic[curClass] = l;
  47. }
  48. float[,,] points = new float[dic.Count, maxCount, X.shape[1]];
  49. foreach (KeyValuePair<long, List<List<float>>> kv in dic)
  50. {
  51. int j = (int) kv.Key;
  52. for (int i = 0; i < maxCount; i++)
  53. {
  54. for (int k = 0; k < X.shape[1]; k++)
  55. {
  56. points[j, i, k] = kv.Value[i][k];
  57. }
  58. }
  59. }
  60. NDArray points_by_class = np.array<float>(points);
  61. // estimate mean and variance for each class / feature
  62. // shape : nb_classes * nb_features
  63. var cons = tf.constant(points_by_class);
  64. var tup = tf.nn.moments(cons, new int[]{1});
  65. var mean = tup.Item1;
  66. var variance = tup.Item2;
  67. // Create a 3x2 univariate normal distribution with the
  68. // Known mean and variance
  69. var dist = tf.distributions.Normal(mean, tf.sqrt(variance));
  70. }
  71. public void predict (NDArray X)
  72. {
  73. // assert self.dist is not None
  74. // nb_classes, nb_features = map(int, self.dist.scale.shape)
  75. throw new NotFiniteNumberException();
  76. }
  77. }
  78. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。