| @@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | ||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}" | |||||
| EndProject | |||||
| Global | Global | ||||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
| Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
| @@ -27,6 +29,10 @@ Global | |||||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
| {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| EndGlobalSection | EndGlobalSection | ||||
| GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
| HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
| @@ -40,30 +40,22 @@ namespace Tensorflow | |||||
| } | } | ||||
| public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| public virtual object run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
| { | { | ||||
| var result = _run(fetches, feed_dict); | var result = _run(fetches, feed_dict); | ||||
| return result; | return result; | ||||
| } | } | ||||
| private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| private unsafe object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<Tensor, object>(); | |||||
| var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||||
| if (feed_dict != null) | if (feed_dict != null) | ||||
| { | { | ||||
| NDArray np_val = null; | |||||
| foreach (var feed in feed_dict) | foreach (var feed in feed_dict) | ||||
| { | { | ||||
| switch (feed.Value) | |||||
| { | |||||
| case float value: | |||||
| np_val = np.asarray(value); | |||||
| break; | |||||
| } | |||||
| feed_dict_tensor[feed.Key] = np_val; | |||||
| feed_dict_tensor[feed.Key] = feed.Value; | |||||
| } | } | ||||
| } | } | ||||
| @@ -85,9 +77,9 @@ namespace Tensorflow | |||||
| return fetch_handler.build_results(null, results); | return fetch_handler.build_results(null, results); | ||||
| } | } | ||||
| private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict) | |||||
| private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||||
| { | { | ||||
| var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||||
| var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | ||||
| return _call_tf_sessionrun(feeds, fetches); | return _call_tf_sessionrun(feeds, fetches); | ||||
| @@ -133,12 +125,14 @@ namespace Tensorflow | |||||
| case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
| result[i] = *(float*)c_api.TF_TensorData(output_values[i]); | result[i] = *(float*)c_api.TF_TensorData(output_values[i]); | ||||
| break; | break; | ||||
| case TF_DataType.TF_INT16: | |||||
| result[i] = *(short*)c_api.TF_TensorData(output_values[i]); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | case TF_DataType.TF_INT32: | ||||
| result[i] = *(int*)c_api.TF_TensorData(output_values[i]); | result[i] = *(int*)c_api.TF_TensorData(output_values[i]); | ||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException("can't get output"); | throw new NotImplementedException("can't get output"); | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using NumSharp.Core; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -15,7 +16,7 @@ namespace Tensorflow | |||||
| private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
| private List<object> _targets = new List<object>(); | private List<object> _targets = new List<object>(); | ||||
| public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null) | |||||
| public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | _fetch_mapper = new _FetchMapper().for_fetch(fetches); | ||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
| @@ -43,4 +43,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| <Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" /> | <Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
| </ItemGroup> | |||||
| </Project> | </Project> | ||||
| @@ -74,6 +74,9 @@ namespace Tensorflow | |||||
| switch (nd.dtype.Name) | switch (nd.dtype.Name) | ||||
| { | { | ||||
| case "Int16": | |||||
| Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Int32": | case "Int32": | ||||
| Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | ||||
| break; | break; | ||||
| @@ -161,6 +164,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| switch (type.Name) | switch (type.Name) | ||||
| { | { | ||||
| case "Int16": | |||||
| return TF_DataType.TF_INT16; | |||||
| case "Int32": | case "Int32": | ||||
| return TF_DataType.TF_INT32; | return TF_DataType.TF_INT32; | ||||
| case "Single": | case "Single": | ||||
| @@ -169,9 +174,9 @@ namespace Tensorflow | |||||
| return TF_DataType.TF_DOUBLE; | return TF_DataType.TF_DOUBLE; | ||||
| case "String": | case "String": | ||||
| return TF_DataType.TF_STRING; | return TF_DataType.TF_STRING; | ||||
| default: | |||||
| throw new NotImplementedException("ToTFDataType error"); | |||||
| } | } | ||||
| return TF_DataType.DtInvalid; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using NumSharp.Core; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -43,11 +44,24 @@ namespace TensorFlowNET.Examples | |||||
| // Launch the default graph. | // Launch the default graph. | ||||
| using(sess = tf.Session()) | using(sess = tf.Session()) | ||||
| { | { | ||||
| // var feed_dict = new Dictionary<string, > | |||||
| var feed_dict = new Dictionary<Tensor, NDArray>(); | |||||
| feed_dict.Add(a, (short)2); | |||||
| feed_dict.Add(b, (short)3); | |||||
| // Run every operation with variable input | // Run every operation with variable input | ||||
| // Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}"); | |||||
| // Console.WriteLine($"Multiplication with variables: {}"); | |||||
| Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); | |||||
| Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); | |||||
| } | } | ||||
| // ---------------- | |||||
| // More in details: | |||||
| // Matrix Multiplication from TensorFlow official tutorial | |||||
| // Create a Constant op that produces a 1x2 matrix. The op is | |||||
| // added as a node to the default graph. | |||||
| // | |||||
| // The value returned by the constructor represents the output | |||||
| // of the Constant op. | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,6 +10,7 @@ | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp.Core; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -31,7 +32,7 @@ namespace TensorFlowNET.UnitTest | |||||
| using(var sess = tf.Session()) | using(var sess = tf.Session()) | ||||
| { | { | ||||
| var feed_dict = new Dictionary<Tensor, object>(); | |||||
| var feed_dict = new Dictionary<Tensor, NDArray>(); | |||||
| feed_dict.Add(a, 3.0f); | feed_dict.Add(a, 3.0f); | ||||
| feed_dict.Add(b, 2.0f); | feed_dict.Add(b, 2.0f); | ||||