| @@ -1,10 +0,0 @@ | |||
| namespace Tensorflow.Framework.Models | |||
| { | |||
| public class ScopedTFImportGraphDefOptions : ImportGraphDefOptions | |||
| { | |||
| public ScopedTFImportGraphDefOptions() : base() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -1,17 +0,0 @@ | |||
| using System; | |||
| namespace Tensorflow.Framework.Models | |||
| { | |||
| public class ScopedTFImportGraphDefResults : ImportGraphDefOptions | |||
| { | |||
| public ScopedTFImportGraphDefResults() : base() | |||
| { | |||
| } | |||
| public ScopedTFImportGraphDefResults(IntPtr results) : base(results) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -62,7 +62,7 @@ namespace Tensorflow | |||
| { | |||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
| // need to create a class ImportGraphDefWithResults with IDisposal | |||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status.Handle); | |||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options.Handle, status.Handle); | |||
| status.Check(true); | |||
| } | |||
| @@ -114,8 +114,8 @@ namespace Tensorflow | |||
| Dictionary<string, Tensor> input_map, | |||
| string[] return_elements) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix); | |||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1); | |||
| foreach(var input in input_map) | |||
| { | |||
| @@ -130,11 +130,11 @@ namespace Tensorflow | |||
| if(name.Contains(":")) | |||
| { | |||
| var (op_name, index) = _ParseTensorName(name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index); | |||
| } | |||
| else | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name); | |||
| } | |||
| } | |||
| @@ -29,7 +29,7 @@ namespace Tensorflow | |||
| int size = Marshal.SizeOf<TF_Output>(); | |||
| var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | |||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s.Handle); | |||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts.Handle, return_output_handle, num_return_outputs, s.Handle); | |||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| @@ -53,8 +53,8 @@ namespace Tensorflow | |||
| using (var graph_def = new Buffer(bytes)) | |||
| { | |||
| as_default(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); | |||
| c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status.Handle); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix); | |||
| c_api.TF_GraphImportGraphDef(_handle, graph_def, opts.Handle, status.Handle); | |||
| status.Check(true); | |||
| return status.Code == TF_Code.TF_OK; | |||
| } | |||
| @@ -18,30 +18,24 @@ using System; | |||
| namespace Tensorflow | |||
| { | |||
| public class ImportGraphDefOptions : DisposableObject | |||
| public sealed class ImportGraphDefOptions : IDisposable | |||
| { | |||
| public SafeImportGraphDefOptionsHandle Handle { get; } | |||
| public int NumReturnOutputs | |||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle); | |||
| public ImportGraphDefOptions() | |||
| { | |||
| _handle = c_api.TF_NewImportGraphDefOptions(); | |||
| } | |||
| public ImportGraphDefOptions(IntPtr handle) | |||
| { | |||
| _handle = handle; | |||
| Handle = c_api.TF_NewImportGraphDefOptions(); | |||
| } | |||
| public void AddReturnOutput(string name, int index) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index); | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TF_DeleteImportGraphDefOptions(handle); | |||
| public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | |||
| public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); | |||
| public void Dispose() | |||
| => Handle.Dispose(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| public sealed class SafeImportGraphDefOptionsHandle : SafeTensorflowHandle | |||
| { | |||
| public SafeImportGraphDefOptionsHandle() | |||
| { | |||
| } | |||
| public SafeImportGraphDefOptionsHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| } | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| c_api.TF_DeleteImportGraphDefOptions(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| @@ -78,7 +78,7 @@ namespace Tensorflow | |||
| /// <param name="num_return_outputs">int</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||
| public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||
| @@ -92,7 +92,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns>TF_ImportGraphDefResults*</returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, SafeStatusHandle status); | |||
| public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. | |||
| @@ -102,7 +102,7 @@ namespace Tensorflow | |||
| /// <param name="options">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, SafeStatusHandle status); | |||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Iterate through the operations of a graph. | |||
| @@ -160,7 +160,7 @@ namespace Tensorflow | |||
| /// <param name="opts"></param> | |||
| /// <param name="oper"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddControlDependency(IntPtr opts, IntPtr oper); | |||
| public static extern void TF_ImportGraphDefOptionsAddControlDependency(SafeImportGraphDefOptionsHandle opts, IntPtr oper); | |||
| /// <summary> | |||
| /// Set any imported nodes with input `src_name:src_index` to have that input | |||
| @@ -173,7 +173,7 @@ namespace Tensorflow | |||
| /// <param name="src_index">int</param> | |||
| /// <param name="dst">TF_Output</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst); | |||
| public static extern void TF_ImportGraphDefOptionsAddInputMapping(SafeImportGraphDefOptionsHandle opts, string src_name, int src_index, TF_Output dst); | |||
| /// <summary> | |||
| /// Add an operation in `graph_def` to be returned via the `return_opers` output | |||
| @@ -183,7 +183,7 @@ namespace Tensorflow | |||
| /// <param name="opts">TF_ImportGraphDefOptions* opts</param> | |||
| /// <param name="oper_name">const char*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name); | |||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | |||
| /// <summary> | |||
| /// Add an output in `graph_def` to be returned via the `return_outputs` output | |||
| @@ -195,7 +195,7 @@ namespace Tensorflow | |||
| /// <param name="oper_name">const char*</param> | |||
| /// <param name="index">int</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index); | |||
| public static extern void TF_ImportGraphDefOptionsAddReturnOutput(SafeImportGraphDefOptionsHandle opts, string oper_name, int index); | |||
| /// <summary> | |||
| /// Returns the number of return operations added via | |||
| @@ -204,7 +204,7 @@ namespace Tensorflow | |||
| /// <param name="opts"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts); | |||
| public static extern int TF_ImportGraphDefOptionsNumReturnOperations(SafeImportGraphDefOptionsHandle opts); | |||
| /// <summary> | |||
| /// Returns the number of return outputs added via | |||
| @@ -213,7 +213,7 @@ namespace Tensorflow | |||
| /// <param name="opts">const TF_ImportGraphDefOptions*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); | |||
| public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(SafeImportGraphDefOptionsHandle opts); | |||
| /// <summary> | |||
| /// Set any imported nodes with control input `src_name` to have that input | |||
| @@ -225,7 +225,7 @@ namespace Tensorflow | |||
| /// <param name="src_name">const char*</param> | |||
| /// <param name="dst">TF_Operation*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsRemapControlDependency(IntPtr opts, string src_name, IntPtr dst); | |||
| public static extern void TF_ImportGraphDefOptionsRemapControlDependency(SafeImportGraphDefOptionsHandle opts, string src_name, IntPtr dst); | |||
| /// <summary> | |||
| /// Set the prefix to be prepended to the names of nodes in `graph_def` that will | |||
| @@ -234,7 +234,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="ops"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); | |||
| public static extern void TF_ImportGraphDefOptionsSetPrefix(SafeImportGraphDefOptionsHandle ops, string prefix); | |||
| /// <summary> | |||
| /// Set whether to uniquify imported operation names. If true, imported operation | |||
| @@ -246,7 +246,7 @@ namespace Tensorflow | |||
| /// <param name="ops">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="uniquify_prefix">unsigned char</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(IntPtr ops, char uniquify_prefix); | |||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); | |||
| /// <summary> | |||
| /// Fetches the return operations requested via | |||
| @@ -295,7 +295,7 @@ namespace Tensorflow | |||
| public static extern IntPtr TF_NewGraph(); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_NewImportGraphDefOptions(); | |||
| public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); | |||
| /// <summary> | |||
| /// Updates 'dst' to consume 'new_src'. | |||
| @@ -222,10 +222,12 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| // Import it, with a prefix, in a fresh graph. | |||
| graph.Dispose(); | |||
| graph = new Graph().as_default(); | |||
| var opts = c_api.TF_NewImportGraphDefOptions(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| Operation scalar = graph.OperationByName("imported/scalar"); | |||
| Operation feed = graph.OperationByName("imported/feed"); | |||
| @@ -258,17 +260,19 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| // Import it again, with an input mapping, return outputs, and a return | |||
| // operation, into the same graph. | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); | |||
| EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| IntPtr results; | |||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); | |||
| EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| Operation scalar2 = graph.OperationByName("imported2/scalar"); | |||
| Operation feed2 = graph.OperationByName("imported2/feed"); | |||
| @@ -294,13 +298,14 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| c_api.TF_DeleteImportGraphDefResults(results); | |||
| // Import again, with control dependencies, into the same graph. | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | |||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| var scalar3 = graph.OperationByName("imported3/scalar"); | |||
| var feed3 = graph.OperationByName("imported3/feed"); | |||
| @@ -327,12 +332,13 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Import again, with remapped control dependency, into the same graph | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | |||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | |||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| var scalar4 = graph.OperationByName("imported4/imported3/scalar"); | |||
| var feed4 = graph.OperationByName("imported4/imported2/feed"); | |||
| @@ -344,8 +350,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| EXPECT_EQ(feed, control_inputs[0]); | |||
| EXPECT_EQ(feed4, control_inputs[1]); | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| // Can add nodes to the imported graph without trouble. | |||
| c_test_util.Add(feed, scalar, graph, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||