Browse Source

Implement SafeBufferHandle as a wrapper for TF_Buffer

tags/v0.20
Sam Harwell Haiping 5 years ago
parent
commit
caae2dbd66
18 changed files with 121 additions and 89 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Attributes/c_api.ops.cs
  2. +35
    -53
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  3. +40
    -0
      src/TensorFlowNET.Core/Buffers/SafeBufferHandle.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Framework/importer.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  7. +1
    -1
      src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs
  8. +1
    -1
      src/TensorFlowNET.Core/GraphTransformation/c_api.transform_graph.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Graphs/Graph.Export.cs
  10. +2
    -2
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  11. +5
    -5
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  12. +4
    -4
      src/TensorFlowNET.Core/Operations/Operation.cs
  13. +2
    -2
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  14. +10
    -0
      src/TensorFlowNET.Core/Util/SafeHandleLease.cs
  15. +6
    -6
      test/TensorFlowNET.UnitTest/GraphTest.cs
  16. +2
    -2
      test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs
  17. +4
    -4
      test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs
  18. +2
    -2
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Attributes/c_api.ops.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow
/// <param name="oper"></param> /// <param name="oper"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, SafeStatusHandle status);
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, SafeBufferHandle output_attr_value, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);


+ 35
- 53
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -14,10 +14,11 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using NumSharp.Backends.Unmanaged;
using System; using System;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using NumSharp.Backends.Unmanaged;
using Tensorflow.Util;
using static Tensorflow.c_api; using static Tensorflow.c_api;


namespace Tensorflow namespace Tensorflow
@@ -25,34 +26,30 @@ namespace Tensorflow
/// <summary> /// <summary>
/// Represents a TF_Buffer that can be passed to Tensorflow. /// Represents a TF_Buffer that can be passed to Tensorflow.
/// </summary> /// </summary>
public class Buffer : DisposableObject
public sealed class Buffer : IDisposable
{ {
private unsafe TF_Buffer buffer
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => *bufferptr;
}
public SafeBufferHandle Handle { get; }


private unsafe TF_Buffer* bufferptr
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => (TF_Buffer*) _handle;
}
/// <remarks>
/// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/>
/// </remarks>
private unsafe ref readonly TF_Buffer DangerousBuffer
=> ref Unsafe.AsRef<TF_Buffer>(Handle.DangerousGetHandle().ToPointer());


/// <summary> /// <summary>
/// The memory block representing this buffer. /// The memory block representing this buffer.
/// </summary> /// </summary>
/// <remarks>The deallocator is set to null.</remarks>
public UnmanagedMemoryBlock<byte> MemoryBlock
/// <remarks>
/// <para>The deallocator is set to null.</para>
///
/// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/>
/// </remarks>
public unsafe UnmanagedMemoryBlock<byte> DangerousMemoryBlock
{ {
get get
{ {
unsafe
{
EnsureNotDisposed();
var buff = (TF_Buffer*) _handle;
return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length);
}
ref readonly TF_Buffer buffer = ref DangerousBuffer;
return new UnmanagedMemoryBlock<byte>((byte*)buffer.data.ToPointer(), (long)buffer.length);
} }
} }


@@ -63,25 +60,23 @@ namespace Tensorflow
{ {
get get
{ {
EnsureNotDisposed();
return buffer.length;
using (Handle.Lease())
{
return DangerousBuffer.length;
}
} }
} }


public Buffer() => _handle = TF_NewBuffer();

public Buffer(IntPtr handle)
{
if (handle == IntPtr.Zero)
throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle));
public Buffer()
=> Handle = TF_NewBuffer();


_handle = handle;
}
public Buffer(SafeBufferHandle handle)
=> Handle = handle;


public Buffer(byte[] data) : this(_toBuffer(data))
{ }
public Buffer(byte[] data)
=> Handle = _toBuffer(data);


private static IntPtr _toBuffer(byte[] data)
private static SafeBufferHandle _toBuffer(byte[] data)
{ {
if (data == null) if (data == null)
throw new ArgumentNullException(nameof(data)); throw new ArgumentNullException(nameof(data));
@@ -93,38 +88,25 @@ namespace Tensorflow
} }
} }


public static implicit operator IntPtr(Buffer buffer)
{
buffer.EnsureNotDisposed();
return buffer._handle;
}

public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost.

/// <summary> /// <summary>
/// Copies this buffer's contents onto a <see cref="byte"/> array. /// Copies this buffer's contents onto a <see cref="byte"/> array.
/// </summary> /// </summary>
public byte[] ToArray() public byte[] ToArray()
{ {
EnsureNotDisposed();

unsafe
using (Handle.Lease())
{ {
var len = buffer.length;
var block = DangerousMemoryBlock;
var len = block.Count;
if (len == 0) if (len == 0)
return Array.Empty<byte>(); return Array.Empty<byte>();


byte[] data = new byte[len];
fixed (byte* dst = data)
System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len);

var data = new byte[len];
block.CopyTo(data, 0);
return data; return data;
} }
} }


protected override void DisposeUnmanagedResources(IntPtr handle)
{
TF_DeleteBuffer(handle);
}
public void Dispose()
=> Handle.Dispose();
} }
} }

+ 40
- 0
src/TensorFlowNET.Core/Buffers/SafeBufferHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow
{
public sealed class SafeBufferHandle : SafeTensorflowHandle
{
private SafeBufferHandle()
{
}

public SafeBufferHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteBuffer(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Buffers/c_api.buffer.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewBuffer();
public static extern SafeBufferHandle TF_NewBuffer();


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); public static extern IntPtr TF_GetBuffer(TF_Buffer buffer);
@@ -42,6 +42,6 @@ namespace Tensorflow
/// <param name="proto_len">size_t</param> /// <param name="proto_len">size_t</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewBufferFromString(IntPtr proto, ulong proto_len);
public static extern SafeBufferHandle TF_NewBufferFromString(IntPtr proto, ulong proto_len);
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow
{ {
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal // need to create a class ImportGraphDefWithResults with IDisposal
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options.Handle, status.Handle);
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle);
status.Check(true); status.Check(true);
} }




+ 1
- 1
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
{ {
_registered_ops = new Dictionary<string, OpDef>(); _registered_ops = new Dictionary<string, OpDef>();
using var buffer = new Buffer(c_api.TF_GetAllOpList()); using var buffer = new Buffer(c_api.TF_GetAllOpList());
using var stream = buffer.MemoryBlock.Stream();
using var stream = buffer.DangerousMemoryBlock.Stream();
var op_list = OpList.Parser.ParseFrom(stream); var op_list = OpList.Parser.ParseFrom(stream);
foreach (var op_def in op_list.Op) foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def; _registered_ops[op_def.Name] = op_def;


+ 1
- 1
src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow
inputs_string, inputs_string,
outputs_string, outputs_string,
transforms_string, transforms_string,
buffer,
buffer.Handle,
status.Handle); status.Handle);


status.Check(false); status.Check(false);


+ 1
- 1
src/TensorFlowNET.Core/GraphTransformation/c_api.transform_graph.cs View File

@@ -27,7 +27,7 @@ namespace Tensorflow
string inputs_string, string inputs_string,
string outputs_string, string outputs_string,
string transforms_string, string transforms_string,
IntPtr output_buffer,
SafeBufferHandle output_buffer,
SafeStatusHandle status); SafeStatusHandle status);
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Graphs/Graph.Export.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow
public Buffer ToGraphDef(Status s) public Buffer ToGraphDef(Status s)
{ {
var buffer = new Buffer(); var buffer = new Buffer();
c_api.TF_GraphToGraphDef(_handle, buffer, s.Handle);
c_api.TF_GraphToGraphDef(_handle, buffer.Handle, s.Handle);
s.Check(true); s.Check(true);


return buffer; return buffer;
@@ -39,7 +39,7 @@ namespace Tensorflow
{ {
status.Check(true); status.Check(true);
// limit size to 250M, recursion to max 100 // limit size to 250M, recursion to max 100
var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100);
var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock.Stream(), 250 * 1024 * 1024, 100);
def = GraphDef.Parser.ParseFrom(inputStream); def = GraphDef.Parser.ParseFrom(inputStream);
} }




+ 2
- 2
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow
int size = Marshal.SizeOf<TF_Output>(); int size = Marshal.SizeOf<TF_Output>();
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);


c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts.Handle, return_output_handle, num_return_outputs, s.Handle);
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def.Handle, opts.Handle, return_output_handle, num_return_outputs, s.Handle);


var tf_output_ptr = (TF_Output*) return_output_handle; var tf_output_ptr = (TF_Output*) return_output_handle;
for (int i = 0; i < num_return_outputs; i++) for (int i = 0; i < num_return_outputs; i++)
@@ -54,7 +54,7 @@ namespace Tensorflow
{ {
as_default(); as_default();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix); c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix);
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts.Handle, status.Handle);
c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle);
status.Check(true); status.Check(true);
return status.Code == TF_Code.TF_OK; return status.Code == TF_Code.TF_OK;
} }


+ 5
- 5
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -47,7 +47,7 @@ namespace Tensorflow
public static extern string TF_GraphDebugString(IntPtr graph, out int len); public static extern string TF_GraphDebugString(IntPtr graph, out int len);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, SafeStatusHandle status);
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, SafeBufferHandle output_op_def, SafeStatusHandle status);


/// <summary> /// <summary>
/// Returns the shape of the Tensor referenced by `output` in `graph` /// Returns the shape of the Tensor referenced by `output` in `graph`
@@ -78,7 +78,7 @@ namespace Tensorflow
/// <param name="num_return_outputs">int</param> /// <param name="num_return_outputs">int</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status);
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status);


/// <summary> /// <summary>
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and
@@ -92,7 +92,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns>TF_ImportGraphDefResults*</returns> /// <returns>TF_ImportGraphDefResults*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);


/// <summary> /// <summary>
/// Import the graph serialized in `graph_def` into `graph`. /// Import the graph serialized in `graph_def` into `graph`.
@@ -102,7 +102,7 @@ namespace Tensorflow
/// <param name="options">TF_ImportGraphDefOptions*</param> /// <param name="options">TF_ImportGraphDefOptions*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
public static extern void TF_GraphImportGraphDef(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
/// <summary> /// <summary>
/// Iterate through the operations of a graph. /// Iterate through the operations of a graph.
@@ -138,7 +138,7 @@ namespace Tensorflow
/// <param name="output_graph_def">TF_Buffer*</param> /// <param name="output_graph_def">TF_Buffer*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, SafeStatusHandle status);
public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status);
/// <summary> /// <summary>
/// Returns the number of dimensions of the Tensor referenced by `output` /// Returns the number of dimensions of the Tensor referenced by `output`


+ 4
- 4
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -236,10 +236,10 @@ namespace Tensorflow
lock (Locks.ProcessWide) lock (Locks.ProcessWide)
{ {
using var buf = new Buffer(); using var buf = new Buffer();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.status.Handle);
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.status.Handle);
tf.status.Check(true); tf.status.Check(true);


x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream());
x = AttrValue.Parser.ParseFrom(buf.DangerousMemoryBlock.Stream());
} }


string oneof_value = x.ValueCase.ToString(); string oneof_value = x.ValueCase.ToString();
@@ -269,10 +269,10 @@ namespace Tensorflow
using (var s = new Status()) using (var s = new Status())
using (var buffer = new Buffer()) using (var buffer = new Buffer())
{ {
c_api.TF_OperationToNodeDef(_handle, buffer, s.Handle);
c_api.TF_OperationToNodeDef(_handle, buffer.Handle, s.Handle);
s.Check(); s.Check();


return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
return NodeDef.Parser.ParseFrom(buffer.DangerousMemoryBlock.Stream());
} }
} }




+ 2
- 2
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GetAllOpList();
public static extern SafeBufferHandle TF_GetAllOpList();


/// <summary> /// <summary>
/// For inputs that take a single tensor. /// For inputs that take a single tensor.
@@ -204,7 +204,7 @@ namespace Tensorflow
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_OperationToNodeDef(IntPtr oper, IntPtr buffer, SafeStatusHandle status);
public static extern void TF_OperationToNodeDef(IntPtr oper, SafeBufferHandle buffer, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status);


+ 10
- 0
src/TensorFlowNET.Core/Util/SafeHandleLease.cs View File

@@ -23,6 +23,16 @@ namespace Tensorflow.Util
/// Represents a lease of a <see cref="SafeHandle"/>. /// Represents a lease of a <see cref="SafeHandle"/>.
/// </summary> /// </summary>
/// <seealso cref="SafeHandleExtensions.Lease"/> /// <seealso cref="SafeHandleExtensions.Lease"/>
/// <devdoc>
/// <para>Elements in this section may be referenced by <c>&lt;inheritdoc&gt;</c> elements to provide common
/// language in documentation remarks.</para>
///
/// <usage>
/// <para>The result of this method is only valid when the underlying handle has not been disposed. If the lifetime
/// of the object is unclear, a lease may be used to prevent disposal while the object is in use. See
/// <see cref="SafeHandleExtensions.Lease(SafeHandle)"/>.</para>
/// </usage>
/// </devdoc>
public readonly struct SafeHandleLease : IDisposable public readonly struct SafeHandleLease : IDisposable
{ {
private readonly SafeHandle _handle; private readonly SafeHandle _handle;


+ 6
- 6
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -216,7 +216,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI


// Export to a GraphDef. // Export to a GraphDef.
var graph_def = new Buffer(); var graph_def = new Buffer();
c_api.TF_GraphToGraphDef(graph, graph_def, s.Handle);
c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);


// Import it, with a prefix, in a fresh graph. // Import it, with a prefix, in a fresh graph.
@@ -225,7 +225,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
using (var opts = c_api.TF_NewImportGraphDefOptions()) using (var opts = c_api.TF_NewImportGraphDefOptions())
{ {
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);
} }


@@ -270,7 +270,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s.Handle);
results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);
} }


@@ -303,7 +303,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);
} }


@@ -328,7 +328,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI


// Export to a graph def so we can import a graph with control dependencies // Export to a graph def so we can import a graph with control dependencies
graph_def = new Buffer(); graph_def = new Buffer();
c_api.TF_GraphToGraphDef(graph, graph_def, s.Handle);
c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);


// Import again, with remapped control dependency, into the same graph // Import again, with remapped control dependency, into the same graph
@@ -336,7 +336,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI
{ {
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s.Handle);
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
ASSERT_EQ(TF_Code.TF_OK, s.Code); ASSERT_EQ(TF_Code.TF_OK, s.Code);
} }




+ 2
- 2
test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs View File

@@ -50,11 +50,11 @@ namespace TensorFlowNET.UnitTest.NativeAPI
{ {
using (var buffer = new Buffer()) using (var buffer = new Buffer())
{ {
c_api.TF_GraphToGraphDef(graph, buffer, s.Handle);
c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle);
bool ret = TF_GetCode(s) == TF_OK; bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s)); EXPECT_EQ(TF_OK, TF_GetCode(s));
if (ret) if (ret)
graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
graph_def = GraphDef.Parser.ParseFrom(buffer.DangerousMemoryBlock.Stream());
return ret; return ret;
} }
} }


+ 4
- 4
test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs View File

@@ -38,8 +38,8 @@ namespace TensorFlowNET.UnitTest
{ {
using (var buffer = new Buffer()) using (var buffer = new Buffer())
{ {
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s.Handle);
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream());
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer.Handle, s.Handle);
attr_value = AttrValue.Parser.ParseFrom(buffer.DangerousMemoryBlock.Stream());
} }


return s.Code == TF_Code.TF_OK; return s.Code == TF_Code.TF_OK;
@@ -53,9 +53,9 @@ namespace TensorFlowNET.UnitTest
using (var s = new Status()) using (var s = new Status())
using (var buffer = new Buffer()) using (var buffer = new Buffer())
{ {
c_api.TF_GraphToGraphDef(graph, buffer, s.Handle);
c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle);
s.Check(); s.Check();
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
return GraphDef.Parser.ParseFrom(buffer.DangerousMemoryBlock.Stream());
} }
} }
} }


+ 2
- 2
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -22,8 +22,8 @@ namespace TensorFlowNET.UnitTest.Basics
public void GetAllOpList() public void GetAllOpList()
{ {
var handle = c_api.TF_GetAllOpList(); var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle);
var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream());
using var buffer = new Buffer(handle);
var op_list = OpList.Parser.ParseFrom(buffer.DangerousMemoryBlock.Stream());


var _registered_ops = new Dictionary<string, OpDef>(); var _registered_ops = new Dictionary<string, OpDef>();
foreach (var op_def in op_list.Op) foreach (var op_def in op_list.Op)


Loading…
Cancel
Save