|
|
@@ -21,6 +21,9 @@ using System.Collections.Generic; |
|
|
using System.Linq; |
|
|
using System.Linq; |
|
|
using System.Numerics; |
|
|
using System.Numerics; |
|
|
using System.Text; |
|
|
using System.Text; |
|
|
|
|
|
using Google.Protobuf; |
|
|
|
|
|
using NumSharp.Backends; |
|
|
|
|
|
using Tensorflow.Util; |
|
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
namespace Tensorflow |
|
|
{ |
|
|
{ |
|
|
@@ -246,111 +249,167 @@ namespace Tensorflow |
|
|
return result; |
|
|
return result; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private unsafe NDArray fetchValue(IntPtr output) |
|
|
|
|
|
|
|
|
private static unsafe NDArray fetchValue(IntPtr output) |
|
|
{ |
|
|
{ |
|
|
var tensor = new Tensor(output); |
|
|
|
|
|
NDArray nd = null; |
|
|
|
|
|
Type type = tensor.dtype.as_numpy_dtype(); |
|
|
|
|
|
var ndims = tensor.shape; |
|
|
|
|
|
var offset = (byte*) c_api.TF_TensorData(output); |
|
|
|
|
|
|
|
|
|
|
|
if(ndims.Length == 0) |
|
|
|
|
|
|
|
|
NDArray ret; |
|
|
|
|
|
using (var tensor = new Tensor(output)) |
|
|
{ |
|
|
{ |
|
|
switch (tensor.dtype) |
|
|
|
|
|
|
|
|
var ndims = tensor.shape; |
|
|
|
|
|
var srcAddress = c_api.TF_TensorData(output).ToInt64(); |
|
|
|
|
|
|
|
|
|
|
|
if (ndims.Length == 0) |
|
|
{ |
|
|
{ |
|
|
case TF_DataType.TF_BOOL: |
|
|
|
|
|
nd = NDArray.Scalar(*(bool*)offset); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_STRING: |
|
|
|
|
|
var bytes = tensor.BufferToArray(); |
|
|
|
|
|
// wired, don't know why we have to start from offset 9. |
|
|
|
|
|
// length in the begin |
|
|
|
|
|
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); |
|
|
|
|
|
nd = NDArray.FromString(str); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_UINT8: |
|
|
|
|
|
nd = NDArray.Scalar(*(byte*)offset); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT16: |
|
|
|
|
|
nd = NDArray.Scalar(*(short*)offset); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT32: |
|
|
|
|
|
nd = NDArray.Scalar(*(int*)offset); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT64: |
|
|
|
|
|
nd = NDArray.Scalar(*(long*)offset); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_FLOAT: |
|
|
|
|
|
nd = NDArray.Scalar(*(float*)offset); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_DOUBLE: |
|
|
|
|
|
nd = NDArray.Scalar(*(double*)offset); |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
throw new NotImplementedException("can't fetch output"); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
switch (tensor.dtype) |
|
|
|
|
|
|
|
|
switch (tensor.dtype) |
|
|
|
|
|
{ |
|
|
|
|
|
case TF_DataType.TF_BOOL: |
|
|
|
|
|
ret = NDArray.Scalar(*(bool*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_STRING: |
|
|
|
|
|
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) |
|
|
|
|
|
ret = NDArray.FromString(reader.ReadString()); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_UINT8: |
|
|
|
|
|
ret = NDArray.Scalar(*(byte*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT16: |
|
|
|
|
|
ret = NDArray.Scalar(*(short*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT32: |
|
|
|
|
|
ret = NDArray.Scalar(*(int*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT64: |
|
|
|
|
|
ret = NDArray.Scalar(*(long*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_UINT16: |
|
|
|
|
|
ret = NDArray.Scalar(*(ushort*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_UINT32: |
|
|
|
|
|
ret = NDArray.Scalar(*(uint*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_UINT64: |
|
|
|
|
|
ret = NDArray.Scalar(*(ulong*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_FLOAT: |
|
|
|
|
|
ret = NDArray.Scalar(*(float*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_DOUBLE: |
|
|
|
|
|
ret = NDArray.Scalar(*(double*) srcAddress); |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
throw new NotImplementedException("can't fetch output"); |
|
|
|
|
|
} |
|
|
|
|
|
} else |
|
|
{ |
|
|
{ |
|
|
case TF_DataType.TF_BOOL: |
|
|
|
|
|
var bools = new bool[tensor.size]; |
|
|
|
|
|
for (ulong i = 0; i < tensor.size; i++) |
|
|
|
|
|
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); |
|
|
|
|
|
nd = np.array(bools).reshape(ndims); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_STRING: |
|
|
|
|
|
var bytes = tensor.BufferToArray(); |
|
|
|
|
|
// wired, don't know why we have to start from offset 9. |
|
|
|
|
|
// length in the begin |
|
|
|
|
|
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); |
|
|
|
|
|
nd = np.array(str); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_UINT8: |
|
|
|
|
|
var _bytes = new byte[tensor.size]; |
|
|
|
|
|
for (ulong i = 0; i < tensor.size; i++) |
|
|
|
|
|
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); |
|
|
|
|
|
nd = np.array(_bytes).reshape(ndims); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT16: |
|
|
|
|
|
var shorts = new short[tensor.size]; |
|
|
|
|
|
for (ulong i = 0; i < tensor.size; i++) |
|
|
|
|
|
shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); |
|
|
|
|
|
nd = np.array(shorts).reshape(ndims); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT32: |
|
|
|
|
|
var ints = new int[tensor.size]; |
|
|
|
|
|
for (ulong i = 0; i < tensor.size; i++) |
|
|
|
|
|
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); |
|
|
|
|
|
nd = np.array(ints).reshape(ndims); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_INT64: |
|
|
|
|
|
var longs = new long[tensor.size]; |
|
|
|
|
|
for (ulong i = 0; i < tensor.size; i++) |
|
|
|
|
|
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); |
|
|
|
|
|
nd = np.array(longs).reshape(ndims); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_FLOAT: |
|
|
|
|
|
var floats = new float[tensor.size]; |
|
|
|
|
|
for (ulong i = 0; i < tensor.size; i++) |
|
|
|
|
|
floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); |
|
|
|
|
|
nd = np.array(floats).reshape(ndims); |
|
|
|
|
|
break; |
|
|
|
|
|
case TF_DataType.TF_DOUBLE: |
|
|
|
|
|
var doubles = new double[tensor.size]; |
|
|
|
|
|
for (ulong i = 0; i < tensor.size; i++) |
|
|
|
|
|
doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); |
|
|
|
|
|
nd = np.array(doubles).reshape(ndims); |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
throw new NotImplementedException("can't fetch output"); |
|
|
|
|
|
|
|
|
//var size = (long) tensor.size; |
|
|
|
|
|
//var itemsize = (long) tensor.itemsize; |
|
|
|
|
|
var bytesize = (long) tensor.bytesize; |
|
|
|
|
|
var src = (void*) srcAddress; |
|
|
|
|
|
|
|
|
|
|
|
#if _REGEN |
|
|
|
|
|
#region Compute |
|
|
|
|
|
switch (tensor.dtype) |
|
|
|
|
|
{ |
|
|
|
|
|
%foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")% |
|
|
|
|
|
case TF_DataType.#3: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.#1, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
% |
|
|
|
|
|
case TF_DataType.TF_STRING: |
|
|
|
|
|
{ |
|
|
|
|
|
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString |
|
|
|
|
|
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) |
|
|
|
|
|
ret = NDArray.FromString(reader.ReadString()); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
default: |
|
|
|
|
|
throw new NotSupportedException(); |
|
|
|
|
|
} |
|
|
|
|
|
#endregion |
|
|
|
|
|
#else |
|
|
|
|
|
|
|
|
|
|
|
#region Compute |
|
|
|
|
|
switch (tensor.dtype) |
|
|
|
|
|
{ |
|
|
|
|
|
case TF_DataType.TF_BOOL: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.Boolean, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_UINT8: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.Byte, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_INT16: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.Int16, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_UINT16: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.UInt16, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_INT32: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.Int32, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_UINT32: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.UInt32, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_INT64: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.Int64, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_UINT64: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.UInt64, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_DOUBLE: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.Double, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_FLOAT: |
|
|
|
|
|
{ |
|
|
|
|
|
ret = new NDArray(NPTypeCode.Single, ndims, false); |
|
|
|
|
|
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
case TF_DataType.TF_STRING: |
|
|
|
|
|
{ |
|
|
|
|
|
throw new NotImplementedException(); |
|
|
|
|
|
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString |
|
|
|
|
|
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) |
|
|
|
|
|
ret = NDArray.FromString(reader.ReadString()); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
default: |
|
|
|
|
|
throw new NotSupportedException(); |
|
|
|
|
|
} |
|
|
|
|
|
#endregion |
|
|
|
|
|
#endif |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
tensor.Dispose(); |
|
|
|
|
|
|
|
|
|
|
|
return nd; |
|
|
|
|
|
|
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/// <summary> |
|
|
/// <summary> |
|
|
|