diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 3b2df95d..019bad25 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -9,8 +9,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DA680126-DA60-4CE3-9094-72C355C081D3}" -EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -29,10 +27,6 @@ Global {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU - {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.Build.0 = Debug|Any CPU - {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.ActiveCfg = Release|Any CPU - {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index b2e9947b..9a836d10 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -31,6 +31,25 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); + /// + /// Iterate through the operations of a graph. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos); + + /// + /// Returns the operation in the graph with `oper_name`. Returns nullptr if + /// no operation found. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name); + /// /// Sets the shape of the Tensor referenced by `output` in `graph` to /// the shape described by `dims` and `num_dims`. diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index c8a2933f..1bd440b3 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -116,6 +116,8 @@ namespace Tensorflow { case IntPtr val: return val == _handle; + case Operation val: + return val._handle == _handle; } return base.Equals(obj); diff --git a/src/TensorFlowNET.Core/Operations/OperationDescription.cs b/src/TensorFlowNET.Core/Operations/OperationDescription.cs new file mode 100644 index 00000000..b49952cf --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/OperationDescription.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class OperationDescription + { + private IntPtr _handle; + + public OperationDescription(IntPtr handle) + { + _handle = handle; + } + + public static implicit operator OperationDescription(IntPtr handle) + { + return new OperationDescription(handle); + } + + public static implicit operator IntPtr(OperationDescription desc) + { + return desc._handle; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs b/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs index 9f04289b..878d9201 100644 --- a/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs +++ b/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs @@ -9,6 +9,7 @@ namespace Tensorflow public struct TF_OperationDescription { public IntPtr node_builder; - //public TF_Graph graph; + public IntPtr graph; + public IntPtr colocation_constraints; } } diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index c35f0aa2..5f35afcb 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -33,11 +33,7 @@ - - - - - + diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index 0e0316a1..77cc01af 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -22,6 +22,7 @@ namespace Tensorflow /// int32_t => int /// int64_t* => long[] /// size_t* => unlong[] + /// size_t* => ref uint /// void* => IntPtr /// string => IntPtr c_api.StringPiece(IntPtr) /// diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 56518593..116a0b90 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -6,11 +6,10 @@ - + - diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 3b2fd37c..e97644df 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -104,21 +104,98 @@ namespace TensorFlowNET.UnitTest Assert.IsFalse(found_placeholder); found_placeholder = true; } - /*else if (IsScalarConst(n, 3)) + else if (c_test_util.IsScalarConst(n, 3)) { Assert.IsFalse(found_scalar_const); found_scalar_const = true; } - else if (IsAddN(n, 2)) + else if (c_test_util.IsAddN(n, 2)) { Assert.IsFalse(found_add); found_add = true; } else { - ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); - }*/ + Assert.Fail($"Unexpected NodeDef: {n}"); + } } + Assert.IsTrue(found_placeholder); + Assert.IsTrue(found_scalar_const); + Assert.IsTrue(found_add); + + // Add another oper to the graph. + var neg = c_test_util.Neg(add, graph, s); + Assert.AreEqual(TF_Code.TF_OK, s.Code); + + // Serialize to NodeDef. + var node_def = c_test_util.GetNodeDef(neg); + + // Validate NodeDef is what we expect. + Assert.IsTrue(c_test_util.IsNeg(node_def, "add")); + + // Serialize to GraphDef. + var graph_def2 = c_test_util.GetGraphDef(graph); + + // Compare with first GraphDef + added NodeDef. + graph_def.Node.Add(node_def); + Assert.AreEqual(graph_def.ToString(), graph_def2.ToString()); + + // Look up some nodes by name. + Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); + Assert.AreEqual(neg, neg2); + var node_def2 = c_test_util.GetNodeDef(neg2); + Assert.AreEqual(node_def.ToString(), node_def2.ToString()); + + Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); + Assert.AreEqual(feed, feed2); + node_def = c_test_util.GetNodeDef(feed); + node_def2 = c_test_util.GetNodeDef(feed2); + Assert.AreEqual(node_def.ToString(), node_def2.ToString()); + + // Test iterating through the nodes of a graph. + found_placeholder = false; + found_scalar_const = false; + found_add = false; + bool found_neg = false; + uint pos = 0; + Operation oper; + + while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) + { + if (oper.Equals(feed)) + { + Assert.IsFalse(found_placeholder); + found_placeholder = true; + } + else if (oper.Equals(three)) + { + Assert.IsFalse(found_scalar_const); + found_scalar_const = true; + } + else if (oper.Equals(add)) + { + Assert.IsFalse(found_add); + found_add = true; + } + else if (oper.Equals(neg)) + { + Assert.IsFalse(found_neg); + found_neg = true; + } + else + { + node_def = c_test_util.GetNodeDef(oper); + Assert.Fail($"Unexpected Node: {node_def.ToString()}"); + } + } + + Assert.IsTrue(found_placeholder); + Assert.IsTrue(found_scalar_const); + Assert.IsTrue(found_add); + Assert.IsTrue(found_neg); + + graph.Dispose(); + s.Dispose(); } } } diff --git a/test/TensorFlowNET.UnitTest/StatusTest.cs b/test/TensorFlowNET.UnitTest/StatusTest.cs index 8e1baede..7283e7a0 100644 --- a/test/TensorFlowNET.UnitTest/StatusTest.cs +++ b/test/TensorFlowNET.UnitTest/StatusTest.cs @@ -23,7 +23,7 @@ namespace TensorFlowNET.UnitTest var s = new Status(); s.SetStatus(TF_Code.TF_CANCELLED, "cancel"); Assert.AreEqual(s.Code, TF_Code.TF_CANCELLED); - // Assert.AreEqual(s.Message, "cancel"); + Assert.AreEqual(s.Message, "cancel"); } [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index e93f4713..802d4c0c 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -19,11 +19,10 @@ - + - diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 079ecee5..128e50df 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -50,16 +50,66 @@ namespace TensorFlowNET.UnitTest var buffer = new Buffer(); c_api.TF_GraphToGraphDef(graph, buffer, s); s.Check(); - return GraphDef.Parser.ParseFrom(buffer); + var def = GraphDef.Parser.ParseFrom(buffer); + buffer.Dispose(); + s.Dispose(); + return def; } - public static bool GetNodeDef(Operation oper, ref NodeDef node_def) + public static NodeDef GetNodeDef(Operation oper) { var s = new Status(); var buffer = new Buffer(); c_api.TF_OperationToNodeDef(oper, buffer, s); + s.Check(); + var ret = NodeDef.Parser.ParseFrom(buffer); + buffer.Dispose(); + s.Dispose(); + return ret; + } - return s.Code == TF_Code.TF_OK; + public static bool IsAddN(NodeDef node_def, int n) + { + if (node_def.Op != "AddN" || node_def.Name != "add" || + node_def.Input.Count != n) + { + return false; + } + bool found_t = false; + bool found_n = false; + foreach (var attr in node_def.Attr) + { + if (attr.Key == "T") + { + if (attr.Value.Type == DataType.DtInt32) + { + found_t = true; + } + else + { + return false; + } + } + else if (attr.Key == "N") + { + if (attr.Value.I == n) + { + found_n = true; + } + else + { + return false; + } + } + } + + return found_t && found_n; + } + + public static bool IsNeg(NodeDef node_def, string input) + { + return node_def.Op == "Neg" && node_def.Name == "neg" && + node_def.Input.Count == 1 && node_def.Input[0] == input; } public static bool IsPlaceholder(NodeDef node_def) @@ -93,6 +143,59 @@ namespace TensorFlowNET.UnitTest return found_dtype && found_shape; } + public static bool IsScalarConst(NodeDef node_def, int v) + { + if (node_def.Op != "Const" || node_def.Name != "scalar") + { + return false; + } + bool found_dtype = false; + bool found_value = false; + foreach (var attr in node_def.Attr) { + if (attr.Key == "dtype") + { + if (attr.Value.Type == DataType.DtInt32) + { + found_dtype = true; + } + else + { + return false; + } + } + else if (attr.Key == "value") + { + if (attr.Value.Tensor != null && + attr.Value.Tensor.IntVal.Count == 1 && + attr.Value.Tensor.IntVal[0] == v) + { + found_value = true; + } + else + { + return false; + } + } + } + return found_dtype && found_value; + } + + public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") + { + return NegHelper(n, graph, s, name); + } + + public static Operation NegHelper(Operation n, Graph graph, Status s, string name) + { + OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); + var neg_input = new TF_Output(n, 0); + c_api.TF_AddInput(desc, neg_input); + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } + public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) { var desc = c_api.TF_NewOperation(graph, "Placeholder", name);