| @@ -0,0 +1,21 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| { | |||
| public Buffer ToGraphDef(Status s) | |||
| { | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
| s.Check(); | |||
| // var def = GraphDef.Parser.ParseFrom(buffer); | |||
| // buffer.Dispose(); | |||
| return buffer; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,28 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| { | |||
| public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s) | |||
| { | |||
| var num_return_outputs = opts.NumReturnOutputs; | |||
| var return_outputs = new TF_Output[num_return_outputs]; | |||
| TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>() * 2); | |||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| { | |||
| var handle = return_output_handle + i * Marshal.SizeOf<TF_Output>(); | |||
| return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); | |||
| } | |||
| return return_outputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -13,7 +13,7 @@ namespace Tensorflow | |||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
| /// https://www.tensorflow.org/guide/graphs | |||
| /// </summary> | |||
| public class Graph : IDisposable | |||
| public partial class Graph : IDisposable | |||
| { | |||
| private IntPtr _handle; | |||
| private Dictionary<int, Operation> _nodes_by_id; | |||
| @@ -211,18 +211,6 @@ namespace Tensorflow | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| } | |||
| public GraphDef ToGraphDef() | |||
| { | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
| s.Check(); | |||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| s.Dispose(); | |||
| return def; | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TF_DeleteGraph(_handle); | |||
| @@ -0,0 +1,35 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class ImportGraphDefOptions : IDisposable | |||
| { | |||
| private IntPtr _handle; | |||
| public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
| public ImportGraphDefOptions() | |||
| { | |||
| _handle = c_api.TF_NewImportGraphDefOptions(); | |||
| } | |||
| public ImportGraphDefOptions(IntPtr handle) | |||
| { | |||
| _handle = handle; | |||
| } | |||
| public void AddReturnOutput(string name, int index) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TF_DeleteImportGraphDefOptions(_handle); | |||
| } | |||
| public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | |||
| public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); | |||
| } | |||
| } | |||
| @@ -45,6 +45,24 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. | |||
| /// Convenience function for when only return outputs are needed. | |||
| /// | |||
| /// `num_return_outputs` must be the number of return outputs added (i.e. the | |||
| /// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If | |||
| /// `num_return_outputs` is non-zero, `return_outputs` must be of length | |||
| /// `num_return_outputs`. Otherwise it can be null. | |||
| /// </summary> | |||
| /// <param name="graph">TF_Graph* graph</param> | |||
| /// <param name="graph_def">const TF_Buffer*</param> | |||
| /// <param name="options">const TF_ImportGraphDefOptions*</param> | |||
| /// <param name="return_outputs">TF_Output*</param> | |||
| /// <param name="num_return_outputs">int</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||
| /// a bad status on error. Otherwise, returns a populated | |||
| @@ -357,10 +357,55 @@ namespace TensorFlowNET.UnitTest | |||
| s.Dispose(); | |||
| } | |||
| /// <summary> | |||
| /// Port from c_api_test.cc | |||
| /// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)` | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void c_api_ImportGraphDef_WithReturnOutputs() | |||
| { | |||
| var s = new Status(); | |||
| var graph = new Graph(); | |||
| // Create a graph with two nodes: x and 3 | |||
| var feed = c_test_util.Placeholder(graph, s); | |||
| EXPECT_EQ(feed, graph.OperationByName("feed")); | |||
| var scalar = c_test_util.ScalarConst(3, graph, s); | |||
| EXPECT_EQ(scalar, graph.OperationByName("scalar")); | |||
| var neg = c_test_util.Neg(scalar, graph, s); | |||
| EXPECT_EQ(neg, graph.OperationByName("neg")); | |||
| // Export to a GraphDef. | |||
| var graph_def = graph.ToGraphDef(s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Import it in a fresh graph with return outputs. | |||
| graph.Dispose(); | |||
| graph = new Graph(); | |||
| var opts = new ImportGraphDefOptions(); | |||
| opts.AddReturnOutput("feed", 0); | |||
| opts.AddReturnOutput("scalar", 0); | |||
| EXPECT_EQ(2, opts.NumReturnOutputs); | |||
| var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| scalar = graph.OperationByName("scalar"); | |||
| feed = graph.OperationByName("feed"); | |||
| neg = graph.OperationByName("neg"); | |||
| ASSERT_TRUE(scalar != IntPtr.Zero); | |||
| ASSERT_TRUE(feed != IntPtr.Zero); | |||
| ASSERT_TRUE(neg != IntPtr.Zero); | |||
| // Check return outputs | |||
| EXPECT_EQ(feed, return_outputs[0].oper); | |||
| EXPECT_EQ(0, return_outputs[0].index); | |||
| EXPECT_EQ(scalar, return_outputs[1].oper); | |||
| EXPECT_EQ(0, return_outputs[1].index); | |||
| opts.Dispose(); | |||
| graph_def.Dispose(); | |||
| graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| } | |||
| } | |||