diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 43848da6..59c107fc 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -83,6 +83,12 @@ namespace Tensorflow throw new NotImplementedException("MakeNdarray"); } + private static readonly TF_DataType[] quantized_types = new TF_DataType[] + { + TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, + TF_DataType.TF_QINT32 + }; + /// /// Create a TensorProto. /// @@ -99,15 +105,6 @@ namespace Tensorflow if (values is TensorProto tp) return tp; - if (dtype != TF_DataType.DtInvalid) - ; - - bool is_quantized = new TF_DataType[] - { - TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, - TF_DataType.TF_QINT32 - }.Contains(dtype); - // We first convert value to a numpy array or scalar. NDArray nparray = null; var np_dt = dtype.as_numpy_dtype(); @@ -227,13 +224,13 @@ namespace Tensorflow } } - var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype); + var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); if (numpy_dtype == TF_DataType.DtInvalid) throw new TypeError($"Unrecognized data type: {nparray.dtype}"); // If dtype was specified and is a quantized type, we convert // numpy_dtype back into the quantized version. - if (is_quantized) + if (quantized_types.Contains(dtype)) numpy_dtype = dtype; bool is_same_size = false;