| @@ -10,6 +10,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_BOOL); | |||||
| return *(bool*) tensor.buffer; | return *(bool*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -19,6 +20,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_INT8); | |||||
| return *(sbyte*) tensor.buffer; | return *(sbyte*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -28,6 +30,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_UINT8); | |||||
| return *(byte*) tensor.buffer; | return *(byte*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -37,6 +40,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_UINT16); | |||||
| return *(ushort*) tensor.buffer; | return *(ushort*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -46,6 +50,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_INT16); | |||||
| return *(short*) tensor.buffer; | return *(short*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -55,6 +60,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_INT32); | |||||
| return *(int*) tensor.buffer; | return *(int*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -64,6 +70,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_UINT32); | |||||
| return *(uint*) tensor.buffer; | return *(uint*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -73,6 +80,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_INT64); | |||||
| return *(long*) tensor.buffer; | return *(long*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -82,6 +90,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_UINT64); | |||||
| return *(ulong*) tensor.buffer; | return *(ulong*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -91,6 +100,7 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||||
| return *(float*) tensor.buffer; | return *(float*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| @@ -100,27 +110,29 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| EnsureScalar(tensor); | EnsureScalar(tensor); | ||||
| EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||||
| return *(double*) tensor.buffer; | return *(double*) tensor.buffer; | ||||
| } | } | ||||
| } | } | ||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| private static void EnsureDType(Tensor tensor, TF_DataType @is) | |||||
| { | |||||
| if (tensor._dtype != @is) | |||||
| throw new InvalidCastException($"Unable to cast scalar tensor {tensor._dtype} to {@is}"); | |||||
| } | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | [MethodImpl(MethodImplOptions.AggressiveInlining)] | ||||
| private static void EnsureScalar(Tensor tensor) | private static void EnsureScalar(Tensor tensor) | ||||
| { | { | ||||
| if (tensor == null) | if (tensor == null) | ||||
| { | |||||
| throw new ArgumentNullException(nameof(tensor)); | throw new ArgumentNullException(nameof(tensor)); | ||||
| } | |||||
| if (tensor.TensorShape.ndim != 0) | if (tensor.TensorShape.ndim != 0) | ||||
| { | |||||
| throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | ||||
| } | |||||
| if (tensor.TensorShape.size != 1) | if (tensor.TensorShape.size != 1) | ||||
| { | |||||
| throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | ||||
| } | |||||
| } | } | ||||
| } | } | ||||