diff --git a/README.md b/README.md index 2c593a42..3ac68eed 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # TensorFlow.NET TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. -Here is a simple test +TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. + [![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) [![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) [![codecov](https://codecov.io/gh/SciSharp/NumSharp/branch/master/graph/badge.svg)](https://codecov.io/gh/SciSharp/NumSharp) @@ -8,7 +9,7 @@ Here is a simple test [![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) [![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) -TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). +TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). ![tensors_flowing](docs/assets/tensors_flowing.gif) @@ -24,14 +25,14 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr ### How to use -Install TensorFlow.NET through NuGet. +Install TF.NET through NuGet. ```sh PM> Install-Package TensorFlow.NET ``` If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. -Import tensorflow.net. +Import TF.NET. ```cs using Tensorflow; diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 12a4591b..ef066abe 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -17,10 +17,10 @@ namespace Tensorflow if (fetch.GetType().IsArray) return new _ListFetchMapper(fetches); else - return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0]); + return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0]); } - public virtual NDArray build_results(List values) + public virtual NDArray build_results(List values) { var type = values[0].GetType(); var nd = new NDArray(type, values.Count); @@ -31,16 +31,12 @@ namespace Tensorflow nd.SetData(values.Select(x => (float)x).ToArray()); break; case "NDArray": - // nd.SetData(values.ToArray()); - //NDArray[] arr = new NDArray[values.Count]; - //for (int i=0; i (NDArray)x).ToArray(); nd = new NDArray(arr); break; default: break; } - return nd; } diff --git a/test/TensorFlowNET.Examples/KMeansClustering.cs b/test/TensorFlowNET.Examples/KMeansClustering.cs index d18a1153..d9e2de47 100644 --- a/test/TensorFlowNET.Examples/KMeansClustering.cs +++ b/test/TensorFlowNET.Examples/KMeansClustering.cs @@ -33,19 +33,54 @@ namespace TensorFlowNET.Examples public bool Run() { + PrepareData(); + + var graph = tf.Graph().as_default(); + tf.train.import_meta_graph("kmeans.meta"); // Input images - var X = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); + var X = graph.get_operation_by_name("Placeholder").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); // Labels (for assigning a label to a centroid and testing) - var Y = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes)); + var Y = graph.get_operation_by_name("Placeholder_1").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes)); // K-Means Parameters - var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true); + //var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true); // Build KMeans graph - var training_graph = kmeans.training_graph(); - + //var training_graph = kmeans.training_graph(); + + var init_vars = tf.global_variables_initializer(); + Tensor init_op = graph.get_operation_by_name("cond/Merge"); + var train_op = graph.get_operation_by_name("group_deps"); + Tensor avg_distance = graph.get_operation_by_name("Mean"); + Tensor cluster_idx = graph.get_operation_by_name("Squeeze_1"); + + with(tf.Session(graph), sess => + { + sess.run(init_vars, new FeedItem(X, full_data_x)); + sess.run(init_op, new FeedItem(X, full_data_x)); + + // Training + NDArray result = null; + foreach(var i in range(1, num_steps + 1)) + { + result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); + if (i % 2 == 0 || i == 1) + print($"Step {i}, Avg Distance: {result[1]}"); + } + + var idx = result[2]; + + // Assign a label to each centroid + // Count total number of labels per centroid, using the label of each training + // sample to their closest centroid (given by 'idx') + var counts = np.zeros(k, num_classes); + foreach (var i in range(idx.len)) + counts[idx[i]] += mnist.train.labels[i]; + + }); + return false; } diff --git a/test/TensorFlowNET.Examples/ObjectDetection.cs b/test/TensorFlowNET.Examples/ObjectDetection.cs index 8c96c45a..5e76d4b7 100644 --- a/test/TensorFlowNET.Examples/ObjectDetection.cs +++ b/test/TensorFlowNET.Examples/ObjectDetection.cs @@ -50,13 +50,9 @@ namespace TensorFlowNET.Examples with(tf.Session(graph), sess => { var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr)); - - //NDArray scores = results.Array.GetValue(2) as NDArray; - - //floatscores.Data(); + NDArray[] resultArr = results.Data(); - - //float[] scores = resultArr[2].Data(); + buildOutputImage(resultArr); }); diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 28abce32..729a129d 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow; using Buffer = Tensorflow.Buffer; @@ -20,6 +21,13 @@ namespace TensorFlowNET.UnitTest var handle = c_api.TF_GetAllOpList(); var buffer = new Buffer(handle); var op_list = OpList.Parser.ParseFrom(buffer); + + var _registered_ops = new Dictionary(); + foreach (var op_def in op_list.Op) + _registered_ops[op_def.Name] = op_def; + + // r1.14 added NN op + var op = _registered_ops.FirstOrDefault(x => x.Key == "NearestNeighbors"); Assert.IsTrue(op_list.Op.Count > 1000); }