| @@ -1,10 +0,0 @@ | |||||
| namespace Tensorflow.Framework.Models | |||||
| { | |||||
| public class ScopedTFImportGraphDefOptions : ImportGraphDefOptions | |||||
| { | |||||
| public ScopedTFImportGraphDefOptions() : base() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,17 +0,0 @@ | |||||
| using System; | |||||
| namespace Tensorflow.Framework.Models | |||||
| { | |||||
| public class ScopedTFImportGraphDefResults : ImportGraphDefOptions | |||||
| { | |||||
| public ScopedTFImportGraphDefResults() : base() | |||||
| { | |||||
| } | |||||
| public ScopedTFImportGraphDefResults(IntPtr results) : base(results) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -62,7 +62,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | ||||
| // need to create a class ImportGraphDefWithResults with IDisposal | // need to create a class ImportGraphDefWithResults with IDisposal | ||||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status.Handle); | |||||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options.Handle, status.Handle); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| @@ -114,8 +114,8 @@ namespace Tensorflow | |||||
| Dictionary<string, Tensor> input_map, | Dictionary<string, Tensor> input_map, | ||||
| string[] return_elements) | string[] return_elements) | ||||
| { | { | ||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix); | |||||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1); | |||||
| foreach(var input in input_map) | foreach(var input in input_map) | ||||
| { | { | ||||
| @@ -130,11 +130,11 @@ namespace Tensorflow | |||||
| if(name.Contains(":")) | if(name.Contains(":")) | ||||
| { | { | ||||
| var (op_name, index) = _ParseTensorName(name); | var (op_name, index) = _ParseTensorName(name); | ||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -29,7 +29,7 @@ namespace Tensorflow | |||||
| int size = Marshal.SizeOf<TF_Output>(); | int size = Marshal.SizeOf<TF_Output>(); | ||||
| var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | ||||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s.Handle); | |||||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts.Handle, return_output_handle, num_return_outputs, s.Handle); | |||||
| var tf_output_ptr = (TF_Output*) return_output_handle; | var tf_output_ptr = (TF_Output*) return_output_handle; | ||||
| for (int i = 0; i < num_return_outputs; i++) | for (int i = 0; i < num_return_outputs; i++) | ||||
| @@ -53,8 +53,8 @@ namespace Tensorflow | |||||
| using (var graph_def = new Buffer(bytes)) | using (var graph_def = new Buffer(bytes)) | ||||
| { | { | ||||
| as_default(); | as_default(); | ||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); | |||||
| c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status.Handle); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix); | |||||
| c_api.TF_GraphImportGraphDef(_handle, graph_def, opts.Handle, status.Handle); | |||||
| status.Check(true); | status.Check(true); | ||||
| return status.Code == TF_Code.TF_OK; | return status.Code == TF_Code.TF_OK; | ||||
| } | } | ||||
| @@ -18,30 +18,24 @@ using System; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class ImportGraphDefOptions : DisposableObject | |||||
| public sealed class ImportGraphDefOptions : IDisposable | |||||
| { | { | ||||
| public SafeImportGraphDefOptionsHandle Handle { get; } | |||||
| public int NumReturnOutputs | public int NumReturnOutputs | ||||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle); | |||||
| public ImportGraphDefOptions() | public ImportGraphDefOptions() | ||||
| { | { | ||||
| _handle = c_api.TF_NewImportGraphDefOptions(); | |||||
| } | |||||
| public ImportGraphDefOptions(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| Handle = c_api.TF_NewImportGraphDefOptions(); | |||||
| } | } | ||||
| public void AddReturnOutput(string name, int index) | public void AddReturnOutput(string name, int index) | ||||
| { | { | ||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index); | |||||
| } | } | ||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteImportGraphDefOptions(handle); | |||||
| public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | |||||
| public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle); | |||||
| public void Dispose() | |||||
| => Handle.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,40 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public sealed class SafeImportGraphDefOptionsHandle : SafeTensorflowHandle | |||||
| { | |||||
| public SafeImportGraphDefOptionsHandle() | |||||
| { | |||||
| } | |||||
| public SafeImportGraphDefOptionsHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| c_api.TF_DeleteImportGraphDefOptions(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -78,7 +78,7 @@ namespace Tensorflow | |||||
| /// <param name="num_return_outputs">int</param> | /// <param name="num_return_outputs">int</param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||||
| public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | ||||
| @@ -92,7 +92,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns>TF_ImportGraphDefResults*</returns> | /// <returns>TF_ImportGraphDefResults*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, SafeStatusHandle status); | |||||
| public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Import the graph serialized in `graph_def` into `graph`. | /// Import the graph serialized in `graph_def` into `graph`. | ||||
| @@ -102,7 +102,7 @@ namespace Tensorflow | |||||
| /// <param name="options">TF_ImportGraphDefOptions*</param> | /// <param name="options">TF_ImportGraphDefOptions*</param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, SafeStatusHandle status); | |||||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Iterate through the operations of a graph. | /// Iterate through the operations of a graph. | ||||
| @@ -160,7 +160,7 @@ namespace Tensorflow | |||||
| /// <param name="opts"></param> | /// <param name="opts"></param> | ||||
| /// <param name="oper"></param> | /// <param name="oper"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsAddControlDependency(IntPtr opts, IntPtr oper); | |||||
| public static extern void TF_ImportGraphDefOptionsAddControlDependency(SafeImportGraphDefOptionsHandle opts, IntPtr oper); | |||||
| /// <summary> | /// <summary> | ||||
| /// Set any imported nodes with input `src_name:src_index` to have that input | /// Set any imported nodes with input `src_name:src_index` to have that input | ||||
| @@ -173,7 +173,7 @@ namespace Tensorflow | |||||
| /// <param name="src_index">int</param> | /// <param name="src_index">int</param> | ||||
| /// <param name="dst">TF_Output</param> | /// <param name="dst">TF_Output</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst); | |||||
| public static extern void TF_ImportGraphDefOptionsAddInputMapping(SafeImportGraphDefOptionsHandle opts, string src_name, int src_index, TF_Output dst); | |||||
| /// <summary> | /// <summary> | ||||
| /// Add an operation in `graph_def` to be returned via the `return_opers` output | /// Add an operation in `graph_def` to be returned via the `return_opers` output | ||||
| @@ -183,7 +183,7 @@ namespace Tensorflow | |||||
| /// <param name="opts">TF_ImportGraphDefOptions* opts</param> | /// <param name="opts">TF_ImportGraphDefOptions* opts</param> | ||||
| /// <param name="oper_name">const char*</param> | /// <param name="oper_name">const char*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name); | |||||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Add an output in `graph_def` to be returned via the `return_outputs` output | /// Add an output in `graph_def` to be returned via the `return_outputs` output | ||||
| @@ -195,7 +195,7 @@ namespace Tensorflow | |||||
| /// <param name="oper_name">const char*</param> | /// <param name="oper_name">const char*</param> | ||||
| /// <param name="index">int</param> | /// <param name="index">int</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index); | |||||
| public static extern void TF_ImportGraphDefOptionsAddReturnOutput(SafeImportGraphDefOptionsHandle opts, string oper_name, int index); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the number of return operations added via | /// Returns the number of return operations added via | ||||
| @@ -204,7 +204,7 @@ namespace Tensorflow | |||||
| /// <param name="opts"></param> | /// <param name="opts"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts); | |||||
| public static extern int TF_ImportGraphDefOptionsNumReturnOperations(SafeImportGraphDefOptionsHandle opts); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the number of return outputs added via | /// Returns the number of return outputs added via | ||||
| @@ -213,7 +213,7 @@ namespace Tensorflow | |||||
| /// <param name="opts">const TF_ImportGraphDefOptions*</param> | /// <param name="opts">const TF_ImportGraphDefOptions*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); | |||||
| public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(SafeImportGraphDefOptionsHandle opts); | |||||
| /// <summary> | /// <summary> | ||||
| /// Set any imported nodes with control input `src_name` to have that input | /// Set any imported nodes with control input `src_name` to have that input | ||||
| @@ -225,7 +225,7 @@ namespace Tensorflow | |||||
| /// <param name="src_name">const char*</param> | /// <param name="src_name">const char*</param> | ||||
| /// <param name="dst">TF_Operation*</param> | /// <param name="dst">TF_Operation*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsRemapControlDependency(IntPtr opts, string src_name, IntPtr dst); | |||||
| public static extern void TF_ImportGraphDefOptionsRemapControlDependency(SafeImportGraphDefOptionsHandle opts, string src_name, IntPtr dst); | |||||
| /// <summary> | /// <summary> | ||||
| /// Set the prefix to be prepended to the names of nodes in `graph_def` that will | /// Set the prefix to be prepended to the names of nodes in `graph_def` that will | ||||
| @@ -234,7 +234,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ops"></param> | /// <param name="ops"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); | |||||
| public static extern void TF_ImportGraphDefOptionsSetPrefix(SafeImportGraphDefOptionsHandle ops, string prefix); | |||||
| /// <summary> | /// <summary> | ||||
| /// Set whether to uniquify imported operation names. If true, imported operation | /// Set whether to uniquify imported operation names. If true, imported operation | ||||
| @@ -246,7 +246,7 @@ namespace Tensorflow | |||||
| /// <param name="ops">TF_ImportGraphDefOptions*</param> | /// <param name="ops">TF_ImportGraphDefOptions*</param> | ||||
| /// <param name="uniquify_prefix">unsigned char</param> | /// <param name="uniquify_prefix">unsigned char</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(IntPtr ops, char uniquify_prefix); | |||||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); | |||||
| /// <summary> | /// <summary> | ||||
| /// Fetches the return operations requested via | /// Fetches the return operations requested via | ||||
| @@ -295,7 +295,7 @@ namespace Tensorflow | |||||
| public static extern IntPtr TF_NewGraph(); | public static extern IntPtr TF_NewGraph(); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_NewImportGraphDefOptions(); | |||||
| public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Updates 'dst' to consume 'new_src'. | /// Updates 'dst' to consume 'new_src'. | ||||
| @@ -222,10 +222,12 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| // Import it, with a prefix, in a fresh graph. | // Import it, with a prefix, in a fresh graph. | ||||
| graph.Dispose(); | graph.Dispose(); | ||||
| graph = new Graph().as_default(); | graph = new Graph().as_default(); | ||||
| var opts = c_api.TF_NewImportGraphDefOptions(); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||||
| { | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| } | |||||
| Operation scalar = graph.OperationByName("imported/scalar"); | Operation scalar = graph.OperationByName("imported/scalar"); | ||||
| Operation feed = graph.OperationByName("imported/feed"); | Operation feed = graph.OperationByName("imported/feed"); | ||||
| @@ -258,17 +260,19 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| // Import it again, with an input mapping, return outputs, and a return | // Import it again, with an input mapping, return outputs, and a return | ||||
| // operation, into the same graph. | // operation, into the same graph. | ||||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | |||||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); | |||||
| EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| IntPtr results; | |||||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||||
| { | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | |||||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); | |||||
| EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| } | |||||
| Operation scalar2 = graph.OperationByName("imported2/scalar"); | Operation scalar2 = graph.OperationByName("imported2/scalar"); | ||||
| Operation feed2 = graph.OperationByName("imported2/feed"); | Operation feed2 = graph.OperationByName("imported2/feed"); | ||||
| @@ -294,13 +298,14 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| c_api.TF_DeleteImportGraphDefResults(results); | c_api.TF_DeleteImportGraphDefResults(results); | ||||
| // Import again, with control dependencies, into the same graph. | // Import again, with control dependencies, into the same graph. | ||||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | |||||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | |||||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||||
| { | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | |||||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | |||||
| c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| } | |||||
| var scalar3 = graph.OperationByName("imported3/scalar"); | var scalar3 = graph.OperationByName("imported3/scalar"); | ||||
| var feed3 = graph.OperationByName("imported3/feed"); | var feed3 = graph.OperationByName("imported3/feed"); | ||||
| @@ -327,12 +332,13 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
| // Import again, with remapped control dependency, into the same graph | // Import again, with remapped control dependency, into the same graph | ||||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | |||||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | |||||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||||
| { | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | |||||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | |||||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| } | |||||
| var scalar4 = graph.OperationByName("imported4/imported3/scalar"); | var scalar4 = graph.OperationByName("imported4/imported3/scalar"); | ||||
| var feed4 = graph.OperationByName("imported4/imported2/feed"); | var feed4 = graph.OperationByName("imported4/imported2/feed"); | ||||
| @@ -344,8 +350,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| EXPECT_EQ(feed, control_inputs[0]); | EXPECT_EQ(feed, control_inputs[0]); | ||||
| EXPECT_EQ(feed4, control_inputs[1]); | EXPECT_EQ(feed4, control_inputs[1]); | ||||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||||
| // Can add nodes to the imported graph without trouble. | // Can add nodes to the imported graph without trouble. | ||||
| c_test_util.Add(feed, scalar, graph, s); | c_test_util.Add(feed, scalar, graph, s); | ||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||