| @@ -179,31 +179,50 @@ namespace Tensorflow | |||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| { | |||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(return_output_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||
| var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i); | |||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||
| } | |||
| return return_outputs; | |||
| } | |||
| public Operation[] ReturnOperations(IntPtr results) | |||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||
| { | |||
| IntPtr return_oper_handle = IntPtr.Zero; | |||
| TF_Operation return_oper_handle = new TF_Operation(); | |||
| int num_return_opers = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle); | |||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | |||
| Operation[] return_opers = new Operation[num_return_opers]; | |||
| for (int i = 0; i < num_return_opers; i++) | |||
| { | |||
| // return_opers[i] = Marshal.PtrToStructure<TF_Output>(return_oper_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||
| } | |||
| return return_opers; | |||
| } | |||
| public Operation OperationByName(string operName) | |||
| { | |||
| return c_api.TF_GraphOperationByName(_handle, operName); | |||
| } | |||
| public Operation[] get_operations() | |||
| { | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| } | |||
| public GraphDef ToGraphDef() | |||
| { | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
| s.Check(); | |||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| s.Dispose(); | |||
| return def; | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TF_DeleteGraph(_handle); | |||
| @@ -0,0 +1,17 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct TF_ImportGraphDefResults | |||
| { | |||
| public IntPtr return_tensors; | |||
| public IntPtr return_nodes; | |||
| public IntPtr missing_unused_key_names; | |||
| public IntPtr missing_unused_key_indexes; | |||
| public IntPtr missing_unused_key_names_data; | |||
| } | |||
| } | |||
| @@ -22,6 +22,13 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts); | |||
| /// <summary> | |||
| /// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). | |||
| /// </summary> | |||
| /// <param name="results"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_DeleteImportGraphDefResults(IntPtr results); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | |||
| @@ -91,9 +98,9 @@ namespace Tensorflow | |||
| /// Write out a serialized representation of `graph` (as a GraphDef protocol | |||
| /// message) to `output_graph_def` (allocated by TF_NewBuffer()). | |||
| /// </summary> | |||
| /// <param name="graph"></param> | |||
| /// <param name="output_graph_def"></param> | |||
| /// <param name="status"></param> | |||
| /// <param name="graph">TF_Graph*</param> | |||
| /// <param name="output_graph_def">TF_Buffer*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status); | |||
| @@ -110,6 +117,15 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); | |||
| /// <summary> | |||
| /// Cause the imported graph to have a control dependency on `oper`. `oper` | |||
| /// should exist in the graph being imported into. | |||
| /// </summary> | |||
| /// <param name="opts"></param> | |||
| /// <param name="oper"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddControlDependency(IntPtr opts, IntPtr oper); | |||
| /// <summary> | |||
| /// Set any imported nodes with input `src_name:src_index` to have that input | |||
| /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, | |||
| @@ -163,6 +179,18 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); | |||
| /// <summary> | |||
| /// Set any imported nodes with control input `src_name` to have that input | |||
| /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, | |||
| /// `dst` references an operation already existing in the graph being imported | |||
| /// into. `src_name` is copied and has no lifetime requirements. | |||
| /// </summary> | |||
| /// <param name="opts">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="src_name">const char*</param> | |||
| /// <param name="dst">TF_Operation*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsRemapControlDependency(IntPtr opts, string src_name, IntPtr dst); | |||
| /// <summary> | |||
| /// Set the prefix to be prepended to the names of nodes in `graph_def` that will | |||
| /// be imported into `graph`. `prefix` is copied and has no lifetime | |||
| @@ -182,7 +210,7 @@ namespace Tensorflow | |||
| /// <param name="num_opers">int*</param> | |||
| /// <param name="opers">TF_Operation***</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers); | |||
| public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers); | |||
| /// <summary> | |||
| /// Fetches the return outputs requested via | |||
| @@ -18,13 +18,16 @@ namespace Tensorflow | |||
| public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||
| public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||
| public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
| public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||
| public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); | |||
| public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); | |||
| public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); | |||
| public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); | |||
| public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | |||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
| public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | |||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | |||
| { | |||
| @@ -42,21 +45,41 @@ namespace Tensorflow | |||
| public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
| public Operation[] ControlInputs(int max_control_inputs) | |||
| public unsafe Operation[] GetControlInputs() | |||
| { | |||
| var control_inputs = new Operation[NumControlInputs]; | |||
| var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); | |||
| c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs); | |||
| if(NumControlInputs > 0) | |||
| { | |||
| IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>()); | |||
| c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); | |||
| for (int i = 0; i < NumControlInputs; i++) | |||
| { | |||
| var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i; | |||
| control_inputs[i] = new Operation(*(IntPtr*)handle); | |||
| } | |||
| } | |||
| return control_inputs; | |||
| } | |||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||
| public Operation[] ControlOutputs(int max_control_outputs) | |||
| public unsafe Operation[] GetControlOutputs() | |||
| { | |||
| var control_outputs = new Operation[NumControlOutputs]; | |||
| var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | |||
| c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs); | |||
| if(NumControlOutputs > 0) | |||
| { | |||
| IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>()); | |||
| c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); | |||
| for (int i = 0; i < NumControlInputs; i++) | |||
| { | |||
| var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; | |||
| control_outputs[i] = new Operation(*(IntPtr*)handle); | |||
| } | |||
| } | |||
| return control_outputs; | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct TF_Operation | |||
| { | |||
| public IntPtr node; | |||
| } | |||
| } | |||
| @@ -16,7 +16,7 @@ namespace Tensorflow | |||
| /// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) | |||
| /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | |||
| /// struct => struct (TF_Output output) => (TF_Output output) | |||
| /// struct* => struct (TF_Output* output) => (TF_Output[] output) | |||
| /// struct* => struct[] (TF_Output* output) => (TF_Output[] output) | |||
| /// struct* => struct* for ref | |||
| /// const char* => string | |||
| /// int32_t => int | |||
| @@ -228,9 +228,9 @@ namespace TensorFlowNET.UnitTest | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar"); | |||
| Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed"); | |||
| Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg"); | |||
| Operation scalar = graph.OperationByName("imported/scalar"); | |||
| Operation feed = graph.OperationByName("imported/feed"); | |||
| Operation neg = graph.OperationByName("imported/neg"); | |||
| // Test basic structure of the imported graph. | |||
| EXPECT_EQ(0, scalar.NumInputs); | |||
| @@ -243,19 +243,19 @@ namespace TensorFlowNET.UnitTest | |||
| // Test that we can't see control edges involving the source and sink nodes. | |||
| EXPECT_EQ(0, scalar.NumControlInputs); | |||
| EXPECT_EQ(0, scalar.ControlInputs(100).Length); | |||
| EXPECT_EQ(0, scalar.GetControlInputs().Length); | |||
| EXPECT_EQ(0, scalar.NumControlOutputs); | |||
| EXPECT_EQ(0, scalar.ControlOutputs(100).Length); | |||
| EXPECT_EQ(0, scalar.GetControlOutputs().Length); | |||
| EXPECT_EQ(0, feed.NumControlInputs); | |||
| EXPECT_EQ(0, feed.ControlInputs(100).Length); | |||
| EXPECT_EQ(0, feed.GetControlInputs().Length); | |||
| EXPECT_EQ(0, feed.NumControlOutputs); | |||
| EXPECT_EQ(0, feed.ControlOutputs(100).Length); | |||
| EXPECT_EQ(0, feed.GetControlOutputs().Length); | |||
| EXPECT_EQ(0, neg.NumControlInputs); | |||
| EXPECT_EQ(0, neg.ControlInputs(100).Length); | |||
| EXPECT_EQ(0, neg.GetControlInputs().Length); | |||
| EXPECT_EQ(0, neg.NumControlOutputs); | |||
| EXPECT_EQ(0, neg.ControlOutputs(100).Length); | |||
| EXPECT_EQ(0, neg.GetControlOutputs().Length); | |||
| // Import it again, with an input mapping, return outputs, and a return | |||
| // operation, into the same graph. | |||
| @@ -271,9 +271,9 @@ namespace TensorFlowNET.UnitTest | |||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar"); | |||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed"); | |||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg"); | |||
| Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||
| Operation feed2 = graph.OperationByName("imported2/feed"); | |||
| Operation neg2 = graph.OperationByName("imported2/neg"); | |||
| // Check input mapping | |||
| neg_input = neg.Input(0); | |||
| @@ -289,8 +289,72 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ(0, return_outputs[1].index); | |||
| // Check return operation | |||
| var num_return_opers = graph.ReturnOperations(results); | |||
| ASSERT_EQ(1, num_return_opers); | |||
| var return_opers = graph.ReturnOperations(results); | |||
| ASSERT_EQ(1, return_opers.Length); | |||
| EXPECT_EQ(scalar2, return_opers[0]); // not remapped | |||
| c_api.TF_DeleteImportGraphDefResults(results); | |||
| // Import again, with control dependencies, into the same graph. | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| var scalar3 = graph.OperationByName("imported3/scalar"); | |||
| var feed3 = graph.OperationByName("imported3/feed"); | |||
| var neg3 = graph.OperationByName("imported3/neg"); | |||
| ASSERT_TRUE(scalar3 != IntPtr.Zero); | |||
| ASSERT_TRUE(feed3 != IntPtr.Zero); | |||
| ASSERT_TRUE(neg3 != IntPtr.Zero); | |||
| // Check that newly-imported scalar and feed have control deps (neg3 will | |||
| // inherit them from input) | |||
| var control_inputs = scalar3.GetControlInputs(); | |||
| ASSERT_EQ(2, scalar3.NumControlInputs); | |||
| EXPECT_EQ(feed, control_inputs[0]); | |||
| EXPECT_EQ(feed2, control_inputs[1]); | |||
| control_inputs = feed3.GetControlInputs(); | |||
| ASSERT_EQ(2, feed3.NumControlInputs); | |||
| EXPECT_EQ(feed, control_inputs[0]); | |||
| EXPECT_EQ(feed2, control_inputs[1]); | |||
| // Export to a graph def so we can import a graph with control dependencies | |||
| graph_def.Dispose(); | |||
| graph_def = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Import again, with remapped control dependency, into the same graph | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | |||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| var scalar4 = graph.OperationByName("imported4/imported3/scalar"); | |||
| var feed4 = graph.OperationByName("imported4/imported2/feed"); | |||
| // Check that imported `imported3/scalar` has remapped control dep from | |||
| // original graph and imported control dep | |||
| control_inputs = scalar4.GetControlInputs(); | |||
| ASSERT_EQ(2, scalar4.NumControlInputs); | |||
| EXPECT_EQ(feed, control_inputs[0]); | |||
| EXPECT_EQ(feed4, control_inputs[1]); | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| c_api.TF_DeleteBuffer(graph_def); | |||
| // Can add nodes to the imported graph without trouble. | |||
| c_test_util.Add(feed, scalar, graph, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| //graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| } | |||
| } | |||