diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs deleted file mode 100644 index 145a3058..00000000 --- a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs +++ /dev/null @@ -1,10 +0,0 @@ -namespace Tensorflow.Framework.Models -{ - public class ScopedTFImportGraphDefOptions : ImportGraphDefOptions - { - public ScopedTFImportGraphDefOptions() : base() - { - - } - } -} diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs deleted file mode 100644 index dc1236e3..00000000 --- a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs +++ /dev/null @@ -1,17 +0,0 @@ -using System; - -namespace Tensorflow.Framework.Models -{ - public class ScopedTFImportGraphDefResults : ImportGraphDefOptions - { - public ScopedTFImportGraphDefResults() : base() - { - - } - - public ScopedTFImportGraphDefResults(IntPtr results) : base(results) - { - - } - } -} diff --git a/src/TensorFlowNET.Core/Framework/importer.cs b/src/TensorFlowNET.Core/Framework/importer.cs index ff1ba4f5..4a35f8d8 100644 --- a/src/TensorFlowNET.Core/Framework/importer.cs +++ b/src/TensorFlowNET.Core/Framework/importer.cs @@ -62,7 +62,7 @@ namespace Tensorflow { _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); // need to create a class ImportGraphDefWithResults with IDisposal - results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status.Handle); + results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options.Handle, status.Handle); status.Check(true); } @@ -114,8 +114,8 @@ namespace Tensorflow Dictionary input_map, string[] return_elements) { - c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); - c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); + c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix); + c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1); foreach(var input in input_map) { @@ -130,11 +130,11 @@ namespace Tensorflow if(name.Contains(":")) { var (op_name, index) = _ParseTensorName(name); - c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index); } else { - c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); + c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name); } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index 4df5e1eb..3a919145 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -29,7 +29,7 @@ namespace Tensorflow int size = Marshal.SizeOf(); var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); - c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s.Handle); + c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts.Handle, return_output_handle, num_return_outputs, s.Handle); var tf_output_ptr = (TF_Output*) return_output_handle; for (int i = 0; i < num_return_outputs; i++) @@ -53,8 +53,8 @@ namespace Tensorflow using (var graph_def = new Buffer(bytes)) { as_default(); - c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); - c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status.Handle); + c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix); + c_api.TF_GraphImportGraphDef(_handle, graph_def, opts.Handle, status.Handle); status.Check(true); return status.Code == TF_Code.TF_OK; } diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs index 70802597..38e30343 100644 --- a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs +++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs @@ -18,30 +18,24 @@ using System; namespace Tensorflow { - public class ImportGraphDefOptions : DisposableObject + public sealed class ImportGraphDefOptions : IDisposable { + public SafeImportGraphDefOptionsHandle Handle { get; } + public int NumReturnOutputs - => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); + => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle); public ImportGraphDefOptions() { - _handle = c_api.TF_NewImportGraphDefOptions(); - } - - public ImportGraphDefOptions(IntPtr handle) - { - _handle = handle; + Handle = c_api.TF_NewImportGraphDefOptions(); } public void AddReturnOutput(string name, int index) { - c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index); } - protected override void DisposeUnmanagedResources(IntPtr handle) - => c_api.TF_DeleteImportGraphDefOptions(handle); - - public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; - public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); + public void Dispose() + => Handle.Dispose(); } } diff --git a/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefOptionsHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefOptionsHandle.cs new file mode 100644 index 00000000..c1aac9d1 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefOptionsHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeImportGraphDefOptionsHandle : SafeTensorflowHandle + { + public SafeImportGraphDefOptionsHandle() + { + } + + public SafeImportGraphDefOptionsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteImportGraphDefOptions(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 971e5ddf..4547d92d 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -78,7 +78,7 @@ namespace Tensorflow /// int /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); + public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); /// /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and @@ -92,7 +92,7 @@ namespace Tensorflow /// TF_Status* /// TF_ImportGraphDefResults* [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, SafeStatusHandle status); + public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); /// /// Import the graph serialized in `graph_def` into `graph`. @@ -102,7 +102,7 @@ namespace Tensorflow /// TF_ImportGraphDefOptions* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, SafeStatusHandle status); + public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); /// /// Iterate through the operations of a graph. @@ -160,7 +160,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefOptionsAddControlDependency(IntPtr opts, IntPtr oper); + public static extern void TF_ImportGraphDefOptionsAddControlDependency(SafeImportGraphDefOptionsHandle opts, IntPtr oper); /// /// Set any imported nodes with input `src_name:src_index` to have that input @@ -173,7 +173,7 @@ namespace Tensorflow /// int /// TF_Output [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst); + public static extern void TF_ImportGraphDefOptionsAddInputMapping(SafeImportGraphDefOptionsHandle opts, string src_name, int src_index, TF_Output dst); /// /// Add an operation in `graph_def` to be returned via the `return_opers` output @@ -183,7 +183,7 @@ namespace Tensorflow /// TF_ImportGraphDefOptions* opts /// const char* [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name); + public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); /// /// Add an output in `graph_def` to be returned via the `return_outputs` output @@ -195,7 +195,7 @@ namespace Tensorflow /// const char* /// int [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index); + public static extern void TF_ImportGraphDefOptionsAddReturnOutput(SafeImportGraphDefOptionsHandle opts, string oper_name, int index); /// /// Returns the number of return operations added via @@ -204,7 +204,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts); + public static extern int TF_ImportGraphDefOptionsNumReturnOperations(SafeImportGraphDefOptionsHandle opts); /// /// Returns the number of return outputs added via @@ -213,7 +213,7 @@ namespace Tensorflow /// const TF_ImportGraphDefOptions* /// [DllImport(TensorFlowLibName)] - public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); + public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(SafeImportGraphDefOptionsHandle opts); /// /// Set any imported nodes with control input `src_name` to have that input @@ -225,7 +225,7 @@ namespace Tensorflow /// const char* /// TF_Operation* [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefOptionsRemapControlDependency(IntPtr opts, string src_name, IntPtr dst); + public static extern void TF_ImportGraphDefOptionsRemapControlDependency(SafeImportGraphDefOptionsHandle opts, string src_name, IntPtr dst); /// /// Set the prefix to be prepended to the names of nodes in `graph_def` that will @@ -234,7 +234,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); + public static extern void TF_ImportGraphDefOptionsSetPrefix(SafeImportGraphDefOptionsHandle ops, string prefix); /// /// Set whether to uniquify imported operation names. If true, imported operation @@ -246,7 +246,7 @@ namespace Tensorflow /// TF_ImportGraphDefOptions* /// unsigned char [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(IntPtr ops, char uniquify_prefix); + public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); /// /// Fetches the return operations requested via @@ -295,7 +295,7 @@ namespace Tensorflow public static extern IntPtr TF_NewGraph(); [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_NewImportGraphDefOptions(); + public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); /// /// Updates 'dst' to consume 'new_src'. diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 1ff0d40a..e2855425 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -222,10 +222,12 @@ namespace TensorFlowNET.UnitTest.NativeAPI // Import it, with a prefix, in a fresh graph. graph.Dispose(); graph = new Graph().as_default(); - var opts = c_api.TF_NewImportGraphDefOptions(); - c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); - c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); - EXPECT_EQ(TF_Code.TF_OK, s.Code); + using (var opts = c_api.TF_NewImportGraphDefOptions()) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + } Operation scalar = graph.OperationByName("imported/scalar"); Operation feed = graph.OperationByName("imported/feed"); @@ -258,17 +260,19 @@ namespace TensorFlowNET.UnitTest.NativeAPI // Import it again, with an input mapping, return outputs, and a return // operation, into the same graph. - c_api.TF_DeleteImportGraphDefOptions(opts); - opts = c_api.TF_NewImportGraphDefOptions(); - c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); - c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); - c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); - c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); - EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); - c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); - EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); - var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle); - EXPECT_EQ(TF_Code.TF_OK, s.Code); + IntPtr results; + using (var opts = c_api.TF_NewImportGraphDefOptions()) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); + c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); + EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); + c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); + EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); + results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + } Operation scalar2 = graph.OperationByName("imported2/scalar"); Operation feed2 = graph.OperationByName("imported2/feed"); @@ -294,13 +298,14 @@ namespace TensorFlowNET.UnitTest.NativeAPI c_api.TF_DeleteImportGraphDefResults(results); // Import again, with control dependencies, into the same graph. - c_api.TF_DeleteImportGraphDefOptions(opts); - opts = c_api.TF_NewImportGraphDefOptions(); - c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); - c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); - c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); - c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); - EXPECT_EQ(TF_Code.TF_OK, s.Code); + using (var opts = c_api.TF_NewImportGraphDefOptions()) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); + c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); + c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + } var scalar3 = graph.OperationByName("imported3/scalar"); var feed3 = graph.OperationByName("imported3/feed"); @@ -327,12 +332,13 @@ namespace TensorFlowNET.UnitTest.NativeAPI EXPECT_EQ(TF_Code.TF_OK, s.Code); // Import again, with remapped control dependency, into the same graph - c_api.TF_DeleteImportGraphDefOptions(opts); - opts = c_api.TF_NewImportGraphDefOptions(); - c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); - c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); - c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); - ASSERT_EQ(TF_Code.TF_OK, s.Code); + using (var opts = c_api.TF_NewImportGraphDefOptions()) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); + c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + } var scalar4 = graph.OperationByName("imported4/imported3/scalar"); var feed4 = graph.OperationByName("imported4/imported2/feed"); @@ -344,8 +350,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed4, control_inputs[1]); - c_api.TF_DeleteImportGraphDefOptions(opts); - // Can add nodes to the imported graph without trouble. c_test_util.Add(feed, scalar, graph, s); ASSERT_EQ(TF_Code.TF_OK, s.Code);