diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs index a4e9c428..cc2c4cb6 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs @@ -10,6 +10,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_BOOL); return *(bool*) tensor.buffer; } } @@ -19,6 +20,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT8); return *(sbyte*) tensor.buffer; } } @@ -28,6 +30,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT8); return *(byte*) tensor.buffer; } } @@ -37,6 +40,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT16); return *(ushort*) tensor.buffer; } } @@ -46,6 +50,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT16); return *(short*) tensor.buffer; } } @@ -55,6 +60,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT32); return *(int*) tensor.buffer; } } @@ -64,6 +70,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT32); return *(uint*) tensor.buffer; } } @@ -73,6 +80,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT64); return *(long*) tensor.buffer; } } @@ -82,6 +90,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT64); return *(ulong*) tensor.buffer; } } @@ -91,6 +100,7 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_FLOAT); return *(float*) tensor.buffer; } } @@ -100,27 +110,29 @@ namespace Tensorflow unsafe { EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_DOUBLE); return *(double*) tensor.buffer; } } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void EnsureDType(Tensor tensor, TF_DataType @is) + { + if (tensor._dtype != @is) + throw new InvalidCastException($"Unable to cast scalar tensor {tensor._dtype} to {@is}"); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void EnsureScalar(Tensor tensor) { if (tensor == null) - { throw new ArgumentNullException(nameof(tensor)); - } if (tensor.TensorShape.ndim != 0) - { throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); - } if (tensor.TensorShape.size != 1) - { throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); - } } }