| @@ -1,5 +1,5 @@ | |||||
| # TensorFlow.NET | # TensorFlow.NET | ||||
| TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||||
| TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||||
| [](https://gitter.im/sci-sharp/community) | [](https://gitter.im/sci-sharp/community) | ||||
| [](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | [](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | ||||
| @@ -8,7 +8,7 @@ TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.ten | |||||
| [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | ||||
| [](https://996.icu/#/en_US) | [](https://996.icu/#/en_US) | ||||
| TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | |||||
| TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | |||||
|  |  | ||||
| @@ -24,14 +24,14 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr | |||||
| ### How to use | ### How to use | ||||
| Install TensorFlow.NET through NuGet. | |||||
| Install TF.NET through NuGet. | |||||
| ```sh | ```sh | ||||
| PM> Install-Package TensorFlow.NET | PM> Install-Package TensorFlow.NET | ||||
| ``` | ``` | ||||
| If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. | If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflowlib) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. | ||||
| Import tensorflow.net. | |||||
| Import TF.NET. | |||||
| ```cs | ```cs | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -10,9 +10,9 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class _ElementFetchMapper : _FetchMapper | public class _ElementFetchMapper : _FetchMapper | ||||
| { | { | ||||
| private Func<List<object>, object> _contraction_fn; | |||||
| private Func<List<NDArray>, object> _contraction_fn; | |||||
| public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | |||||
| public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn) | |||||
| { | { | ||||
| var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
| ITensorOrOperation el = null; | ITensorOrOperation el = null; | ||||
| @@ -31,7 +31,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="values"></param> | /// <param name="values"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public override NDArray build_results(List<object> values) | |||||
| public override NDArray build_results(List<NDArray> values) | |||||
| { | { | ||||
| NDArray result = null; | NDArray result = null; | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||||
| public NDArray build_results(BaseSession session, NDArray[] tensor_values) | public NDArray build_results(BaseSession session, NDArray[] tensor_values) | ||||
| { | { | ||||
| var full_values = new List<object>(); | |||||
| var full_values = new List<NDArray>(); | |||||
| if (_final_fetches.Count != tensor_values.Length) | if (_final_fetches.Count != tensor_values.Length) | ||||
| throw new InvalidOperationException("_final_fetches mismatch tensor_values"); | throw new InvalidOperationException("_final_fetches mismatch tensor_values"); | ||||
| @@ -17,21 +17,14 @@ namespace Tensorflow | |||||
| if (fetch.GetType().IsArray) | if (fetch.GetType().IsArray) | ||||
| return new _ListFetchMapper(fetches); | return new _ListFetchMapper(fetches); | ||||
| else | else | ||||
| return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]); | |||||
| return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0]); | |||||
| } | } | ||||
| public virtual NDArray build_results(List<object> values) | |||||
| public virtual NDArray build_results(List<NDArray> values) | |||||
| { | { | ||||
| var type = values[0].GetType(); | var type = values[0].GetType(); | ||||
| var nd = new NDArray(type, values.Count); | var nd = new NDArray(type, values.Count); | ||||
| switch (type.Name) | |||||
| { | |||||
| case "Single": | |||||
| nd.SetData(values.Select(x => (float)x).ToArray()); | |||||
| break; | |||||
| } | |||||
| nd.SetData(values.ToArray()); | |||||
| return nd; | return nd; | ||||
| } | } | ||||
| @@ -62,10 +62,23 @@ namespace TensorFlowNET.Examples | |||||
| sess.run(init_op, new FeedItem(X, full_data_x)); | sess.run(init_op, new FeedItem(X, full_data_x)); | ||||
| // Training | // Training | ||||
| NDArray result = null; | |||||
| foreach(var i in range(1, num_steps + 1)) | foreach(var i in range(1, num_steps + 1)) | ||||
| { | { | ||||
| var result = sess.run(new Tensor[] { avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); | |||||
| result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); | |||||
| if (i % 2 == 0 || i == 1) | |||||
| print($"Step {i}, Avg Distance: {result[1]}"); | |||||
| } | } | ||||
| var idx = result[2]; | |||||
| // Assign a label to each centroid | |||||
| // Count total number of labels per centroid, using the label of each training | |||||
| // sample to their closest centroid (given by 'idx') | |||||
| var counts = np.zeros(k, num_classes); | |||||
| foreach (var i in range(idx.len)) | |||||
| counts[idx[i]] += mnist.train.labels[i]; | |||||
| }); | }); | ||||
| return false; | return false; | ||||