diff --git a/src/TensorFlowNET.Core/Framework/importer.cs b/src/TensorFlowNET.Core/Framework/importer.cs index 8dd30858..c28579be 100644 --- a/src/TensorFlowNET.Core/Framework/importer.cs +++ b/src/TensorFlowNET.Core/Framework/importer.cs @@ -62,7 +62,7 @@ namespace Tensorflow { _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); // need to create a class ImportGraphDefWithResults with IDisposal - results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle); + results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle)); status.Check(true); } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index d7f2ce2a..d17f6591 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -33,7 +33,7 @@ namespace Tensorflow return c_api.TF_NewOperation(_handle, opType, opName); } - public Operation[] ReturnOperations(IntPtr results) + public Operation[] ReturnOperations(SafeImportGraphDefResultsHandle results) { TF_Operation return_oper_handle = new TF_Operation(); int num_return_opers = 0; diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 69192634..cfa782ea 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -413,7 +413,7 @@ namespace Tensorflow return name; } - public TF_Output[] ReturnOutputs(IntPtr results) + public TF_Output[] ReturnOutputs(SafeImportGraphDefResultsHandle results) { IntPtr return_output_handle = IntPtr.Zero; int num_return_outputs = 0; diff --git a/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs new file mode 100644 index 00000000..8a84eff8 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeImportGraphDefResultsHandle : SafeTensorflowHandle + { + private SafeImportGraphDefResultsHandle() + { + } + + public SafeImportGraphDefResultsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteImportGraphDefResults(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs index 71ea5306..eff8be94 100644 --- a/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs +++ b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs @@ -1,18 +1,35 @@ -using System; -using System.Runtime.InteropServices; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; namespace Tensorflow { - public class TF_ImportGraphDefResults : DisposableObject + public sealed class TF_ImportGraphDefResults : IDisposable { /*public IntPtr return_nodes; public IntPtr missing_unused_key_names; public IntPtr missing_unused_key_indexes; public IntPtr missing_unused_key_names_data;*/ - public TF_ImportGraphDefResults(IntPtr handle) + private SafeImportGraphDefResultsHandle Handle { get; } + + public TF_ImportGraphDefResults(SafeImportGraphDefResultsHandle handle) { - _handle = handle; + Handle = handle; } public TF_Output[] return_tensors @@ -21,7 +38,7 @@ namespace Tensorflow { IntPtr return_output_handle = IntPtr.Zero; int num_outputs = -1; - c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle); + c_api.TF_ImportGraphDefResultsReturnOutputs(Handle, ref num_outputs, ref return_output_handle); TF_Output[] return_outputs = new TF_Output[num_outputs]; unsafe { @@ -52,13 +69,7 @@ namespace Tensorflow } } - public static implicit operator TF_ImportGraphDefResults(IntPtr handle) - => new TF_ImportGraphDefResults(handle); - - public static implicit operator IntPtr(TF_ImportGraphDefResults results) - => results._handle; - - protected override void DisposeUnmanagedResources(IntPtr handle) - => c_api.TF_DeleteImportGraphDefResults(handle); + public void Dispose() + => Handle.Dispose(); } } diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 94d6d63e..74e21f03 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -92,7 +92,7 @@ namespace Tensorflow /// TF_Status* /// TF_ImportGraphDefResults* [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); + public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); /// /// Import the graph serialized in `graph_def` into `graph`. @@ -258,7 +258,7 @@ namespace Tensorflow /// int* /// TF_Operation*** [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers); + public static extern void TF_ImportGraphDefResultsReturnOperations(SafeImportGraphDefResultsHandle results, ref int num_opers, ref TF_Operation opers); /// /// Fetches the return outputs requested via @@ -270,7 +270,7 @@ namespace Tensorflow /// int* /// TF_Output** [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); + public static extern void TF_ImportGraphDefResultsReturnOutputs(SafeImportGraphDefResultsHandle results, ref int num_outputs, ref IntPtr outputs); /// /// This function creates a new TF_Session (which is created on success) using diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 868fd1b8..5bd72a2b 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -258,11 +258,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI EXPECT_EQ(0, neg.NumControlOutputs); EXPECT_EQ(0, neg.GetControlOutputs().Length); - // Import it again, with an input mapping, return outputs, and a return - // operation, into the same graph. - IntPtr results; - using (var opts = c_api.TF_NewImportGraphDefOptions()) + static SafeImportGraphDefResultsHandle ImportGraph(Status s, Graph graph, Buffer graph_def, Operation scalar) { + using var opts = c_api.TF_NewImportGraphDefOptions(); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); @@ -270,32 +268,39 @@ namespace TensorFlowNET.UnitTest.NativeAPI EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); - results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); + var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); EXPECT_EQ(TF_Code.TF_OK, s.Code); - } - - Operation scalar2 = graph.OperationByName("imported2/scalar"); - Operation feed2 = graph.OperationByName("imported2/feed"); - Operation neg2 = graph.OperationByName("imported2/neg"); - - // Check input mapping - neg_input = neg.Input(0); - EXPECT_EQ(scalar, neg_input.oper); - EXPECT_EQ(0, neg_input.index); - // Check return outputs - var return_outputs = graph.ReturnOutputs(results); - ASSERT_EQ(2, return_outputs.Length); - EXPECT_EQ(feed2, return_outputs[0].oper); - EXPECT_EQ(0, return_outputs[0].index); - EXPECT_EQ(scalar, return_outputs[1].oper); // remapped - EXPECT_EQ(0, return_outputs[1].index); + return results; + } - // Check return operation - var return_opers = graph.ReturnOperations(results); - ASSERT_EQ(1, return_opers.Length); - EXPECT_EQ(scalar2, return_opers[0]); // not remapped - c_api.TF_DeleteImportGraphDefResults(results); + // Import it again, with an input mapping, return outputs, and a return + // operation, into the same graph. + Operation feed2; + using (SafeImportGraphDefResultsHandle results = ImportGraph(s, graph, graph_def, scalar)) + { + Operation scalar2 = graph.OperationByName("imported2/scalar"); + feed2 = graph.OperationByName("imported2/feed"); + Operation neg2 = graph.OperationByName("imported2/neg"); + + // Check input mapping + neg_input = neg.Input(0); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Check return outputs + var return_outputs = graph.ReturnOutputs(results); + ASSERT_EQ(2, return_outputs.Length); + EXPECT_EQ(feed2, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); // remapped + EXPECT_EQ(0, return_outputs[1].index); + + // Check return operation + var return_opers = graph.ReturnOperations(results); + ASSERT_EQ(1, return_opers.Length); + EXPECT_EQ(scalar2, return_opers[0]); // not remapped + } // Import again, with control dependencies, into the same graph. using (var opts = c_api.TF_NewImportGraphDefOptions())