| @@ -62,7 +62,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | ||||
| // need to create a class ImportGraphDefWithResults with IDisposal | // need to create a class ImportGraphDefWithResults with IDisposal | ||||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle); | |||||
| results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle)); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| @@ -33,7 +33,7 @@ namespace Tensorflow | |||||
| return c_api.TF_NewOperation(_handle, opType, opName); | return c_api.TF_NewOperation(_handle, opType, opName); | ||||
| } | } | ||||
| public Operation[] ReturnOperations(IntPtr results) | |||||
| public Operation[] ReturnOperations(SafeImportGraphDefResultsHandle results) | |||||
| { | { | ||||
| TF_Operation return_oper_handle = new TF_Operation(); | TF_Operation return_oper_handle = new TF_Operation(); | ||||
| int num_return_opers = 0; | int num_return_opers = 0; | ||||
| @@ -413,7 +413,7 @@ namespace Tensorflow | |||||
| return name; | return name; | ||||
| } | } | ||||
| public TF_Output[] ReturnOutputs(IntPtr results) | |||||
| public TF_Output[] ReturnOutputs(SafeImportGraphDefResultsHandle results) | |||||
| { | { | ||||
| IntPtr return_output_handle = IntPtr.Zero; | IntPtr return_output_handle = IntPtr.Zero; | ||||
| int num_return_outputs = 0; | int num_return_outputs = 0; | ||||
| @@ -0,0 +1,40 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public sealed class SafeImportGraphDefResultsHandle : SafeTensorflowHandle | |||||
| { | |||||
| private SafeImportGraphDefResultsHandle() | |||||
| { | |||||
| } | |||||
| public SafeImportGraphDefResultsHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| c_api.TF_DeleteImportGraphDefResults(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,18 +1,35 @@ | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class TF_ImportGraphDefResults : DisposableObject | |||||
| public sealed class TF_ImportGraphDefResults : IDisposable | |||||
| { | { | ||||
| /*public IntPtr return_nodes; | /*public IntPtr return_nodes; | ||||
| public IntPtr missing_unused_key_names; | public IntPtr missing_unused_key_names; | ||||
| public IntPtr missing_unused_key_indexes; | public IntPtr missing_unused_key_indexes; | ||||
| public IntPtr missing_unused_key_names_data;*/ | public IntPtr missing_unused_key_names_data;*/ | ||||
| public TF_ImportGraphDefResults(IntPtr handle) | |||||
| private SafeImportGraphDefResultsHandle Handle { get; } | |||||
| public TF_ImportGraphDefResults(SafeImportGraphDefResultsHandle handle) | |||||
| { | { | ||||
| _handle = handle; | |||||
| Handle = handle; | |||||
| } | } | ||||
| public TF_Output[] return_tensors | public TF_Output[] return_tensors | ||||
| @@ -21,7 +38,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| IntPtr return_output_handle = IntPtr.Zero; | IntPtr return_output_handle = IntPtr.Zero; | ||||
| int num_outputs = -1; | int num_outputs = -1; | ||||
| c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle); | |||||
| c_api.TF_ImportGraphDefResultsReturnOutputs(Handle, ref num_outputs, ref return_output_handle); | |||||
| TF_Output[] return_outputs = new TF_Output[num_outputs]; | TF_Output[] return_outputs = new TF_Output[num_outputs]; | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| @@ -52,13 +69,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public static implicit operator TF_ImportGraphDefResults(IntPtr handle) | |||||
| => new TF_ImportGraphDefResults(handle); | |||||
| public static implicit operator IntPtr(TF_ImportGraphDefResults results) | |||||
| => results._handle; | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteImportGraphDefResults(handle); | |||||
| public void Dispose() | |||||
| => Handle.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -92,7 +92,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns>TF_ImportGraphDefResults*</returns> | /// <returns>TF_ImportGraphDefResults*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||||
| public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Import the graph serialized in `graph_def` into `graph`. | /// Import the graph serialized in `graph_def` into `graph`. | ||||
| @@ -258,7 +258,7 @@ namespace Tensorflow | |||||
| /// <param name="num_opers">int*</param> | /// <param name="num_opers">int*</param> | ||||
| /// <param name="opers">TF_Operation***</param> | /// <param name="opers">TF_Operation***</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers); | |||||
| public static extern void TF_ImportGraphDefResultsReturnOperations(SafeImportGraphDefResultsHandle results, ref int num_opers, ref TF_Operation opers); | |||||
| /// <summary> | /// <summary> | ||||
| /// Fetches the return outputs requested via | /// Fetches the return outputs requested via | ||||
| @@ -270,7 +270,7 @@ namespace Tensorflow | |||||
| /// <param name="num_outputs">int*</param> | /// <param name="num_outputs">int*</param> | ||||
| /// <param name="outputs">TF_Output**</param> | /// <param name="outputs">TF_Output**</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | |||||
| public static extern void TF_ImportGraphDefResultsReturnOutputs(SafeImportGraphDefResultsHandle results, ref int num_outputs, ref IntPtr outputs); | |||||
| /// <summary> | /// <summary> | ||||
| /// This function creates a new TF_Session (which is created on success) using | /// This function creates a new TF_Session (which is created on success) using | ||||
| @@ -258,11 +258,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| EXPECT_EQ(0, neg.NumControlOutputs); | EXPECT_EQ(0, neg.NumControlOutputs); | ||||
| EXPECT_EQ(0, neg.GetControlOutputs().Length); | EXPECT_EQ(0, neg.GetControlOutputs().Length); | ||||
| // Import it again, with an input mapping, return outputs, and a return | |||||
| // operation, into the same graph. | |||||
| IntPtr results; | |||||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||||
| static SafeImportGraphDefResultsHandle ImportGraph(Status s, Graph graph, Buffer graph_def, Operation scalar) | |||||
| { | { | ||||
| using var opts = c_api.TF_NewImportGraphDefOptions(); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | ||||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | ||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | ||||
| @@ -270,32 +268,39 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | ||||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | ||||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | ||||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); | |||||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
| } | |||||
| Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||||
| Operation feed2 = graph.OperationByName("imported2/feed"); | |||||
| Operation neg2 = graph.OperationByName("imported2/neg"); | |||||
| // Check input mapping | |||||
| neg_input = neg.Input(0); | |||||
| EXPECT_EQ(scalar, neg_input.oper); | |||||
| EXPECT_EQ(0, neg_input.index); | |||||
| // Check return outputs | |||||
| var return_outputs = graph.ReturnOutputs(results); | |||||
| ASSERT_EQ(2, return_outputs.Length); | |||||
| EXPECT_EQ(feed2, return_outputs[0].oper); | |||||
| EXPECT_EQ(0, return_outputs[0].index); | |||||
| EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||||
| EXPECT_EQ(0, return_outputs[1].index); | |||||
| return results; | |||||
| } | |||||
| // Check return operation | |||||
| var return_opers = graph.ReturnOperations(results); | |||||
| ASSERT_EQ(1, return_opers.Length); | |||||
| EXPECT_EQ(scalar2, return_opers[0]); // not remapped | |||||
| c_api.TF_DeleteImportGraphDefResults(results); | |||||
| // Import it again, with an input mapping, return outputs, and a return | |||||
| // operation, into the same graph. | |||||
| Operation feed2; | |||||
| using (SafeImportGraphDefResultsHandle results = ImportGraph(s, graph, graph_def, scalar)) | |||||
| { | |||||
| Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||||
| feed2 = graph.OperationByName("imported2/feed"); | |||||
| Operation neg2 = graph.OperationByName("imported2/neg"); | |||||
| // Check input mapping | |||||
| neg_input = neg.Input(0); | |||||
| EXPECT_EQ(scalar, neg_input.oper); | |||||
| EXPECT_EQ(0, neg_input.index); | |||||
| // Check return outputs | |||||
| var return_outputs = graph.ReturnOutputs(results); | |||||
| ASSERT_EQ(2, return_outputs.Length); | |||||
| EXPECT_EQ(feed2, return_outputs[0].oper); | |||||
| EXPECT_EQ(0, return_outputs[0].index); | |||||
| EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||||
| EXPECT_EQ(0, return_outputs[1].index); | |||||
| // Check return operation | |||||
| var return_opers = graph.ReturnOperations(results); | |||||
| ASSERT_EQ(1, return_opers.Length); | |||||
| EXPECT_EQ(scalar2, return_opers[0]); // not remapped | |||||
| } | |||||
| // Import again, with control dependencies, into the same graph. | // Import again, with control dependencies, into the same graph. | ||||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | using (var opts = c_api.TF_NewImportGraphDefOptions()) | ||||