| @@ -73,6 +73,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow | |||||
| * [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs) | * [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs) | ||||
| * [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs) | * [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs) | ||||
| * [Logistic Regression](test/TensorFlowNET.Examples/LogisticRegression.cs) | * [Logistic Regression](test/TensorFlowNET.Examples/LogisticRegression.cs) | ||||
| * [Nearest Neighbor](test/TensorFlowNET.Examples/NearestNeighbor.cs) | |||||
| * [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs) | * [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs) | ||||
| * [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs) | * [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs) | ||||
| * [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs) | * [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs) | ||||
| @@ -0,0 +1,3 @@ | |||||
| # Chapter. Nearest Neighbor | |||||
| The nearest neighbour algorithm was one of the first algorithms used to solve the travelling salesman problem. In it, the salesman starts at a random city and repeatedly visits the nearest city until all have been visited. It quickly yields a short tour, but usually not the optimal one. | |||||
| @@ -27,4 +27,5 @@ Welcome to TensorFlow.NET's documentation! | |||||
| EagerMode | EagerMode | ||||
| LinearRegression | LinearRegression | ||||
| LogisticRegression | LogisticRegression | ||||
| NearestNeighbor | |||||
| ImageRecognition | ImageRecognition | ||||
| @@ -6,9 +6,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static partial class tf | public static partial class tf | ||||
| { | { | ||||
| public static Tensor abs(Tensor x, string name = null) | |||||
| => math_ops.abs(x, name); | |||||
| public static Tensor add(Tensor a, Tensor b) | public static Tensor add(Tensor a, Tensor b) | ||||
| => gen_math_ops.add(a, b); | => gen_math_ops.add(a, b); | ||||
| public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) | |||||
| => gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name); | |||||
| public static Tensor arg_min(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) | |||||
| => gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name); | |||||
| public static Tensor sub(Tensor a, Tensor b) | public static Tensor sub(Tensor a, Tensor b) | ||||
| => gen_math_ops.sub(a, b); | => gen_math_ops.sub(a, b); | ||||
| @@ -27,6 +36,9 @@ namespace Tensorflow | |||||
| public static Tensor multiply(Tensor x, Tensor y) | public static Tensor multiply(Tensor x, Tensor y) | ||||
| => gen_math_ops.mul(x, y); | => gen_math_ops.mul(x, y); | ||||
| public static Tensor negative(Tensor x, string name = null) | |||||
| => gen_math_ops.neg(x, name); | |||||
| public static Tensor divide<T>(Tensor x, T[] y, string name = null) where T : struct | public static Tensor divide<T>(Tensor x, T[] y, string name = null) where T : struct | ||||
| => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); | => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); | ||||
| @@ -355,7 +355,7 @@ namespace Tensorflow | |||||
| return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | ||||
| } | } | ||||
| public object get_collection(string name, string scope = "") | |||||
| public object get_collection(string name, string scope = null) | |||||
| { | { | ||||
| return _collections.ContainsKey(name) ? _collections[name] : null; | return _collections.ContainsKey(name) ? _collections[name] : null; | ||||
| } | } | ||||
| @@ -9,6 +9,30 @@ namespace Tensorflow | |||||
| public static class gen_math_ops | public static class gen_math_ops | ||||
| { | { | ||||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
| /// <summary> | |||||
| /// Returns the index with the largest value across dimensions of a tensor. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="dimension"></param> | |||||
| /// <param name="output_type"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) | |||||
| => _op_def_lib._apply_op_helper("ArgMax", name, args: new { input, dimension, output_type }).outputs[0]; | |||||
| /// <summary> | |||||
| /// Returns the index with the smallest value across dimensions of a tensor. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="dimension"></param> | |||||
| /// <param name="output_type"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor arg_min(Tensor input, int dimension, TF_DataType output_type= TF_DataType.TF_INT64, string name= null) | |||||
| =>_op_def_lib._apply_op_helper("ArgMin", name, args: new { input, dimension, output_type }).outputs[0]; | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the mean of elements across dimensions of a tensor. | /// Computes the mean of elements across dimensions of a tensor. | ||||
| /// Reduces `input` along the dimensions given in `axis`. Unless /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in /// `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1. | /// Reduces `input` along the dimensions given in `axis`. Unless /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in /// `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1. | ||||
| @@ -207,6 +231,13 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor _abs(Tensor x, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Abs", name, new { x }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims=false, string name = null) | public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims=false, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }); | var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }); | ||||
| @@ -249,20 +280,5 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns the index with the largest value across dimensions of a tensor. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="dimension"></param> | |||||
| /// <param name="output_type"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("ArgMax", name, new { input, dimension, output_type }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,9 +6,36 @@ using Tensorflow.Framework; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | |||||
| /// python\ops\math_ops.py | |||||
| /// </summary> | |||||
| public class math_ops : Python | public class math_ops : Python | ||||
| { | { | ||||
| public static Tensor add(Tensor x, Tensor y, string name = null) => gen_math_ops.add(x, y, name); | |||||
| public static Tensor abs(Tensor x, string name = null) | |||||
| { | |||||
| return with(ops.name_scope(name, "Abs", new { x }), scope => | |||||
| { | |||||
| x = ops.convert_to_tensor(x, name: "x"); | |||||
| if (x.dtype.is_complex()) | |||||
| throw new NotImplementedException("math_ops.abs for dtype.is_complex"); | |||||
| //return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name); | |||||
| return gen_math_ops._abs(x, name: name); | |||||
| }); | |||||
| } | |||||
| public static Tensor add(Tensor x, Tensor y, string name = null) | |||||
| => gen_math_ops.add(x, y, name); | |||||
| public static Tensor add(Tensor x, string name = null) | |||||
| { | |||||
| return with(ops.name_scope(name, "Abs", new { x }), scope => | |||||
| { | |||||
| name = scope; | |||||
| x = ops.convert_to_tensor(x, name: "x"); | |||||
| return gen_math_ops._abs(x, name: name); | |||||
| }); | |||||
| } | |||||
| public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | ||||
| { | { | ||||
| @@ -222,6 +222,12 @@ namespace Tensorflow | |||||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | ||||
| nd = np.array(ints).reshape(ndims); | nd = np.array(ints).reshape(ndims); | ||||
| break; | break; | ||||
| case TF_DataType.TF_INT64: | |||||
| var longs = new long[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(longs).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
| var floats = new float[tensor.size]; | var floats = new float[tensor.size]; | ||||
| for (ulong i = 0; i < tensor.size; i++) | for (ulong i = 0; i < tensor.size; i++) | ||||
| @@ -65,6 +65,9 @@ namespace Tensorflow | |||||
| case "Int32": | case "Int32": | ||||
| full_values.Add(value.Data<int>(0)); | full_values.Add(value.Data<int>(0)); | ||||
| break; | break; | ||||
| case "Int64": | |||||
| full_values.Add(value.Data<long>(0)); | |||||
| break; | |||||
| case "Single": | case "Single": | ||||
| full_values.Add(value.Data<float>(0)); | full_values.Add(value.Data<float>(0)); | ||||
| break; | break; | ||||
| @@ -78,7 +81,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| full_values.Add(value[np.arange(1)]); | |||||
| full_values.Add(value[np.arange(0, value.shape[0])]); | |||||
| } | } | ||||
| } | } | ||||
| i += 1; | i += 1; | ||||
| @@ -4,7 +4,7 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <Version>0.5.0</Version> | |||||
| <Version>0.5.1</Version> | |||||
| <Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| @@ -16,11 +16,13 @@ | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | ||||
| <Description>Google's TensorFlow binding in .NET Standard. | <Description>Google's TensorFlow binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.5.0.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Add Logistic Regression to do MNIST. | |||||
| Add a lot of APIs to build neural networks model</PackageReleaseNotes> | |||||
| <AssemblyVersion>0.5.1.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Changes since v0.5: | |||||
| Added Nearest Neighbor. | |||||
| Add a lot of APIs to build neural networks model. | |||||
| Bug fix.</PackageReleaseNotes> | |||||
| <LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
| <FileVersion>0.5.0.0</FileVersion> | |||||
| <FileVersion>0.5.1.0</FileVersion> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| @@ -44,7 +46,7 @@ Add a lot of APIs to build neural networks model</PackageReleaseNotes> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.7.0" /> | <PackageReference Include="Google.Protobuf" Version="3.7.0" /> | ||||
| <PackageReference Include="NumSharp" Version="0.8.1" /> | |||||
| <PackageReference Include="NumSharp" Version="0.8.2" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -47,11 +47,11 @@ namespace Tensorflow | |||||
| /// special tokens filters by prefix. | /// special tokens filters by prefix. | ||||
| /// </param> | /// </param> | ||||
| /// <returns>A list of `Variable` objects.</returns> | /// <returns>A list of `Variable` objects.</returns> | ||||
| public static List<RefVariable> global_variables(string scope = "") | |||||
| public static List<RefVariable> global_variables(string scope = null) | |||||
| { | { | ||||
| var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | ||||
| return result as List<RefVariable>; | |||||
| return result == null ? new List<RefVariable>() : result as List<RefVariable>; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -62,7 +62,10 @@ namespace Tensorflow | |||||
| /// <returns>An Op that run the initializers of all the specified variables.</returns> | /// <returns>An Op that run the initializers of all the specified variables.</returns> | ||||
| public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | ||||
| { | { | ||||
| return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray(), name); | |||||
| if (var_list.Length > 0) | |||||
| return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray(), name); | |||||
| else | |||||
| return gen_control_flow_ops.no_op(name: name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -41,7 +41,7 @@ namespace Tensorflow | |||||
| /// list contains the values in the order under which they were | /// list contains the values in the order under which they were | ||||
| /// collected. | /// collected. | ||||
| /// </returns> | /// </returns> | ||||
| public static object get_collection(string key, string scope = "") | |||||
| public static object get_collection(string key, string scope = null) | |||||
| { | { | ||||
| return get_default_graph().get_collection(key, scope); | return get_default_graph().get_collection(key, scope); | ||||
| } | } | ||||
| @@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| public class ImageRecognition : Python, IExample | public class ImageRecognition : Python, IExample | ||||
| { | { | ||||
| public int Priority => 5; | |||||
| public int Priority => 6; | |||||
| public bool Enabled => true; | public bool Enabled => true; | ||||
| public string Name => "Image Recognition"; | public string Name => "Image Recognition"; | ||||
| @@ -98,7 +98,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| } | } | ||||
| public void SaveModel(Session sess) | public void SaveModel(Session sess) | ||||
| @@ -0,0 +1,70 @@ | |||||
| using NumSharp.Core; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using TensorFlowNET.Examples.Utility; | |||||
| namespace TensorFlowNET.Examples | |||||
| { | |||||
| /// <summary> | |||||
| /// A nearest neighbor learning algorithm example | |||||
| /// This example is using the MNIST database of handwritten digits | |||||
| /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py | |||||
| /// </summary> | |||||
| public class NearestNeighbor : Python, IExample | |||||
| { | |||||
| public int Priority => 5; | |||||
| public bool Enabled => true; | |||||
| public string Name => "Nearest Neighbor"; | |||||
| Datasets mnist; | |||||
| NDArray Xtr, Ytr, Xte, Yte; | |||||
| public bool Run() | |||||
| { | |||||
| // tf Graph Input | |||||
| var xtr = tf.placeholder(tf.float32, new TensorShape(-1, 784)); | |||||
| var xte = tf.placeholder(tf.float32, new TensorShape(784)); | |||||
| // Nearest Neighbor calculation using L1 Distance | |||||
| // Calculate L1 Distance | |||||
| var distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices: 1); | |||||
| // Prediction: Get min distance index (Nearest neighbor) | |||||
| var pred = tf.arg_min(distance, 0); | |||||
| float accuracy = 0f; | |||||
| // Initialize the variables (i.e. assign their default value) | |||||
| var init = tf.global_variables_initializer(); | |||||
| with(tf.Session(), sess => | |||||
| { | |||||
| // Run the initializer | |||||
| sess.run(init); | |||||
| PrepareData(); | |||||
| foreach(int i in range(Xte.shape[0])) | |||||
| { | |||||
| // Get nearest neighbor | |||||
| long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i])); | |||||
| // Get nearest neighbor class label and compare it to its true label | |||||
| print($"Test {i} Prediction: {np.argmax(Ytr[nn_index])} True Class: {np.argmax(Yte[i] as NDArray)}"); | |||||
| // Calculate accuracy | |||||
| if (np.argmax(Ytr[nn_index]) == np.argmax(Yte[i] as NDArray)) | |||||
| accuracy += 1f/ Xte.shape[0]; | |||||
| } | |||||
| print($"Accuracy: {accuracy}"); | |||||
| }); | |||||
| return accuracy > 0.9; | |||||
| } | |||||
| public void PrepareData() | |||||
| { | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| // In this example, we limit mnist data | |||||
| (Xtr, Ytr) = mnist.train.next_batch(5000); // 5000 for training (nn candidates) | |||||
| (Xte, Yte) = mnist.test.next_batch(200); // 200 for testing | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -32,11 +32,11 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| if (example.Enabled) | if (example.Enabled) | ||||
| if (example.Run()) | if (example.Run()) | ||||
| success.Add($"{example.Priority} {example.Name}"); | |||||
| success.Add($"Example {example.Priority}: {example.Name}"); | |||||
| else | else | ||||
| errors.Add($"{example.Priority} {example.Name}"); | |||||
| errors.Add($"Example {example.Priority}: {example.Name}"); | |||||
| else | else | ||||
| disabled.Add($"{example.Priority} {example.Name}"); | |||||
| disabled.Add($"Example {example.Priority}: {example.Name}"); | |||||
| } | } | ||||
| catch (Exception ex) | catch (Exception ex) | ||||
| { | { | ||||
| @@ -46,9 +46,9 @@ namespace TensorFlowNET.Examples | |||||
| Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); | Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); | ||||
| } | } | ||||
| success.ForEach(x => Console.WriteLine($"{x} example is OK!", Color.Green)); | |||||
| disabled.ForEach(x => Console.WriteLine($"{x} example is Disabled!", Color.Tan)); | |||||
| errors.ForEach(x => Console.WriteLine($"{x} example is Failed!", Color.Red)); | |||||
| success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green)); | |||||
| disabled.ForEach(x => Console.WriteLine($"{x} is Disabled!", Color.Tan)); | |||||
| errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); | |||||
| Console.ReadLine(); | Console.ReadLine(); | ||||
| } | } | ||||
| @@ -8,7 +8,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Colorful.Console" Version="1.2.9" /> | <PackageReference Include="Colorful.Console" Version="1.2.9" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.1" /> | <PackageReference Include="Newtonsoft.Json" Version="12.0.1" /> | ||||
| <PackageReference Include="NumSharp" Version="0.8.1" /> | |||||
| <PackageReference Include="NumSharp" Version="0.8.2" /> | |||||
| <PackageReference Include="SharpZipLib" Version="1.1.0" /> | <PackageReference Include="SharpZipLib" Version="1.1.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -11,7 +11,7 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| public class TextClassificationWithMovieReviews : Python, IExample | public class TextClassificationWithMovieReviews : Python, IExample | ||||
| { | { | ||||
| public int Priority => 6; | |||||
| public int Priority => 7; | |||||
| public bool Enabled => false; | public bool Enabled => false; | ||||
| public string Name => "Movie Reviews"; | public string Name => "Movie Reviews"; | ||||
| @@ -19,8 +19,8 @@ | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
| <PackageReference Include="NumSharp" Version="0.8.1" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.4.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.8.2" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.5.0" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||