diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
index ca726d5b..e2ff80e2 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
@@ -20,19 +20,7 @@ namespace Tensorflow
public OperationDescription NewOperation(string opType, string opName)
{
- OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName);
- return desc;
-
- /*c_api.TF_SetAttrTensor(desc, "value", tensor, Status);
-
- Status.Check();
-
- c_api.TF_SetAttrType(desc, "dtype", tensor.dtype);
-
- var op = c_api.TF_FinishOperation(desc, Status);
- Status.Check();
-
- return op;*/
+ return c_api.TF_NewOperation(_handle, opType, opName);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 0d3d8a93..5d1f4f83 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -145,6 +145,17 @@ namespace Tensorflow
return ret;
}
+ public NodeDef GetNodeDef()
+ {
+ using (var s = new Status())
+ using (var buffer = new Buffer())
+ {
+ c_api.TF_OperationToNodeDef(_handle, buffer, s);
+ s.Check();
+ return NodeDef.Parser.ParseFrom(buffer);
+ }
+ }
+
public static implicit operator Operation(IntPtr handle) => new Operation(handle);
public static implicit operator IntPtr(Operation op) => op._handle;
diff --git a/src/TensorFlowNET.Core/Operations/OperationDescription.cs b/src/TensorFlowNET.Core/Operations/OperationDescription.cs
index 2c3f1100..38190c17 100644
--- a/src/TensorFlowNET.Core/Operations/OperationDescription.cs
+++ b/src/TensorFlowNET.Core/Operations/OperationDescription.cs
@@ -18,6 +18,11 @@ namespace Tensorflow
c_api.TF_AddInputList(_handle, inputs, inputs.Length);
}
+ public Operation FinishOperation(Status status)
+ {
+ return c_api.TF_FinishOperation(_handle, status);
+ }
+
public static implicit operator OperationDescription(IntPtr handle)
{
return new OperationDescription(handle);
diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
index 7bbd3088..96be50d2 100644
--- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
@@ -232,7 +232,7 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, string[] values, uint[] lengths, int num_values);
+ public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values);
[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status);
diff --git a/test/TensorFlowNET.UnitTest/CApiColocationTest.cs b/test/TensorFlowNET.UnitTest/CApiColocationTest.cs
index 8b90c669..df936024 100644
--- a/test/TensorFlowNET.UnitTest/CApiColocationTest.cs
+++ b/test/TensorFlowNET.UnitTest/CApiColocationTest.cs
@@ -30,34 +30,32 @@ namespace TensorFlowNET.UnitTest
s_.Check();
constant_ = c_test_util.ScalarConst(10, graph_, s_);
s_.Check();
- desc_ = c_api.TF_NewOperation(graph_, "AddN", "add");
- s_.Check();
+ desc_ = graph_.NewOperation("AddN", "add");
TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) };
desc_.AddInputList(inputs);
- s_.Check();
}
private void SetViaStringList(OperationDescription desc, string[] list)
{
- string[] list_ptrs = new string[list.Length];
- uint[] list_lens = new uint[list.Length];
+ var list_ptrs = new IntPtr[list.Length];
+ var list_lens = new uint[list.Length];
StringVectorToArrays(list, list_ptrs, list_lens);
c_api.TF_SetAttrStringList(desc, "_class", list_ptrs, list_lens, list.Length);
}
- private void StringVectorToArrays(string[] v, string[] ptrs, uint[] lens)
+ private void StringVectorToArrays(string[] v, IntPtr[] ptrs, uint[] lens)
{
for (int i = 0; i < v.Length; ++i)
{
- ptrs[i] = v[i];// Marshal.StringToHGlobalAnsi(v[i]);
+ ptrs[i] = Marshal.StringToHGlobalAnsi(v[i]);
lens[i] = (uint)v[i].Length;
}
}
private void FinishAndVerify(OperationDescription desc, string[] expected)
{
- Operation op = c_api.TF_FinishOperation(desc_, s_);
+ var op = desc_.FinishOperation(s_);
ASSERT_EQ(TF_Code.TF_OK, s_.Code);
VerifyCollocation(op, expected);
}
diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs
index 19ece5a5..57963660 100644
--- a/test/TensorFlowNET.UnitTest/GraphTest.cs
+++ b/test/TensorFlowNET.UnitTest/GraphTest.cs
@@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_Code.TF_OK, s.Code);
// Serialize to NodeDef.
- var node_def = c_test_util.GetNodeDef(neg);
+ var node_def = neg.GetNodeDef();
// Validate NodeDef is what we expect.
ASSERT_TRUE(c_test_util.IsNeg(node_def, "add"));
@@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest
// Look up some nodes by name.
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
EXPECT_EQ(neg, neg2);
- var node_def2 = c_test_util.GetNodeDef(neg2);
+ var node_def2 = neg2.GetNodeDef();
EXPECT_EQ(node_def.ToString(), node_def2.ToString());
Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
EXPECT_EQ(feed, feed2);
- node_def = c_test_util.GetNodeDef(feed);
- node_def2 = c_test_util.GetNodeDef(feed2);
+ node_def = feed.GetNodeDef();
+ node_def2 = feed2.GetNodeDef();
EXPECT_EQ(node_def.ToString(), node_def2.ToString());
// Test iterating through the nodes of a graph.
@@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest
uint pos = 0;
Operation oper;
- while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
+ while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
{
if (oper.Equals(feed))
{
@@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest
}
else
{
- node_def = c_test_util.GetNodeDef(oper);
+ node_def = oper.GetNodeDef();
Assert.Fail($"Unexpected Node: {node_def.ToString()}");
}
}
@@ -256,7 +256,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(0, neg.GetControlInputs().Length);
EXPECT_EQ(0, neg.NumControlOutputs);
EXPECT_EQ(0, neg.GetControlOutputs().Length);
-
+
// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
c_api.TF_DeleteImportGraphDefOptions(opts);
@@ -270,7 +270,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
-
+
Operation scalar2 = graph.OperationByName("imported2/scalar");
Operation feed2 = graph.OperationByName("imported2/feed");
Operation neg2 = graph.OperationByName("imported2/neg");
@@ -287,7 +287,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
EXPECT_EQ(0, return_outputs[1].index);
-
+
// Check return operation
var return_opers = graph.ReturnOperations(results);
ASSERT_EQ(1, return_opers.Length);
@@ -302,26 +302,26 @@ namespace TensorFlowNET.UnitTest
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
-
+
var scalar3 = graph.OperationByName("imported3/scalar");
var feed3 = graph.OperationByName("imported3/feed");
var neg3 = graph.OperationByName("imported3/neg");
ASSERT_TRUE(scalar3 != IntPtr.Zero);
ASSERT_TRUE(feed3 != IntPtr.Zero);
ASSERT_TRUE(neg3 != IntPtr.Zero);
-
+
// Check that newly-imported scalar and feed have control deps (neg3 will
// inherit them from input)
var control_inputs = scalar3.GetControlInputs();
ASSERT_EQ(2, scalar3.NumControlInputs);
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed2, control_inputs[1]);
-
+
control_inputs = feed3.GetControlInputs();
ASSERT_EQ(2, feed3.NumControlInputs);
EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed2, control_inputs[1]);
-
+
// Export to a graph def so we can import a graph with control dependencies
graph_def.Dispose();
graph_def = new Buffer();
diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs
index 07d760c9..edf3b379 100644
--- a/test/TensorFlowNET.UnitTest/c_test_util.cs
+++ b/test/TensorFlowNET.UnitTest/c_test_util.cs
@@ -51,18 +51,6 @@ namespace TensorFlowNET.UnitTest
return def;
}
- public static NodeDef GetNodeDef(Operation oper)
- {
- var s = new Status();
- var buffer = new Buffer();
- c_api.TF_OperationToNodeDef(oper, buffer, s);
- s.Check();
- var ret = NodeDef.Parser.ParseFrom(buffer);
- buffer.Dispose();
- s.Dispose();
- return ret;
- }
-
public static bool IsAddN(NodeDef node_def, int n)
{
if (node_def.Op != "AddN" || node_def.Name != "add" ||