| @@ -9,8 +9,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | ||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DA680126-DA60-4CE3-9094-72C355C081D3}" | |||||
| EndProject | |||||
| Global | Global | ||||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
| Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
| @@ -29,10 +27,6 @@ Global | |||||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
| {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| EndGlobalSection | EndGlobalSection | ||||
| GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
| HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
| @@ -31,6 +31,25 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | ||||
| /// <summary> | |||||
| /// Iterate through the operations of a graph. | |||||
| /// </summary> | |||||
| /// <param name="graph"></param> | |||||
| /// <param name="pos"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos); | |||||
| /// <summary> | |||||
| /// Returns the operation in the graph with `oper_name`. Returns nullptr if | |||||
| /// no operation found. | |||||
| /// </summary> | |||||
| /// <param name="graph"></param> | |||||
| /// <param name="oper_name"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets the shape of the Tensor referenced by `output` in `graph` to | /// Sets the shape of the Tensor referenced by `output` in `graph` to | ||||
| /// the shape described by `dims` and `num_dims`. | /// the shape described by `dims` and `num_dims`. | ||||
| @@ -116,6 +116,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| case IntPtr val: | case IntPtr val: | ||||
| return val == _handle; | return val == _handle; | ||||
| case Operation val: | |||||
| return val._handle == _handle; | |||||
| } | } | ||||
| return base.Equals(obj); | return base.Equals(obj); | ||||
| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class OperationDescription | |||||
| { | |||||
| private IntPtr _handle; | |||||
| public OperationDescription(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| public static implicit operator OperationDescription(IntPtr handle) | |||||
| { | |||||
| return new OperationDescription(handle); | |||||
| } | |||||
| public static implicit operator IntPtr(OperationDescription desc) | |||||
| { | |||||
| return desc._handle; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -9,6 +9,7 @@ namespace Tensorflow | |||||
| public struct TF_OperationDescription | public struct TF_OperationDescription | ||||
| { | { | ||||
| public IntPtr node_builder; | public IntPtr node_builder; | ||||
| //public TF_Graph graph; | |||||
| public IntPtr graph; | |||||
| public IntPtr colocation_constraints; | |||||
| } | } | ||||
| } | } | ||||
| @@ -33,11 +33,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | ||||
| <PackageReference Include="NumSharp" Version="0.6.2" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
| <PackageReference Include="NumSharp" Version="0.6.3" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -22,6 +22,7 @@ namespace Tensorflow | |||||
| /// int32_t => int | /// int32_t => int | ||||
| /// int64_t* => long[] | /// int64_t* => long[] | ||||
| /// size_t* => unlong[] | /// size_t* => unlong[] | ||||
| /// size_t* => ref uint | |||||
| /// void* => IntPtr | /// void* => IntPtr | ||||
| /// string => IntPtr c_api.StringPiece(IntPtr) | /// string => IntPtr c_api.StringPiece(IntPtr) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -6,11 +6,10 @@ | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="NumSharp" Version="0.6.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.6.3" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -104,21 +104,98 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.IsFalse(found_placeholder); | Assert.IsFalse(found_placeholder); | ||||
| found_placeholder = true; | found_placeholder = true; | ||||
| } | } | ||||
| /*else if (IsScalarConst(n, 3)) | |||||
| else if (c_test_util.IsScalarConst(n, 3)) | |||||
| { | { | ||||
| Assert.IsFalse(found_scalar_const); | Assert.IsFalse(found_scalar_const); | ||||
| found_scalar_const = true; | found_scalar_const = true; | ||||
| } | } | ||||
| else if (IsAddN(n, 2)) | |||||
| else if (c_test_util.IsAddN(n, 2)) | |||||
| { | { | ||||
| Assert.IsFalse(found_add); | Assert.IsFalse(found_add); | ||||
| found_add = true; | found_add = true; | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); | |||||
| }*/ | |||||
| Assert.Fail($"Unexpected NodeDef: {n}"); | |||||
| } | |||||
| } | } | ||||
| Assert.IsTrue(found_placeholder); | |||||
| Assert.IsTrue(found_scalar_const); | |||||
| Assert.IsTrue(found_add); | |||||
| // Add another oper to the graph. | |||||
| var neg = c_test_util.Neg(add, graph, s); | |||||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
| // Serialize to NodeDef. | |||||
| var node_def = c_test_util.GetNodeDef(neg); | |||||
| // Validate NodeDef is what we expect. | |||||
| Assert.IsTrue(c_test_util.IsNeg(node_def, "add")); | |||||
| // Serialize to GraphDef. | |||||
| var graph_def2 = c_test_util.GetGraphDef(graph); | |||||
| // Compare with first GraphDef + added NodeDef. | |||||
| graph_def.Node.Add(node_def); | |||||
| Assert.AreEqual(graph_def.ToString(), graph_def2.ToString()); | |||||
| // Look up some nodes by name. | |||||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | |||||
| Assert.AreEqual(neg, neg2); | |||||
| var node_def2 = c_test_util.GetNodeDef(neg2); | |||||
| Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | |||||
| Assert.AreEqual(feed, feed2); | |||||
| node_def = c_test_util.GetNodeDef(feed); | |||||
| node_def2 = c_test_util.GetNodeDef(feed2); | |||||
| Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||||
| // Test iterating through the nodes of a graph. | |||||
| found_placeholder = false; | |||||
| found_scalar_const = false; | |||||
| found_add = false; | |||||
| bool found_neg = false; | |||||
| uint pos = 0; | |||||
| Operation oper; | |||||
| while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | |||||
| { | |||||
| if (oper.Equals(feed)) | |||||
| { | |||||
| Assert.IsFalse(found_placeholder); | |||||
| found_placeholder = true; | |||||
| } | |||||
| else if (oper.Equals(three)) | |||||
| { | |||||
| Assert.IsFalse(found_scalar_const); | |||||
| found_scalar_const = true; | |||||
| } | |||||
| else if (oper.Equals(add)) | |||||
| { | |||||
| Assert.IsFalse(found_add); | |||||
| found_add = true; | |||||
| } | |||||
| else if (oper.Equals(neg)) | |||||
| { | |||||
| Assert.IsFalse(found_neg); | |||||
| found_neg = true; | |||||
| } | |||||
| else | |||||
| { | |||||
| node_def = c_test_util.GetNodeDef(oper); | |||||
| Assert.Fail($"Unexpected Node: {node_def.ToString()}"); | |||||
| } | |||||
| } | |||||
| Assert.IsTrue(found_placeholder); | |||||
| Assert.IsTrue(found_scalar_const); | |||||
| Assert.IsTrue(found_add); | |||||
| Assert.IsTrue(found_neg); | |||||
| graph.Dispose(); | |||||
| s.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,7 +23,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var s = new Status(); | var s = new Status(); | ||||
| s.SetStatus(TF_Code.TF_CANCELLED, "cancel"); | s.SetStatus(TF_Code.TF_CANCELLED, "cancel"); | ||||
| Assert.AreEqual(s.Code, TF_Code.TF_CANCELLED); | Assert.AreEqual(s.Code, TF_Code.TF_CANCELLED); | ||||
| // Assert.AreEqual(s.Message, "cancel"); | |||||
| Assert.AreEqual(s.Message, "cancel"); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -19,11 +19,10 @@ | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
| <PackageReference Include="NumSharp" Version="0.6.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.6.3" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -50,16 +50,66 @@ namespace TensorFlowNET.UnitTest | |||||
| var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | c_api.TF_GraphToGraphDef(graph, buffer, s); | ||||
| s.Check(); | s.Check(); | ||||
| return GraphDef.Parser.ParseFrom(buffer); | |||||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| s.Dispose(); | |||||
| return def; | |||||
| } | } | ||||
| public static bool GetNodeDef(Operation oper, ref NodeDef node_def) | |||||
| public static NodeDef GetNodeDef(Operation oper) | |||||
| { | { | ||||
| var s = new Status(); | var s = new Status(); | ||||
| var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
| c_api.TF_OperationToNodeDef(oper, buffer, s); | c_api.TF_OperationToNodeDef(oper, buffer, s); | ||||
| s.Check(); | |||||
| var ret = NodeDef.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| s.Dispose(); | |||||
| return ret; | |||||
| } | |||||
| return s.Code == TF_Code.TF_OK; | |||||
| public static bool IsAddN(NodeDef node_def, int n) | |||||
| { | |||||
| if (node_def.Op != "AddN" || node_def.Name != "add" || | |||||
| node_def.Input.Count != n) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| bool found_t = false; | |||||
| bool found_n = false; | |||||
| foreach (var attr in node_def.Attr) | |||||
| { | |||||
| if (attr.Key == "T") | |||||
| { | |||||
| if (attr.Value.Type == DataType.DtInt32) | |||||
| { | |||||
| found_t = true; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| else if (attr.Key == "N") | |||||
| { | |||||
| if (attr.Value.I == n) | |||||
| { | |||||
| found_n = true; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return found_t && found_n; | |||||
| } | |||||
| public static bool IsNeg(NodeDef node_def, string input) | |||||
| { | |||||
| return node_def.Op == "Neg" && node_def.Name == "neg" && | |||||
| node_def.Input.Count == 1 && node_def.Input[0] == input; | |||||
| } | } | ||||
| public static bool IsPlaceholder(NodeDef node_def) | public static bool IsPlaceholder(NodeDef node_def) | ||||
| @@ -93,6 +143,59 @@ namespace TensorFlowNET.UnitTest | |||||
| return found_dtype && found_shape; | return found_dtype && found_shape; | ||||
| } | } | ||||
| public static bool IsScalarConst(NodeDef node_def, int v) | |||||
| { | |||||
| if (node_def.Op != "Const" || node_def.Name != "scalar") | |||||
| { | |||||
| return false; | |||||
| } | |||||
| bool found_dtype = false; | |||||
| bool found_value = false; | |||||
| foreach (var attr in node_def.Attr) { | |||||
| if (attr.Key == "dtype") | |||||
| { | |||||
| if (attr.Value.Type == DataType.DtInt32) | |||||
| { | |||||
| found_dtype = true; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| else if (attr.Key == "value") | |||||
| { | |||||
| if (attr.Value.Tensor != null && | |||||
| attr.Value.Tensor.IntVal.Count == 1 && | |||||
| attr.Value.Tensor.IntVal[0] == v) | |||||
| { | |||||
| found_value = true; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return found_dtype && found_value; | |||||
| } | |||||
| public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | |||||
| { | |||||
| return NegHelper(n, graph, s, name); | |||||
| } | |||||
| public static Operation NegHelper(Operation n, Graph graph, Status s, string name) | |||||
| { | |||||
| OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||||
| var neg_input = new TF_Output(n, 0); | |||||
| c_api.TF_AddInput(desc, neg_input); | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| return op; | |||||
| } | |||||
| public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | ||||
| { | { | ||||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | ||||