| @@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}" | |||
| EndProject | |||
| Global | |||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
| Debug|Any CPU = Debug|Any CPU | |||
| @@ -27,6 +29,10 @@ Global | |||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| EndGlobalSection | |||
| GlobalSection(SolutionProperties) = preSolution | |||
| HideSolutionNode = FALSE | |||
| @@ -40,30 +40,22 @@ namespace Tensorflow | |||
| } | |||
| public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
| public virtual object run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
| { | |||
| var result = _run(fetches, feed_dict); | |||
| return result; | |||
| } | |||
| private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
| private unsafe object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<Tensor, object>(); | |||
| var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||
| if (feed_dict != null) | |||
| { | |||
| NDArray np_val = null; | |||
| foreach (var feed in feed_dict) | |||
| { | |||
| switch (feed.Value) | |||
| { | |||
| case float value: | |||
| np_val = np.asarray(value); | |||
| break; | |||
| } | |||
| feed_dict_tensor[feed.Key] = np_val; | |||
| feed_dict_tensor[feed.Key] = feed.Value; | |||
| } | |||
| } | |||
| @@ -85,9 +77,9 @@ namespace Tensorflow | |||
| return fetch_handler.build_results(null, results); | |||
| } | |||
| private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict) | |||
| private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||
| { | |||
| var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||
| var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| return _call_tf_sessionrun(feeds, fetches); | |||
| @@ -133,12 +125,14 @@ namespace Tensorflow | |||
| case TF_DataType.TF_FLOAT: | |||
| result[i] = *(float*)c_api.TF_TensorData(output_values[i]); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| result[i] = *(short*)c_api.TF_TensorData(output_values[i]); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| result[i] = *(int*)c_api.TF_TensorData(output_values[i]); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't get output"); | |||
| break; | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using NumSharp.Core; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| @@ -15,7 +16,7 @@ namespace Tensorflow | |||
| private List<Tensor> _final_fetches = new List<Tensor>(); | |||
| private List<object> _targets = new List<object>(); | |||
| public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null) | |||
| public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||
| { | |||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
| @@ -43,4 +43,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
| <Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -74,6 +74,9 @@ namespace Tensorflow | |||
| switch (nd.dtype.Name) | |||
| { | |||
| case "Int16": | |||
| Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "Int32": | |||
| Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||
| break; | |||
| @@ -161,6 +164,8 @@ namespace Tensorflow | |||
| { | |||
| switch (type.Name) | |||
| { | |||
| case "Int16": | |||
| return TF_DataType.TF_INT16; | |||
| case "Int32": | |||
| return TF_DataType.TF_INT32; | |||
| case "Single": | |||
| @@ -169,9 +174,9 @@ namespace Tensorflow | |||
| return TF_DataType.TF_DOUBLE; | |||
| case "String": | |||
| return TF_DataType.TF_STRING; | |||
| default: | |||
| throw new NotImplementedException("ToTFDataType error"); | |||
| } | |||
| return TF_DataType.DtInvalid; | |||
| } | |||
| public void Dispose() | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using NumSharp.Core; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| @@ -43,11 +44,24 @@ namespace TensorFlowNET.Examples | |||
| // Launch the default graph. | |||
| using(sess = tf.Session()) | |||
| { | |||
| // var feed_dict = new Dictionary<string, > | |||
| var feed_dict = new Dictionary<Tensor, NDArray>(); | |||
| feed_dict.Add(a, (short)2); | |||
| feed_dict.Add(b, (short)3); | |||
| // Run every operation with variable input | |||
| // Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}"); | |||
| // Console.WriteLine($"Multiplication with variables: {}"); | |||
| Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); | |||
| Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); | |||
| } | |||
| // ---------------- | |||
| // More in details: | |||
| // Matrix Multiplication from TensorFlow official tutorial | |||
| // Create a Constant op that produces a 1x2 matrix. The op is | |||
| // added as a node to the default graph. | |||
| // | |||
| // The value returned by the constructor represents the output | |||
| // of the Constant op. | |||
| } | |||
| } | |||
| } | |||
| @@ -10,6 +10,7 @@ | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
| </ItemGroup> | |||
| @@ -1,4 +1,5 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using NumSharp.Core; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| @@ -31,7 +32,7 @@ namespace TensorFlowNET.UnitTest | |||
| using(var sess = tf.Session()) | |||
| { | |||
| var feed_dict = new Dictionary<Tensor, object>(); | |||
| var feed_dict = new Dictionary<Tensor, NDArray>(); | |||
| feed_dict.Add(a, 3.0f); | |||
| feed_dict.Add(b, 2.0f); | |||