From 119f0c5e3a1ff8117fea2eb843ccb28079299376 Mon Sep 17 00:00:00 2001 From: Esther2013 Date: Fri, 4 Jan 2019 23:23:45 -0600 Subject: [PATCH] add FinishOperation to OperationDescription add GetNodeDef to Operation --- .../Graphs/Graph.Operation.cs | 14 +--------- .../Operations/Operation.cs | 11 ++++++++ .../Operations/OperationDescription.cs | 5 ++++ .../Operations/c_api.ops.cs | 2 +- .../CApiColocationTest.cs | 14 +++++----- test/TensorFlowNET.UnitTest/GraphTest.cs | 26 +++++++++---------- test/TensorFlowNET.UnitTest/c_test_util.cs | 12 --------- 7 files changed, 37 insertions(+), 47 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index ca726d5b..e2ff80e2 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -20,19 +20,7 @@ namespace Tensorflow public OperationDescription NewOperation(string opType, string opName) { - OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName); - return desc; - - /*c_api.TF_SetAttrTensor(desc, "value", tensor, Status); - - Status.Check(); - - c_api.TF_SetAttrType(desc, "dtype", tensor.dtype); - - var op = c_api.TF_FinishOperation(desc, Status); - Status.Check(); - - return op;*/ + return c_api.TF_NewOperation(_handle, opType, opName); } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 0d3d8a93..5d1f4f83 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -145,6 +145,17 @@ namespace Tensorflow return ret; } + public NodeDef GetNodeDef() + { + using (var s = new Status()) + using (var buffer = new Buffer()) + { + c_api.TF_OperationToNodeDef(_handle, buffer, s); + s.Check(); + return NodeDef.Parser.ParseFrom(buffer); + } + } + public static implicit operator Operation(IntPtr handle) => new Operation(handle); public static implicit operator IntPtr(Operation op) => op._handle; diff --git a/src/TensorFlowNET.Core/Operations/OperationDescription.cs b/src/TensorFlowNET.Core/Operations/OperationDescription.cs index 2c3f1100..38190c17 100644 --- a/src/TensorFlowNET.Core/Operations/OperationDescription.cs +++ b/src/TensorFlowNET.Core/Operations/OperationDescription.cs @@ -18,6 +18,11 @@ namespace Tensorflow c_api.TF_AddInputList(_handle, inputs, inputs.Length); } + public Operation FinishOperation(Status status) + { + return c_api.TF_FinishOperation(_handle, status); + } + public static implicit operator OperationDescription(IntPtr handle) { return new OperationDescription(handle); diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 7bbd3088..96be50d2 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -232,7 +232,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, string[] values, uint[] lengths, int num_values); + public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values); [DllImport(TensorFlowLibName)] public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); diff --git a/test/TensorFlowNET.UnitTest/CApiColocationTest.cs b/test/TensorFlowNET.UnitTest/CApiColocationTest.cs index 8b90c669..df936024 100644 --- a/test/TensorFlowNET.UnitTest/CApiColocationTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiColocationTest.cs @@ -30,34 +30,32 @@ namespace TensorFlowNET.UnitTest s_.Check(); constant_ = c_test_util.ScalarConst(10, graph_, s_); s_.Check(); - desc_ = c_api.TF_NewOperation(graph_, "AddN", "add"); - s_.Check(); + desc_ = graph_.NewOperation("AddN", "add"); TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; desc_.AddInputList(inputs); - s_.Check(); } private void SetViaStringList(OperationDescription desc, string[] list) { - string[] list_ptrs = new string[list.Length]; - uint[] list_lens = new uint[list.Length]; + var list_ptrs = new IntPtr[list.Length]; + var list_lens = new uint[list.Length]; StringVectorToArrays(list, list_ptrs, list_lens); c_api.TF_SetAttrStringList(desc, "_class", list_ptrs, list_lens, list.Length); } - private void StringVectorToArrays(string[] v, string[] ptrs, uint[] lens) + private void StringVectorToArrays(string[] v, IntPtr[] ptrs, uint[] lens) { for (int i = 0; i < v.Length; ++i) { - ptrs[i] = v[i];// Marshal.StringToHGlobalAnsi(v[i]); + ptrs[i] = Marshal.StringToHGlobalAnsi(v[i]); lens[i] = (uint)v[i].Length; } } private void FinishAndVerify(OperationDescription desc, string[] expected) { - Operation op = c_api.TF_FinishOperation(desc_, s_); + var op = desc_.FinishOperation(s_); ASSERT_EQ(TF_Code.TF_OK, s_.Code); VerifyCollocation(op, expected); } diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 19ece5a5..57963660 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(TF_Code.TF_OK, s.Code); // Serialize to NodeDef. - var node_def = c_test_util.GetNodeDef(neg); + var node_def = neg.GetNodeDef(); // Validate NodeDef is what we expect. ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); @@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest // Look up some nodes by name. Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); EXPECT_EQ(neg, neg2); - var node_def2 = c_test_util.GetNodeDef(neg2); + var node_def2 = neg2.GetNodeDef(); EXPECT_EQ(node_def.ToString(), node_def2.ToString()); Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); EXPECT_EQ(feed, feed2); - node_def = c_test_util.GetNodeDef(feed); - node_def2 = c_test_util.GetNodeDef(feed2); + node_def = feed.GetNodeDef(); + node_def2 = feed2.GetNodeDef(); EXPECT_EQ(node_def.ToString(), node_def2.ToString()); // Test iterating through the nodes of a graph. @@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest uint pos = 0; Operation oper; - while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) + while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) { if (oper.Equals(feed)) { @@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest } else { - node_def = c_test_util.GetNodeDef(oper); + node_def = oper.GetNodeDef(); Assert.Fail($"Unexpected Node: {node_def.ToString()}"); } } @@ -256,7 +256,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(0, neg.GetControlInputs().Length); EXPECT_EQ(0, neg.NumControlOutputs); EXPECT_EQ(0, neg.GetControlOutputs().Length); - + // Import it again, with an input mapping, return outputs, and a return // operation, into the same graph. c_api.TF_DeleteImportGraphDefOptions(opts); @@ -270,7 +270,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); - + Operation scalar2 = graph.OperationByName("imported2/scalar"); Operation feed2 = graph.OperationByName("imported2/feed"); Operation neg2 = graph.OperationByName("imported2/neg"); @@ -287,7 +287,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(0, return_outputs[0].index); EXPECT_EQ(scalar, return_outputs[1].oper); // remapped EXPECT_EQ(0, return_outputs[1].index); - + // Check return operation var return_opers = graph.ReturnOperations(results); ASSERT_EQ(1, return_opers.Length); @@ -302,26 +302,26 @@ namespace TensorFlowNET.UnitTest c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); - + var scalar3 = graph.OperationByName("imported3/scalar"); var feed3 = graph.OperationByName("imported3/feed"); var neg3 = graph.OperationByName("imported3/neg"); ASSERT_TRUE(scalar3 != IntPtr.Zero); ASSERT_TRUE(feed3 != IntPtr.Zero); ASSERT_TRUE(neg3 != IntPtr.Zero); - + // Check that newly-imported scalar and feed have control deps (neg3 will // inherit them from input) var control_inputs = scalar3.GetControlInputs(); ASSERT_EQ(2, scalar3.NumControlInputs); EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed2, control_inputs[1]); - + control_inputs = feed3.GetControlInputs(); ASSERT_EQ(2, feed3.NumControlInputs); EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed2, control_inputs[1]); - + // Export to a graph def so we can import a graph with control dependencies graph_def.Dispose(); graph_def = new Buffer(); diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 07d760c9..edf3b379 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -51,18 +51,6 @@ namespace TensorFlowNET.UnitTest return def; } - public static NodeDef GetNodeDef(Operation oper) - { - var s = new Status(); - var buffer = new Buffer(); - c_api.TF_OperationToNodeDef(oper, buffer, s); - s.Check(); - var ret = NodeDef.Parser.ParseFrom(buffer); - buffer.Dispose(); - s.Dispose(); - return ret; - } - public static bool IsAddN(NodeDef node_def, int n) { if (node_def.Op != "AddN" || node_def.Name != "add" ||