add FinishOperation to OperationDescriptiontags/v0.1.0-Tensor
| @@ -20,19 +20,7 @@ namespace Tensorflow | |||
| public OperationDescription NewOperation(string opType, string opName) | |||
| { | |||
| OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName); | |||
| return desc; | |||
| /*c_api.TF_SetAttrTensor(desc, "value", tensor, Status); | |||
| Status.Check(); | |||
| c_api.TF_SetAttrType(desc, "dtype", tensor.dtype); | |||
| var op = c_api.TF_FinishOperation(desc, Status); | |||
| Status.Check(); | |||
| return op;*/ | |||
| return c_api.TF_NewOperation(_handle, opType, opName); | |||
| } | |||
| } | |||
| } | |||
| @@ -145,6 +145,17 @@ namespace Tensorflow | |||
| return ret; | |||
| } | |||
| public NodeDef GetNodeDef() | |||
| { | |||
| using (var s = new Status()) | |||
| using (var buffer = new Buffer()) | |||
| { | |||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||
| s.Check(); | |||
| return NodeDef.Parser.ParseFrom(buffer); | |||
| } | |||
| } | |||
| public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||
| public static implicit operator IntPtr(Operation op) => op._handle; | |||
| @@ -18,6 +18,11 @@ namespace Tensorflow | |||
| c_api.TF_AddInputList(_handle, inputs, inputs.Length); | |||
| } | |||
| public Operation FinishOperation(Status status) | |||
| { | |||
| return c_api.TF_FinishOperation(_handle, status); | |||
| } | |||
| public static implicit operator OperationDescription(IntPtr handle) | |||
| { | |||
| return new OperationDescription(handle); | |||
| @@ -232,7 +232,7 @@ namespace Tensorflow | |||
| /// <param name="lengths"></param> | |||
| /// <param name="num_values"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, string[] values, uint[] lengths, int num_values); | |||
| public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | |||
| @@ -30,34 +30,32 @@ namespace TensorFlowNET.UnitTest | |||
| s_.Check(); | |||
| constant_ = c_test_util.ScalarConst(10, graph_, s_); | |||
| s_.Check(); | |||
| desc_ = c_api.TF_NewOperation(graph_, "AddN", "add"); | |||
| s_.Check(); | |||
| desc_ = graph_.NewOperation("AddN", "add"); | |||
| TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; | |||
| desc_.AddInputList(inputs); | |||
| s_.Check(); | |||
| } | |||
| private void SetViaStringList(OperationDescription desc, string[] list) | |||
| { | |||
| string[] list_ptrs = new string[list.Length]; | |||
| uint[] list_lens = new uint[list.Length]; | |||
| var list_ptrs = new IntPtr[list.Length]; | |||
| var list_lens = new uint[list.Length]; | |||
| StringVectorToArrays(list, list_ptrs, list_lens); | |||
| c_api.TF_SetAttrStringList(desc, "_class", list_ptrs, list_lens, list.Length); | |||
| } | |||
| private void StringVectorToArrays(string[] v, string[] ptrs, uint[] lens) | |||
| private void StringVectorToArrays(string[] v, IntPtr[] ptrs, uint[] lens) | |||
| { | |||
| for (int i = 0; i < v.Length; ++i) | |||
| { | |||
| ptrs[i] = v[i];// Marshal.StringToHGlobalAnsi(v[i]); | |||
| ptrs[i] = Marshal.StringToHGlobalAnsi(v[i]); | |||
| lens[i] = (uint)v[i].Length; | |||
| } | |||
| } | |||
| private void FinishAndVerify(OperationDescription desc, string[] expected) | |||
| { | |||
| Operation op = c_api.TF_FinishOperation(desc_, s_); | |||
| var op = desc_.FinishOperation(s_); | |||
| ASSERT_EQ(TF_Code.TF_OK, s_.Code); | |||
| VerifyCollocation(op, expected); | |||
| } | |||
| @@ -130,7 +130,7 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Serialize to NodeDef. | |||
| var node_def = c_test_util.GetNodeDef(neg); | |||
| var node_def = neg.GetNodeDef(); | |||
| // Validate NodeDef is what we expect. | |||
| ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); | |||
| @@ -145,13 +145,13 @@ namespace TensorFlowNET.UnitTest | |||
| // Look up some nodes by name. | |||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | |||
| EXPECT_EQ(neg, neg2); | |||
| var node_def2 = c_test_util.GetNodeDef(neg2); | |||
| var node_def2 = neg2.GetNodeDef(); | |||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | |||
| EXPECT_EQ(feed, feed2); | |||
| node_def = c_test_util.GetNodeDef(feed); | |||
| node_def2 = c_test_util.GetNodeDef(feed2); | |||
| node_def = feed.GetNodeDef(); | |||
| node_def2 = feed2.GetNodeDef(); | |||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
| // Test iterating through the nodes of a graph. | |||
| @@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest | |||
| uint pos = 0; | |||
| Operation oper; | |||
| while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | |||
| while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | |||
| { | |||
| if (oper.Equals(feed)) | |||
| { | |||
| @@ -186,7 +186,7 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| else | |||
| { | |||
| node_def = c_test_util.GetNodeDef(oper); | |||
| node_def = oper.GetNodeDef(); | |||
| Assert.Fail($"Unexpected Node: {node_def.ToString()}"); | |||
| } | |||
| } | |||
| @@ -256,7 +256,7 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ(0, neg.GetControlInputs().Length); | |||
| EXPECT_EQ(0, neg.NumControlOutputs); | |||
| EXPECT_EQ(0, neg.GetControlOutputs().Length); | |||
| // Import it again, with an input mapping, return outputs, and a return | |||
| // operation, into the same graph. | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| @@ -270,7 +270,7 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||
| Operation feed2 = graph.OperationByName("imported2/feed"); | |||
| Operation neg2 = graph.OperationByName("imported2/neg"); | |||
| @@ -287,7 +287,7 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ(0, return_outputs[0].index); | |||
| EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||
| EXPECT_EQ(0, return_outputs[1].index); | |||
| // Check return operation | |||
| var return_opers = graph.ReturnOperations(results); | |||
| ASSERT_EQ(1, return_opers.Length); | |||
| @@ -302,26 +302,26 @@ namespace TensorFlowNET.UnitTest | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| var scalar3 = graph.OperationByName("imported3/scalar"); | |||
| var feed3 = graph.OperationByName("imported3/feed"); | |||
| var neg3 = graph.OperationByName("imported3/neg"); | |||
| ASSERT_TRUE(scalar3 != IntPtr.Zero); | |||
| ASSERT_TRUE(feed3 != IntPtr.Zero); | |||
| ASSERT_TRUE(neg3 != IntPtr.Zero); | |||
| // Check that newly-imported scalar and feed have control deps (neg3 will | |||
| // inherit them from input) | |||
| var control_inputs = scalar3.GetControlInputs(); | |||
| ASSERT_EQ(2, scalar3.NumControlInputs); | |||
| EXPECT_EQ(feed, control_inputs[0]); | |||
| EXPECT_EQ(feed2, control_inputs[1]); | |||
| control_inputs = feed3.GetControlInputs(); | |||
| ASSERT_EQ(2, feed3.NumControlInputs); | |||
| EXPECT_EQ(feed, control_inputs[0]); | |||
| EXPECT_EQ(feed2, control_inputs[1]); | |||
| // Export to a graph def so we can import a graph with control dependencies | |||
| graph_def.Dispose(); | |||
| graph_def = new Buffer(); | |||
| @@ -51,18 +51,6 @@ namespace TensorFlowNET.UnitTest | |||
| return def; | |||
| } | |||
| public static NodeDef GetNodeDef(Operation oper) | |||
| { | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_OperationToNodeDef(oper, buffer, s); | |||
| s.Check(); | |||
| var ret = NodeDef.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| s.Dispose(); | |||
| return ret; | |||
| } | |||
| public static bool IsAddN(NodeDef node_def, int n) | |||
| { | |||
| if (node_def.Op != "AddN" || node_def.Name != "add" || | |||