Browse Source

Tensor.Creation.cs: perf-ops

tags/v0.12
Eli Belash 6 years ago
parent
commit
533f8fdd6b
1 changed files with 41 additions and 41 deletions
  1. +41
    -41
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

+ 41
- 41
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -16,6 +16,7 @@


using NumSharp; using NumSharp;
using System; using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq; using System.Linq;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
@@ -463,7 +464,7 @@ namespace Tensorflow
*v = value; *v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true; IsMemoryOwner=true;
}
}
#endif #endif


/// <summary> /// <summary>
@@ -532,11 +533,10 @@ namespace Tensorflow


private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
{ {
if (nd.dtype.Name == "String")
throw new NotImplementedException("Support for NDArray of type string not implemented yet");
if (nd.dtype.Name == "String")
throw new NotImplementedException("Support for NDArray of type string not implemented yet");
IArraySlice arraySlice; IArraySlice arraySlice;
var shape = nd.Unsafe.Storage.Shape;
if (shape.IsSliced || shape.IsBroadcasted)
if (nd.Unsafe.Storage.Shape.IsContiguous == false)
{ {
// the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block.
arraySlice = nd.CloneData(); arraySlice = nd.CloneData();
@@ -553,40 +553,40 @@ namespace Tensorflow
var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs);
IsMemoryOwner = false; IsMemoryOwner = false;
return handle; return handle;
}
public unsafe Tensor(byte[][] buffer, long[] shape)
{
int size = 0;
foreach (var b in buffer)
{
size += (int)TF_StringEncodedSize((UIntPtr)b.Length);
}
int totalSize = size + buffer.Length * 8;
ulong offset = 0;
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize);
// Clear offset table
IntPtr pOffset = TF_TensorData(handle);
IntPtr dst = pOffset + buffer.Length * 8;
IntPtr dstLimit = pOffset + totalSize;
for (int i = 0; i < buffer.Length; i++)
{
Marshal.WriteInt64(pOffset, (long)offset);
using (var status = new Status())
{
fixed (byte* src = &buffer[i][0])
{
var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status);
status.Check(true);
pOffset += 8;
dst += (int)written;
offset += written;
}
}
}
_handle = handle;
}
public unsafe Tensor(byte[][] buffer, long[] shape)
{
int size = 0;
foreach (var b in buffer)
{
size += (int)TF_StringEncodedSize((UIntPtr)b.Length);
}
int totalSize = size + buffer.Length * 8;
ulong offset = 0;
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize);
// Clear offset table
IntPtr pOffset = TF_TensorData(handle);
IntPtr dst = pOffset + buffer.Length * 8;
IntPtr dstLimit = pOffset + totalSize;
for (int i = 0; i < buffer.Length; i++)
{
Marshal.WriteInt64(pOffset, (long)offset);
using (var status = new Status())
{
fixed (byte* src = &buffer[i][0])
{
var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status);
status.Check(true);
pOffset += 8;
dst += (int)written;
offset += written;
}
}
}
_handle = handle;
} }


public Tensor(Operation op, int value_index, TF_DataType dtype) public Tensor(Operation op, int value_index, TF_DataType dtype)
@@ -611,11 +611,11 @@ namespace Tensorflow
/// specified dimensions. /// specified dimensions.
/// </remarks> /// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
[SuppressMessage("ReSharper", "LocalVariableHidesMember")]
protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size)
{ {
if (dt == TF_DataType.TF_STRING && data is byte[])
if (dt == TF_DataType.TF_STRING && data is byte[] buffer)
{ {
var buffer = (byte[])data;
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));




Loading…
Cancel
Save