| @@ -10,6 +10,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_BOOL); | |||
| return *(bool*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -19,6 +20,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_INT8); | |||
| return *(sbyte*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -28,6 +30,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_UINT8); | |||
| return *(byte*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -37,6 +40,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_UINT16); | |||
| return *(ushort*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -46,6 +50,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_INT16); | |||
| return *(short*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -55,6 +60,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_INT32); | |||
| return *(int*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -64,6 +70,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_UINT32); | |||
| return *(uint*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -73,6 +80,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_INT64); | |||
| return *(long*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -82,6 +90,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_UINT64); | |||
| return *(ulong*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -91,6 +100,7 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||
| return *(float*) tensor.buffer; | |||
| } | |||
| } | |||
| @@ -100,27 +110,29 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| EnsureScalar(tensor); | |||
| EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||
| return *(double*) tensor.buffer; | |||
| } | |||
| } | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| private static void EnsureDType(Tensor tensor, TF_DataType @is) | |||
| { | |||
| if (tensor._dtype != @is) | |||
| throw new InvalidCastException($"Unable to cast scalar tensor {tensor._dtype} to {@is}"); | |||
| } | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| private static void EnsureScalar(Tensor tensor) | |||
| { | |||
| if (tensor == null) | |||
| { | |||
| throw new ArgumentNullException(nameof(tensor)); | |||
| } | |||
| if (tensor.TensorShape.ndim != 0) | |||
| { | |||
| throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | |||
| } | |||
| if (tensor.TensorShape.size != 1) | |||
| { | |||
| throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | |||
| } | |||
| } | |||
| } | |||