diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index d40c2cc0..b11f2889 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -98,6 +98,12 @@ namespace Tensorflow yield return (t1[i], t2[i]); } + public static IEnumerable<(T1, T2)> zip(NDArray t1, NDArray t2) + { + for (int i = 0; i < t1.size; i++) + yield return (t1.Data(i), t2.Data(i)); + } + public static IEnumerable<(int, T)> enumerate(IList values) { for (int i = 0; i < values.Count; i++) diff --git a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs index 3a559d16..11cb726a 100644 --- a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs +++ b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs @@ -14,14 +14,22 @@ namespace TensorFlowNET.Examples { public void Run() { - // t/f.nn.moments() + np.array(1.0f, 1.0f); + // var X = np.array(np.array(1.0f, 1.0f), np.array(2.0f, 2.0f), np.array(1.0f, -1.0f), np.array(2.0f, -2.0f), np.array(-1.0f, -1.0f), np.array(-1.0f, 1.0f),); + // var X = np.array(new float[][] { new float[] { 1.0f, 1.0f}, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, }); + var X = np.array(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, }); + var y = np.array(0,0,1,1,2,2); + + fit(X, y); + // Create a regular grid and classify each point + } public void fit(NDArray X, NDArray y) { NDArray unique_y = y.unique(); - Dictionary> dic = new Dictionary>(); + Dictionary> dic = new Dictionary>(); // Init uy in dic foreach (int uy in unique_y.Data()) { @@ -30,19 +38,19 @@ namespace TensorFlowNET.Examples // Separate training points by class // Shape : nb_classes * nb_samples * nb_features int maxCount = 0; - foreach (var (x, t) in zip(X.Data(), y.Data())) + for (int i = 0; i < y.size; i++) { - int curClass = (y[t, 0] as NDArray).Data().First(); + long curClass = (long)y[i]; List l = dic[curClass]; - l.Add(x); + l.Add(X[i] as NDArray); if (l.Count > maxCount) { maxCount = l.Count; } - dic.Add(curClass, l); + dic[curClass] = l; } - NDArray points_by_class = np.zeros(dic.Count,maxCount,X.shape[1]); - foreach (KeyValuePair> kv in dic) + NDArray points_by_class = np.zeros(new int[] { dic.Count, maxCount, X.shape[1] }); + foreach (KeyValuePair> kv in dic) { var cls = kv.Value.ToArray(); for (int i = 0; i < dic.Count; i++)