| @@ -0,0 +1,21 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class Graph | |||||
| { | |||||
| public Buffer ToGraphDef(Status s) | |||||
| { | |||||
| var buffer = new Buffer(); | |||||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||||
| s.Check(); | |||||
| // var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| // buffer.Dispose(); | |||||
| return buffer; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,28 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class Graph | |||||
| { | |||||
| public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s) | |||||
| { | |||||
| var num_return_outputs = opts.NumReturnOutputs; | |||||
| var return_outputs = new TF_Output[num_return_outputs]; | |||||
| TF_Output* return_output_handle = (TF_Output*)Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>() * 2); | |||||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| { | |||||
| var handle = return_output_handle + i * Marshal.SizeOf<TF_Output>(); | |||||
| return_outputs[i] = new TF_Output((*handle).oper, (*handle).index); | |||||
| } | |||||
| return return_outputs; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -13,7 +13,7 @@ namespace Tensorflow | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
| /// https://www.tensorflow.org/guide/graphs | /// https://www.tensorflow.org/guide/graphs | ||||
| /// </summary> | /// </summary> | ||||
| public class Graph : IDisposable | |||||
| public partial class Graph : IDisposable | |||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| private Dictionary<int, Operation> _nodes_by_id; | private Dictionary<int, Operation> _nodes_by_id; | ||||
| @@ -211,18 +211,6 @@ namespace Tensorflow | |||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
| } | } | ||||
| public GraphDef ToGraphDef() | |||||
| { | |||||
| var s = new Status(); | |||||
| var buffer = new Buffer(); | |||||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||||
| s.Check(); | |||||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| s.Dispose(); | |||||
| return def; | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| c_api.TF_DeleteGraph(_handle); | c_api.TF_DeleteGraph(_handle); | ||||
| @@ -0,0 +1,35 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class ImportGraphDefOptions : IDisposable | |||||
| { | |||||
| private IntPtr _handle; | |||||
| public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||||
| public ImportGraphDefOptions() | |||||
| { | |||||
| _handle = c_api.TF_NewImportGraphDefOptions(); | |||||
| } | |||||
| public ImportGraphDefOptions(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| public void AddReturnOutput(string name, int index) | |||||
| { | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| c_api.TF_DeleteImportGraphDefOptions(_handle); | |||||
| } | |||||
| public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | |||||
| public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); | |||||
| } | |||||
| } | |||||
| @@ -45,6 +45,24 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | ||||
| /// <summary> | |||||
| /// Import the graph serialized in `graph_def` into `graph`. | |||||
| /// Convenience function for when only return outputs are needed. | |||||
| /// | |||||
| /// `num_return_outputs` must be the number of return outputs added (i.e. the | |||||
| /// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If | |||||
| /// `num_return_outputs` is non-zero, `return_outputs` must be of length | |||||
| /// `num_return_outputs`. Otherwise it can be null. | |||||
| /// </summary> | |||||
| /// <param name="graph">TF_Graph* graph</param> | |||||
| /// <param name="graph_def">const TF_Buffer*</param> | |||||
| /// <param name="options">const TF_ImportGraphDefOptions*</param> | |||||
| /// <param name="return_outputs">TF_Output*</param> | |||||
| /// <param name="num_return_outputs">int</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, TF_Output* return_outputs, int num_return_outputs, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | ||||
| /// a bad status on error. Otherwise, returns a populated | /// a bad status on error. Otherwise, returns a populated | ||||
| @@ -357,10 +357,55 @@ namespace TensorFlowNET.UnitTest | |||||
| s.Dispose(); | s.Dispose(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)` | |||||
| /// </summary> | |||||
| [TestMethod] | [TestMethod] | ||||
| public void c_api_ImportGraphDef_WithReturnOutputs() | public void c_api_ImportGraphDef_WithReturnOutputs() | ||||
| { | { | ||||
| var s = new Status(); | |||||
| var graph = new Graph(); | |||||
| // Create a graph with two nodes: x and 3 | |||||
| var feed = c_test_util.Placeholder(graph, s); | |||||
| EXPECT_EQ(feed, graph.OperationByName("feed")); | |||||
| var scalar = c_test_util.ScalarConst(3, graph, s); | |||||
| EXPECT_EQ(scalar, graph.OperationByName("scalar")); | |||||
| var neg = c_test_util.Neg(scalar, graph, s); | |||||
| EXPECT_EQ(neg, graph.OperationByName("neg")); | |||||
| // Export to a GraphDef. | |||||
| var graph_def = graph.ToGraphDef(s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Import it in a fresh graph with return outputs. | |||||
| graph.Dispose(); | |||||
| graph = new Graph(); | |||||
| var opts = new ImportGraphDefOptions(); | |||||
| opts.AddReturnOutput("feed", 0); | |||||
| opts.AddReturnOutput("scalar", 0); | |||||
| EXPECT_EQ(2, opts.NumReturnOutputs); | |||||
| var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| scalar = graph.OperationByName("scalar"); | |||||
| feed = graph.OperationByName("feed"); | |||||
| neg = graph.OperationByName("neg"); | |||||
| ASSERT_TRUE(scalar != IntPtr.Zero); | |||||
| ASSERT_TRUE(feed != IntPtr.Zero); | |||||
| ASSERT_TRUE(neg != IntPtr.Zero); | |||||
| // Check return outputs | |||||
| EXPECT_EQ(feed, return_outputs[0].oper); | |||||
| EXPECT_EQ(0, return_outputs[0].index); | |||||
| EXPECT_EQ(scalar, return_outputs[1].oper); | |||||
| EXPECT_EQ(0, return_outputs[1].index); | |||||
| opts.Dispose(); | |||||
| graph_def.Dispose(); | |||||
| graph.Dispose(); | |||||
| s.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||