Browse Source

Merge pull request #564 from sharwell/safe-import-graph-def-options-handle

Implement SafeImportGraphDefOptionsHandle as a wrapper for TF_ImportGraphDefOptions
tags/v0.20
Haiping GitHub 5 years ago
parent
commit
7e18dcf3ee
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 103 additions and 92 deletions
  1. +0
    -10
      src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs
  2. +0
    -17
      src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs
  3. +5
    -5
      src/TensorFlowNET.Core/Framework/importer.cs
  4. +3
    -3
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  5. +8
    -14
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  6. +40
    -0
      src/TensorFlowNET.Core/Graphs/SafeImportGraphDefOptionsHandle.cs
  7. +13
    -13
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  8. +34
    -30
      test/TensorFlowNET.UnitTest/GraphTest.cs

+ 0
- 10
src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs View File

@@ -1,10 +0,0 @@
namespace Tensorflow.Framework.Models
{
public class ScopedTFImportGraphDefOptions : ImportGraphDefOptions
{
public ScopedTFImportGraphDefOptions() : base()
{

}
}
}

+ 0
- 17
src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefResults.cs View File

@@ -1,17 +0,0 @@
using System;

namespace Tensorflow.Framework.Models
{
public class ScopedTFImportGraphDefResults : ImportGraphDefOptions
{
public ScopedTFImportGraphDefResults() : base()
{
}

public ScopedTFImportGraphDefResults(IntPtr results) : base(results)
{

}
}
}

+ 5
- 5
src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow
{ {
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal // need to create a class ImportGraphDefWithResults with IDisposal
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status.Handle);
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options.Handle, status.Handle);
status.Check(true); status.Check(true);
} }


@@ -114,8 +114,8 @@ namespace Tensorflow
Dictionary<string, Tensor> input_map, Dictionary<string, Tensor> input_map,
string[] return_elements) string[] return_elements)
{ {
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1);
c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1);


foreach(var input in input_map) foreach(var input in input_map)
{ {
@@ -130,11 +130,11 @@ namespace Tensorflow
if(name.Contains(":")) if(name.Contains(":"))
{ {
var (op_name, index) = _ParseTensorName(name); var (op_name, index) = _ParseTensorName(name);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index);
} }
else else
{ {
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name);
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name);
} }
} }




+ 3
- 3
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow
int size = Marshal.SizeOf<TF_Output>(); int size = Marshal.SizeOf<TF_Output>();
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);


c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s.Handle);
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts.Handle, return_output_handle, num_return_outputs, s.Handle);


var tf_output_ptr = (TF_Output*) return_output_handle; var tf_output_ptr = (TF_Output*) return_output_handle;
for (int i = 0; i < num_return_outputs; i++) for (int i = 0; i < num_return_outputs; i++)
@@ -53,8 +53,8 @@ namespace Tensorflow
using (var graph_def = new Buffer(bytes)) using (var graph_def = new Buffer(bytes))
{ {
as_default(); as_default();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix);
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status.Handle);
c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix);
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts.Handle, status.Handle);
status.Check(true); status.Check(true);
return status.Code == TF_Code.TF_OK; return status.Code == TF_Code.TF_OK;
} }


+ 8
- 14
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -18,30 +18,24 @@ using System;


namespace Tensorflow namespace Tensorflow
{ {
public class ImportGraphDefOptions : DisposableObject
public sealed class ImportGraphDefOptions : IDisposable
{ {
public SafeImportGraphDefOptionsHandle Handle { get; }

public int NumReturnOutputs public int NumReturnOutputs
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle);
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle);


public ImportGraphDefOptions() public ImportGraphDefOptions()
{ {
_handle = c_api.TF_NewImportGraphDefOptions();
}

public ImportGraphDefOptions(IntPtr handle)
{
_handle = handle;
Handle = c_api.TF_NewImportGraphDefOptions();
} }


public void AddReturnOutput(string name, int index) public void AddReturnOutput(string name, int index)
{ {
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index);
} }


protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteImportGraphDefOptions(handle);

public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle;
public static implicit operator ImportGraphDefOptions(IntPtr handle) => new ImportGraphDefOptions(handle);
public void Dispose()
=> Handle.Dispose();
} }
} }

+ 40
- 0
src/TensorFlowNET.Core/Graphs/SafeImportGraphDefOptionsHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow
{
public sealed class SafeImportGraphDefOptionsHandle : SafeTensorflowHandle
{
public SafeImportGraphDefOptionsHandle()
{
}

public SafeImportGraphDefOptionsHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteImportGraphDefOptions(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 13
- 13
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -78,7 +78,7 @@ namespace Tensorflow
/// <param name="num_return_outputs">int</param> /// <param name="num_return_outputs">int</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status);
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status);


/// <summary> /// <summary>
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and
@@ -92,7 +92,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns>TF_ImportGraphDefResults*</returns> /// <returns>TF_ImportGraphDefResults*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, SafeStatusHandle status);
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);


/// <summary> /// <summary>
/// Import the graph serialized in `graph_def` into `graph`. /// Import the graph serialized in `graph_def` into `graph`.
@@ -102,7 +102,7 @@ namespace Tensorflow
/// <param name="options">TF_ImportGraphDefOptions*</param> /// <param name="options">TF_ImportGraphDefOptions*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, SafeStatusHandle status);
public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
/// <summary> /// <summary>
/// Iterate through the operations of a graph. /// Iterate through the operations of a graph.
@@ -160,7 +160,7 @@ namespace Tensorflow
/// <param name="opts"></param> /// <param name="opts"></param>
/// <param name="oper"></param> /// <param name="oper"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddControlDependency(IntPtr opts, IntPtr oper);
public static extern void TF_ImportGraphDefOptionsAddControlDependency(SafeImportGraphDefOptionsHandle opts, IntPtr oper);


/// <summary> /// <summary>
/// Set any imported nodes with input `src_name:src_index` to have that input /// Set any imported nodes with input `src_name:src_index` to have that input
@@ -173,7 +173,7 @@ namespace Tensorflow
/// <param name="src_index">int</param> /// <param name="src_index">int</param>
/// <param name="dst">TF_Output</param> /// <param name="dst">TF_Output</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst);
public static extern void TF_ImportGraphDefOptionsAddInputMapping(SafeImportGraphDefOptionsHandle opts, string src_name, int src_index, TF_Output dst);


/// <summary> /// <summary>
/// Add an operation in `graph_def` to be returned via the `return_opers` output /// Add an operation in `graph_def` to be returned via the `return_opers` output
@@ -183,7 +183,7 @@ namespace Tensorflow
/// <param name="opts">TF_ImportGraphDefOptions* opts</param> /// <param name="opts">TF_ImportGraphDefOptions* opts</param>
/// <param name="oper_name">const char*</param> /// <param name="oper_name">const char*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name);
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name);


/// <summary> /// <summary>
/// Add an output in `graph_def` to be returned via the `return_outputs` output /// Add an output in `graph_def` to be returned via the `return_outputs` output
@@ -195,7 +195,7 @@ namespace Tensorflow
/// <param name="oper_name">const char*</param> /// <param name="oper_name">const char*</param>
/// <param name="index">int</param> /// <param name="index">int</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index);
public static extern void TF_ImportGraphDefOptionsAddReturnOutput(SafeImportGraphDefOptionsHandle opts, string oper_name, int index);


/// <summary> /// <summary>
/// Returns the number of return operations added via /// Returns the number of return operations added via
@@ -204,7 +204,7 @@ namespace Tensorflow
/// <param name="opts"></param> /// <param name="opts"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts);
public static extern int TF_ImportGraphDefOptionsNumReturnOperations(SafeImportGraphDefOptionsHandle opts);


/// <summary> /// <summary>
/// Returns the number of return outputs added via /// Returns the number of return outputs added via
@@ -213,7 +213,7 @@ namespace Tensorflow
/// <param name="opts">const TF_ImportGraphDefOptions*</param> /// <param name="opts">const TF_ImportGraphDefOptions*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts);
public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(SafeImportGraphDefOptionsHandle opts);


/// <summary> /// <summary>
/// Set any imported nodes with control input `src_name` to have that input /// Set any imported nodes with control input `src_name` to have that input
@@ -225,7 +225,7 @@ namespace Tensorflow
/// <param name="src_name">const char*</param> /// <param name="src_name">const char*</param>
/// <param name="dst">TF_Operation*</param> /// <param name="dst">TF_Operation*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsRemapControlDependency(IntPtr opts, string src_name, IntPtr dst);
public static extern void TF_ImportGraphDefOptionsRemapControlDependency(SafeImportGraphDefOptionsHandle opts, string src_name, IntPtr dst);


/// <summary> /// <summary>
/// Set the prefix to be prepended to the names of nodes in `graph_def` that will /// Set the prefix to be prepended to the names of nodes in `graph_def` that will
@@ -234,7 +234,7 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="ops"></param> /// <param name="ops"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix);
public static extern void TF_ImportGraphDefOptionsSetPrefix(SafeImportGraphDefOptionsHandle ops, string prefix);


/// <summary> /// <summary>
/// Set whether to uniquify imported operation names. If true, imported operation /// Set whether to uniquify imported operation names. If true, imported operation
@@ -246,7 +246,7 @@ namespace Tensorflow
/// <param name="ops">TF_ImportGraphDefOptions*</param> /// <param name="ops">TF_ImportGraphDefOptions*</param>
/// <param name="uniquify_prefix">unsigned char</param> /// <param name="uniquify_prefix">unsigned char</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(IntPtr ops, char uniquify_prefix);
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix);


/// <summary> /// <summary>
/// Fetches the return operations requested via /// Fetches the return operations requested via
@@ -295,7 +295,7 @@ namespace Tensorflow
public static extern IntPtr TF_NewGraph(); public static extern IntPtr TF_NewGraph();


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewImportGraphDefOptions();
public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions();


/// <summary> /// <summary>
/// Updates 'dst' to consume 'new_src'. /// Updates 'dst' to consume 'new_src'.


+ 34
- 30
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -222,10 +222,12 @@ namespace TensorFlowNET.UnitTest.NativeAPI
// Import it, with a prefix, in a fresh graph. // Import it, with a prefix, in a fresh graph.
graph.Dispose(); graph.Dispose();
graph = new Graph().as_default(); graph = new Graph().as_default();
var opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
using (var opts = c_api.TF_NewImportGraphDefOptions())
{
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
}


Operation scalar = graph.OperationByName("imported/scalar"); Operation scalar = graph.OperationByName("imported/scalar");
Operation feed = graph.OperationByName("imported/feed"); Operation feed = graph.OperationByName("imported/feed");
@@ -258,17 +260,19 @@ namespace TensorFlowNET.UnitTest.NativeAPI


// Import it again, with an input mapping, return outputs, and a return // Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph. // operation, into the same graph.
c_api.TF_DeleteImportGraphDefOptions(opts);
opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0));
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
IntPtr results;
using (var opts = c_api.TF_NewImportGraphDefOptions())
{
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0));
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
}


Operation scalar2 = graph.OperationByName("imported2/scalar"); Operation scalar2 = graph.OperationByName("imported2/scalar");
Operation feed2 = graph.OperationByName("imported2/feed"); Operation feed2 = graph.OperationByName("imported2/feed");
@@ -294,13 +298,14 @@ namespace TensorFlowNET.UnitTest.NativeAPI
c_api.TF_DeleteImportGraphDefResults(results); c_api.TF_DeleteImportGraphDefResults(results);


// Import again, with control dependencies, into the same graph. // Import again, with control dependencies, into the same graph.
c_api.TF_DeleteImportGraphDefOptions(opts);
opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
using (var opts = c_api.TF_NewImportGraphDefOptions())
{
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
}


var scalar3 = graph.OperationByName("imported3/scalar"); var scalar3 = graph.OperationByName("imported3/scalar");
var feed3 = graph.OperationByName("imported3/feed"); var feed3 = graph.OperationByName("imported3/feed");
@@ -327,12 +332,13 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);


// Import again, with remapped control dependency, into the same graph // Import again, with remapped control dependency, into the same graph
c_api.TF_DeleteImportGraphDefOptions(opts);
opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
using (var opts = c_api.TF_NewImportGraphDefOptions())
{
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
}


var scalar4 = graph.OperationByName("imported4/imported3/scalar"); var scalar4 = graph.OperationByName("imported4/imported3/scalar");
var feed4 = graph.OperationByName("imported4/imported2/feed"); var feed4 = graph.OperationByName("imported4/imported2/feed");
@@ -344,8 +350,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed, control_inputs[0]);
EXPECT_EQ(feed4, control_inputs[1]); EXPECT_EQ(feed4, control_inputs[1]);


c_api.TF_DeleteImportGraphDefOptions(opts);

// Can add nodes to the imported graph without trouble. // Can add nodes to the imported graph without trouble.
c_test_util.Add(feed, scalar, graph, s); c_test_util.Add(feed, scalar, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code); ASSERT_EQ(TF_Code.TF_OK, s.Code);


Loading…
Cancel
Save