diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index 3b2df95d..019bad25 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -9,8 +9,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DA680126-DA60-4CE3-9094-72C355C081D3}"
-EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -29,10 +27,6 @@ Global
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU
- {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
index b2e9947b..9a836d10 100644
--- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
@@ -31,6 +31,25 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status);
+ ///
+ /// Iterate through the operations of a graph.
+ ///
+ ///
+ ///
+ ///
+ [DllImport(TensorFlowLibName)]
+ public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos);
+
+ ///
+ /// Returns the operation in the graph with `oper_name`. Returns nullptr if
+ /// no operation found.
+ ///
+ ///
+ ///
+ ///
+ [DllImport(TensorFlowLibName)]
+ public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name);
+
///
/// Sets the shape of the Tensor referenced by `output` in `graph` to
/// the shape described by `dims` and `num_dims`.
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index c8a2933f..1bd440b3 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -116,6 +116,8 @@ namespace Tensorflow
{
case IntPtr val:
return val == _handle;
+ case Operation val:
+ return val._handle == _handle;
}
return base.Equals(obj);
diff --git a/src/TensorFlowNET.Core/Operations/OperationDescription.cs b/src/TensorFlowNET.Core/Operations/OperationDescription.cs
new file mode 100644
index 00000000..b49952cf
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/OperationDescription.cs
@@ -0,0 +1,26 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class OperationDescription
+ {
+ private IntPtr _handle;
+
+ public OperationDescription(IntPtr handle)
+ {
+ _handle = handle;
+ }
+
+ public static implicit operator OperationDescription(IntPtr handle)
+ {
+ return new OperationDescription(handle);
+ }
+
+ public static implicit operator IntPtr(OperationDescription desc)
+ {
+ return desc._handle;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs b/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs
index 9f04289b..878d9201 100644
--- a/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs
+++ b/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs
@@ -9,6 +9,7 @@ namespace Tensorflow
public struct TF_OperationDescription
{
public IntPtr node_builder;
- //public TF_Graph graph;
+ public IntPtr graph;
+ public IntPtr colocation_constraints;
}
}
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index c35f0aa2..5f35afcb 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -33,11 +33,7 @@
-
-
-
-
-
+
diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs
index 0e0316a1..77cc01af 100644
--- a/src/TensorFlowNET.Core/c_api.cs
+++ b/src/TensorFlowNET.Core/c_api.cs
@@ -22,6 +22,7 @@ namespace Tensorflow
/// int32_t => int
/// int64_t* => long[]
/// size_t* => unlong[]
+ /// size_t* => ref uint
/// void* => IntPtr
/// string => IntPtr c_api.StringPiece(IntPtr)
///
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
index 56518593..116a0b90 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
@@ -6,11 +6,10 @@
-
+
-
diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs
index 3b2fd37c..e97644df 100644
--- a/test/TensorFlowNET.UnitTest/GraphTest.cs
+++ b/test/TensorFlowNET.UnitTest/GraphTest.cs
@@ -104,21 +104,98 @@ namespace TensorFlowNET.UnitTest
Assert.IsFalse(found_placeholder);
found_placeholder = true;
}
- /*else if (IsScalarConst(n, 3))
+ else if (c_test_util.IsScalarConst(n, 3))
{
Assert.IsFalse(found_scalar_const);
found_scalar_const = true;
}
- else if (IsAddN(n, 2))
+ else if (c_test_util.IsAddN(n, 2))
{
Assert.IsFalse(found_add);
found_add = true;
}
else
{
- ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n);
- }*/
+ Assert.Fail($"Unexpected NodeDef: {n}");
+ }
}
+ Assert.IsTrue(found_placeholder);
+ Assert.IsTrue(found_scalar_const);
+ Assert.IsTrue(found_add);
+
+ // Add another oper to the graph.
+ var neg = c_test_util.Neg(add, graph, s);
+ Assert.AreEqual(TF_Code.TF_OK, s.Code);
+
+ // Serialize to NodeDef.
+ var node_def = c_test_util.GetNodeDef(neg);
+
+ // Validate NodeDef is what we expect.
+ Assert.IsTrue(c_test_util.IsNeg(node_def, "add"));
+
+ // Serialize to GraphDef.
+ var graph_def2 = c_test_util.GetGraphDef(graph);
+
+ // Compare with first GraphDef + added NodeDef.
+ graph_def.Node.Add(node_def);
+ Assert.AreEqual(graph_def.ToString(), graph_def2.ToString());
+
+ // Look up some nodes by name.
+ Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
+ Assert.AreEqual(neg, neg2);
+ var node_def2 = c_test_util.GetNodeDef(neg2);
+ Assert.AreEqual(node_def.ToString(), node_def2.ToString());
+
+ Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
+ Assert.AreEqual(feed, feed2);
+ node_def = c_test_util.GetNodeDef(feed);
+ node_def2 = c_test_util.GetNodeDef(feed2);
+ Assert.AreEqual(node_def.ToString(), node_def2.ToString());
+
+ // Test iterating through the nodes of a graph.
+ found_placeholder = false;
+ found_scalar_const = false;
+ found_add = false;
+ bool found_neg = false;
+ uint pos = 0;
+ Operation oper;
+
+ while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
+ {
+ if (oper.Equals(feed))
+ {
+ Assert.IsFalse(found_placeholder);
+ found_placeholder = true;
+ }
+ else if (oper.Equals(three))
+ {
+ Assert.IsFalse(found_scalar_const);
+ found_scalar_const = true;
+ }
+ else if (oper.Equals(add))
+ {
+ Assert.IsFalse(found_add);
+ found_add = true;
+ }
+ else if (oper.Equals(neg))
+ {
+ Assert.IsFalse(found_neg);
+ found_neg = true;
+ }
+ else
+ {
+ node_def = c_test_util.GetNodeDef(oper);
+ Assert.Fail($"Unexpected Node: {node_def.ToString()}");
+ }
+ }
+
+ Assert.IsTrue(found_placeholder);
+ Assert.IsTrue(found_scalar_const);
+ Assert.IsTrue(found_add);
+ Assert.IsTrue(found_neg);
+
+ graph.Dispose();
+ s.Dispose();
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/StatusTest.cs b/test/TensorFlowNET.UnitTest/StatusTest.cs
index 8e1baede..7283e7a0 100644
--- a/test/TensorFlowNET.UnitTest/StatusTest.cs
+++ b/test/TensorFlowNET.UnitTest/StatusTest.cs
@@ -23,7 +23,7 @@ namespace TensorFlowNET.UnitTest
var s = new Status();
s.SetStatus(TF_Code.TF_CANCELLED, "cancel");
Assert.AreEqual(s.Code, TF_Code.TF_CANCELLED);
- // Assert.AreEqual(s.Message, "cancel");
+ Assert.AreEqual(s.Message, "cancel");
}
[TestMethod]
diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
index e93f4713..802d4c0c 100644
--- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
@@ -19,11 +19,10 @@
-
+
-
diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs
index 079ecee5..128e50df 100644
--- a/test/TensorFlowNET.UnitTest/c_test_util.cs
+++ b/test/TensorFlowNET.UnitTest/c_test_util.cs
@@ -50,16 +50,66 @@ namespace TensorFlowNET.UnitTest
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(graph, buffer, s);
s.Check();
- return GraphDef.Parser.ParseFrom(buffer);
+ var def = GraphDef.Parser.ParseFrom(buffer);
+ buffer.Dispose();
+ s.Dispose();
+ return def;
}
- public static bool GetNodeDef(Operation oper, ref NodeDef node_def)
+ public static NodeDef GetNodeDef(Operation oper)
{
var s = new Status();
var buffer = new Buffer();
c_api.TF_OperationToNodeDef(oper, buffer, s);
+ s.Check();
+ var ret = NodeDef.Parser.ParseFrom(buffer);
+ buffer.Dispose();
+ s.Dispose();
+ return ret;
+ }
- return s.Code == TF_Code.TF_OK;
+ public static bool IsAddN(NodeDef node_def, int n)
+ {
+ if (node_def.Op != "AddN" || node_def.Name != "add" ||
+ node_def.Input.Count != n)
+ {
+ return false;
+ }
+ bool found_t = false;
+ bool found_n = false;
+ foreach (var attr in node_def.Attr)
+ {
+ if (attr.Key == "T")
+ {
+ if (attr.Value.Type == DataType.DtInt32)
+ {
+ found_t = true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ else if (attr.Key == "N")
+ {
+ if (attr.Value.I == n)
+ {
+ found_n = true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ }
+
+ return found_t && found_n;
+ }
+
+ public static bool IsNeg(NodeDef node_def, string input)
+ {
+ return node_def.Op == "Neg" && node_def.Name == "neg" &&
+ node_def.Input.Count == 1 && node_def.Input[0] == input;
}
public static bool IsPlaceholder(NodeDef node_def)
@@ -93,6 +143,59 @@ namespace TensorFlowNET.UnitTest
return found_dtype && found_shape;
}
+ public static bool IsScalarConst(NodeDef node_def, int v)
+ {
+ if (node_def.Op != "Const" || node_def.Name != "scalar")
+ {
+ return false;
+ }
+ bool found_dtype = false;
+ bool found_value = false;
+ foreach (var attr in node_def.Attr) {
+ if (attr.Key == "dtype")
+ {
+ if (attr.Value.Type == DataType.DtInt32)
+ {
+ found_dtype = true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ else if (attr.Key == "value")
+ {
+ if (attr.Value.Tensor != null &&
+ attr.Value.Tensor.IntVal.Count == 1 &&
+ attr.Value.Tensor.IntVal[0] == v)
+ {
+ found_value = true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ }
+ return found_dtype && found_value;
+ }
+
+ public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg")
+ {
+ return NegHelper(n, graph, s, name);
+ }
+
+ public static Operation NegHelper(Operation n, Graph graph, Status s, string name)
+ {
+ OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name);
+ var neg_input = new TF_Output(n, 0);
+ c_api.TF_AddInput(desc, neg_input);
+ var op = c_api.TF_FinishOperation(desc, s);
+ s.Check();
+
+ return op;
+ }
+
public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);