diff --git a/src/TensorFlowNET.Core/Framework/importer.cs b/src/TensorFlowNET.Core/Framework/importer.cs
index 8dd30858..c28579be 100644
--- a/src/TensorFlowNET.Core/Framework/importer.cs
+++ b/src/TensorFlowNET.Core/Framework/importer.cs
@@ -62,7 +62,7 @@ namespace Tensorflow
{
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal
- results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle);
+ results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle));
status.Check(true);
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
index d7f2ce2a..d17f6591 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
@@ -33,7 +33,7 @@ namespace Tensorflow
return c_api.TF_NewOperation(_handle, opType, opName);
}
- public Operation[] ReturnOperations(IntPtr results)
+ public Operation[] ReturnOperations(SafeImportGraphDefResultsHandle results)
{
TF_Operation return_oper_handle = new TF_Operation();
int num_return_opers = 0;
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 69192634..cfa782ea 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -413,7 +413,7 @@ namespace Tensorflow
return name;
}
- public TF_Output[] ReturnOutputs(IntPtr results)
+ public TF_Output[] ReturnOutputs(SafeImportGraphDefResultsHandle results)
{
IntPtr return_output_handle = IntPtr.Zero;
int num_return_outputs = 0;
diff --git a/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs
new file mode 100644
index 00000000..8a84eff8
--- /dev/null
+++ b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs
@@ -0,0 +1,40 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using Tensorflow.Util;
+
+namespace Tensorflow
+{
+ public sealed class SafeImportGraphDefResultsHandle : SafeTensorflowHandle
+ {
+ private SafeImportGraphDefResultsHandle()
+ {
+ }
+
+ public SafeImportGraphDefResultsHandle(IntPtr handle)
+ : base(handle)
+ {
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ c_api.TF_DeleteImportGraphDefResults(handle);
+ SetHandle(IntPtr.Zero);
+ return true;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs
index 71ea5306..eff8be94 100644
--- a/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs
+++ b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs
@@ -1,18 +1,35 @@
-using System;
-using System.Runtime.InteropServices;
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
namespace Tensorflow
{
- public class TF_ImportGraphDefResults : DisposableObject
+ public sealed class TF_ImportGraphDefResults : IDisposable
{
/*public IntPtr return_nodes;
public IntPtr missing_unused_key_names;
public IntPtr missing_unused_key_indexes;
public IntPtr missing_unused_key_names_data;*/
- public TF_ImportGraphDefResults(IntPtr handle)
+ private SafeImportGraphDefResultsHandle Handle { get; }
+
+ public TF_ImportGraphDefResults(SafeImportGraphDefResultsHandle handle)
{
- _handle = handle;
+ Handle = handle;
}
public TF_Output[] return_tensors
@@ -21,7 +38,7 @@ namespace Tensorflow
{
IntPtr return_output_handle = IntPtr.Zero;
int num_outputs = -1;
- c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle);
+ c_api.TF_ImportGraphDefResultsReturnOutputs(Handle, ref num_outputs, ref return_output_handle);
TF_Output[] return_outputs = new TF_Output[num_outputs];
unsafe
{
@@ -52,13 +69,7 @@ namespace Tensorflow
}
}
- public static implicit operator TF_ImportGraphDefResults(IntPtr handle)
- => new TF_ImportGraphDefResults(handle);
-
- public static implicit operator IntPtr(TF_ImportGraphDefResults results)
- => results._handle;
-
- protected override void DisposeUnmanagedResources(IntPtr handle)
- => c_api.TF_DeleteImportGraphDefResults(handle);
+ public void Dispose()
+ => Handle.Dispose();
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
index 94d6d63e..74e21f03 100644
--- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
@@ -92,7 +92,7 @@ namespace Tensorflow
/// TF_Status*
/// TF_ImportGraphDefResults*
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
+ public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
///
/// Import the graph serialized in `graph_def` into `graph`.
@@ -258,7 +258,7 @@ namespace Tensorflow
/// int*
/// TF_Operation***
[DllImport(TensorFlowLibName)]
- public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers);
+ public static extern void TF_ImportGraphDefResultsReturnOperations(SafeImportGraphDefResultsHandle results, ref int num_opers, ref TF_Operation opers);
///
/// Fetches the return outputs requested via
@@ -270,7 +270,7 @@ namespace Tensorflow
/// int*
/// TF_Output**
[DllImport(TensorFlowLibName)]
- public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs);
+ public static extern void TF_ImportGraphDefResultsReturnOutputs(SafeImportGraphDefResultsHandle results, ref int num_outputs, ref IntPtr outputs);
///
/// This function creates a new TF_Session (which is created on success) using
diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs
index 868fd1b8..5bd72a2b 100644
--- a/test/TensorFlowNET.UnitTest/GraphTest.cs
+++ b/test/TensorFlowNET.UnitTest/GraphTest.cs
@@ -258,11 +258,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(0, neg.NumControlOutputs);
EXPECT_EQ(0, neg.GetControlOutputs().Length);
- // Import it again, with an input mapping, return outputs, and a return
- // operation, into the same graph.
- IntPtr results;
- using (var opts = c_api.TF_NewImportGraphDefOptions())
+ static SafeImportGraphDefResultsHandle ImportGraph(Status s, Graph graph, Buffer graph_def, Operation scalar)
{
+ using var opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0));
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
@@ -270,32 +268,39 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
- results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
+ var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
- }
-
- Operation scalar2 = graph.OperationByName("imported2/scalar");
- Operation feed2 = graph.OperationByName("imported2/feed");
- Operation neg2 = graph.OperationByName("imported2/neg");
-
- // Check input mapping
- neg_input = neg.Input(0);
- EXPECT_EQ(scalar, neg_input.oper);
- EXPECT_EQ(0, neg_input.index);
- // Check return outputs
- var return_outputs = graph.ReturnOutputs(results);
- ASSERT_EQ(2, return_outputs.Length);
- EXPECT_EQ(feed2, return_outputs[0].oper);
- EXPECT_EQ(0, return_outputs[0].index);
- EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
- EXPECT_EQ(0, return_outputs[1].index);
+ return results;
+ }
- // Check return operation
- var return_opers = graph.ReturnOperations(results);
- ASSERT_EQ(1, return_opers.Length);
- EXPECT_EQ(scalar2, return_opers[0]); // not remapped
- c_api.TF_DeleteImportGraphDefResults(results);
+ // Import it again, with an input mapping, return outputs, and a return
+ // operation, into the same graph.
+ Operation feed2;
+ using (SafeImportGraphDefResultsHandle results = ImportGraph(s, graph, graph_def, scalar))
+ {
+ Operation scalar2 = graph.OperationByName("imported2/scalar");
+ feed2 = graph.OperationByName("imported2/feed");
+ Operation neg2 = graph.OperationByName("imported2/neg");
+
+ // Check input mapping
+ neg_input = neg.Input(0);
+ EXPECT_EQ(scalar, neg_input.oper);
+ EXPECT_EQ(0, neg_input.index);
+
+ // Check return outputs
+ var return_outputs = graph.ReturnOutputs(results);
+ ASSERT_EQ(2, return_outputs.Length);
+ EXPECT_EQ(feed2, return_outputs[0].oper);
+ EXPECT_EQ(0, return_outputs[0].index);
+ EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
+ EXPECT_EQ(0, return_outputs[1].index);
+
+ // Check return operation
+ var return_opers = graph.ReturnOperations(results);
+ ASSERT_EQ(1, return_opers.Length);
+ EXPECT_EQ(scalar2, return_opers[0]); // not remapped
+ }
// Import again, with control dependencies, into the same graph.
using (var opts = c_api.TF_NewImportGraphDefOptions())