From 1142f38d2ff6adafa180afc2ef1807cb71219c67 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 10 Jul 2021 16:39:40 -0500 Subject: [PATCH] unify numpy dtype and tf dtype. --- .../Data/MnistModelLoader.cs | 4 +- .../Implementation/NumPyImpl.Creation.cs | 10 +- src/TensorFlowNET.Core/Numpy/IMemoryBlock.cs | 6 +- src/TensorFlowNET.Core/Numpy/InfoOf.cs | 50 +--- .../Numpy/NDArray.Creation.cs | 12 +- src/TensorFlowNET.Core/Numpy/NDArray.cs | 4 +- .../Numpy/Numpy.Creation.cs | 16 +- src/TensorFlowNET.Core/Numpy/Numpy.cs | 62 ++--- src/TensorFlowNET.Core/Numpy/NumpyDType.cs | 90 ------- .../Operations/array_ops.cs | 2 +- .../Tensors/Tensor.Creation.cs | 2 +- .../Tensors/TensorConverter.cs | 225 ------------------ src/TensorFlowNET.Core/Tensors/constant_op.cs | 2 +- src/TensorFlowNET.Core/Tensors/dtypes.cs | 50 ---- src/TensorFlowNET.Keras/Datasets/Imdb.cs | 4 +- src/TensorFlowNET.Keras/Sequence.cs | 2 +- src/TensorFlowNET.Keras/Utils/np_utils.cs | 4 +- .../Basics/SessionTest.cs | 6 +- .../Utilities/FluentExtension.cs | 51 +--- 19 files changed, 66 insertions(+), 536 deletions(-) delete mode 100644 src/TensorFlowNET.Core/Numpy/NumpyDType.cs delete mode 100644 src/TensorFlowNET.Core/Tensors/TensorConverter.cs diff --git a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs index 5be88bba..73fb52f9 100644 --- a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs +++ b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs @@ -123,7 +123,7 @@ namespace Tensorflow bytestream.Read(buf, 0, buf.Length); - var data = np.frombuffer(buf, np.@byte); + var data = np.frombuffer(buf, np.@byte.as_system_dtype()); data = data.reshape((num_images, rows, cols, 1)); return data; @@ -148,7 +148,7 @@ namespace Tensorflow bytestream.Read(buf, 0, buf.Length); - var labels = np.frombuffer(buf, np.uint8); + var labels = np.frombuffer(buf, np.uint8.as_system_dtype()); if (one_hot) return DenseToOneHot(labels, num_classes); diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs index 8e73325c..6ad41ff5 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs @@ -7,7 +7,7 @@ namespace Tensorflow.NumPy { public partial class NumPyImpl { - public NDArray eye(int N, int? M = null, int k = 0, NumpyDType dtype = NumpyDType.Double) + public NDArray eye(int N, int? M = null, int k = 0, TF_DataType dtype = TF_DataType.TF_DOUBLE) { if (!M.HasValue) M = N; @@ -28,16 +28,16 @@ namespace Tensorflow.NumPy diag_len = N + k; } - var diagonal_ = array_ops.ones(new TensorShape(diag_len), dtype: dtype.as_tf_dtype()); + var diagonal_ = array_ops.ones(new TensorShape(diag_len), dtype: dtype); var tensor = array_ops.matrix_diag(diagonal: diagonal_, num_rows: N, num_cols: M.Value, k: k); return new NDArray(tensor); } public NDArray linspace(T start, T stop, int num = 50, bool endpoint = true, bool retstep = false, - NumpyDType dtype = NumpyDType.Double, int axis = 0) + TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) { - var start_tensor = array_ops.constant(start, dtype: dtype.as_tf_dtype()); - var stop_tensor = array_ops.constant(stop, dtype: dtype.as_tf_dtype()); + var start_tensor = array_ops.constant(start, dtype: dtype); + var stop_tensor = array_ops.constant(stop, dtype: dtype); var num_tensor = array_ops.constant(num); // var step_tensor = array_ops.constant(np.nan); Tensor result = null; diff --git a/src/TensorFlowNET.Core/Numpy/IMemoryBlock.cs b/src/TensorFlowNET.Core/Numpy/IMemoryBlock.cs index 9ade984c..87c31a21 100644 --- a/src/TensorFlowNET.Core/Numpy/IMemoryBlock.cs +++ b/src/TensorFlowNET.Core/Numpy/IMemoryBlock.cs @@ -9,7 +9,7 @@ namespace Tensorflow.NumPy /// /// The size of a single item stored in . /// - /// Equivalent to extension. + /// Equivalent to extension. int ItemLength { get; } /// @@ -30,8 +30,8 @@ namespace Tensorflow.NumPy long BytesLength { get; } /// - /// The of the type stored inside this memory block. + /// The of the type stored inside this memory block. /// - NumpyDType TypeCode { get; } + TF_DataType TypeCode { get; } } } diff --git a/src/TensorFlowNET.Core/Numpy/InfoOf.cs b/src/TensorFlowNET.Core/Numpy/InfoOf.cs index feee3a97..5286b56d 100644 --- a/src/TensorFlowNET.Core/Numpy/InfoOf.cs +++ b/src/TensorFlowNET.Core/Numpy/InfoOf.cs @@ -8,60 +8,14 @@ namespace Tensorflow.NumPy public class InfoOf { public static readonly int Size; - public static readonly NumpyDType NPTypeCode; + public static readonly TF_DataType NPTypeCode; public static readonly T Zero; public static readonly T MaxValue; public static readonly T MinValue; static InfoOf() { - NPTypeCode = typeof(T).GetTypeCode(); - - switch (NPTypeCode) - { - case NumpyDType.Boolean: - Size = 1; - break; - case NumpyDType.Char: - Size = 2; - break; - case NumpyDType.Byte: - Size = 1; - break; - case NumpyDType.Int16: - Size = 2; - break; - case NumpyDType.UInt16: - Size = 2; - break; - case NumpyDType.Int32: - Size = 4; - break; - case NumpyDType.UInt32: - Size = 4; - break; - case NumpyDType.Int64: - Size = 8; - break; - case NumpyDType.UInt64: - Size = 8; - break; - case NumpyDType.Single: - Size = 4; - break; - case NumpyDType.Double: - Size = 8; - break; - case NumpyDType.Decimal: - Size = 16; - break; - case NumpyDType.String: - break; - case NumpyDType.Complex: - default: - Size = Marshal.SizeOf(); - break; - } + Size = NPTypeCode.get_datatype_size(); } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index ea9d9d69..43fdde55 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -18,10 +18,8 @@ namespace Tensorflow.NumPy public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape); - public NDArray(Shape shape, NumpyDType dtype = NumpyDType.Float) - { - Initialize(shape, dtype: dtype); - } + public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => _tensor = new EagerTensor(shape, dtype: dtype); public NDArray(Tensor value, Shape? shape = null) { @@ -45,11 +43,5 @@ namespace Tensorflow.NumPy _ => throw new NotImplementedException("") }; } - - void Initialize(Shape shape, NumpyDType dtype = NumpyDType.Float) - { - // _tensor = tf.zeros(shape, dtype: dtype.as_tf_dtype()); - _tensor = new EagerTensor(shape, dtype: dtype.as_tf_dtype()); - } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index 72ab9475..5b493e32 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -9,7 +9,7 @@ namespace Tensorflow.NumPy public partial class NDArray { Tensor _tensor; - public NumpyDType dtype => _tensor.dtype.as_numpy_typecode(); + public TF_DataType dtype => _tensor.dtype; public ulong size => _tensor.size; public ulong dtypesize => _tensor.itemsize; public int ndim => _tensor.NDims; @@ -40,7 +40,7 @@ namespace Tensorflow.NumPy public T MoveNext() => throw new NotImplementedException(""); public NDArray reshape(Shape newshape) => new NDArray(_tensor, newshape); public NDArray astype(Type type) => throw new NotImplementedException(""); - public NDArray astype(NumpyDType type) => throw new NotImplementedException(""); + public NDArray astype(TF_DataType type) => throw new NotImplementedException(""); public bool array_equal(NDArray rhs) => throw new NotImplementedException(""); public NDArray ravel() => throw new NotImplementedException(""); public void shuffle(NDArray nd) => throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs index 2e8a8021..7fd02f8e 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs @@ -24,29 +24,29 @@ namespace Tensorflow.NumPy where T : unmanaged => new NDArray(tf.range(start, limit: end, delta: step)); - public static NDArray empty(Shape shape, NumpyDType dtype = NumpyDType.Double) - => new NDArray(tf.zeros(shape, dtype: dtype.as_tf_dtype())); + public static NDArray empty(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => new NDArray(tf.zeros(shape, dtype: dtype)); - public static NDArray eye(int N, int? M = null, int k = 0, NumpyDType dtype = NumpyDType.Double) + public static NDArray eye(int N, int? M = null, int k = 0, TF_DataType dtype = TF_DataType.TF_DOUBLE) => tf.numpy.eye(N, M: M, k: k, dtype: dtype); public static NDArray full(Shape shape, T fill_value) => new NDArray(tf.fill(tf.constant(shape), fill_value)); public static NDArray linspace(T start, T stop, int num = 50, bool endpoint = true, bool retstep = false, - NumpyDType dtype = NumpyDType.Double, int axis = 0) where T : unmanaged + TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) where T : unmanaged => tf.numpy.linspace(start, stop, num: num, endpoint: endpoint, retstep: retstep, dtype: dtype, axis: axis); public static (NDArray, NDArray) meshgrid(T x, T y, bool copy = true, bool sparse = false) => tf.numpy.meshgrid(new[] { x, y }, copy: copy, sparse: sparse); - public static NDArray ones(Shape shape, NumpyDType dtype = NumpyDType.Double) - => new NDArray(tf.ones(shape, dtype: dtype.as_tf_dtype())); + public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => new NDArray(tf.ones(shape, dtype: dtype)); public static NDArray ones_like(NDArray a, Type dtype = null) => throw new NotImplementedException(""); - public static NDArray zeros(Shape shape, NumpyDType dtype = NumpyDType.Double) - => new NDArray(tf.zeros(shape, dtype: dtype.as_tf_dtype())); + public static NDArray zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => new NDArray(tf.zeros(shape, dtype: dtype)); } } diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs index 14cc69ba..dc8489e7 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs @@ -15,46 +15,26 @@ namespace Tensorflow.NumPy public static readonly Slice newaxis = new Slice(null, null, 1) { IsNewAxis = true }; // https://docs.scipy.org/doc/numpy-1.16.0/user/basics.types.html - public static readonly Type bool_ = typeof(bool); - public static readonly Type bool8 = bool_; - public static readonly Type @bool = bool_; - - public static readonly Type @char = typeof(char); - - public static readonly Type @byte = typeof(byte); - public static readonly Type uint8 = typeof(byte); - public static readonly Type ubyte = uint8; - - - public static readonly Type int16 = typeof(short); - - public static readonly Type uint16 = typeof(ushort); - - public static readonly Type int32 = typeof(int); - - public static readonly Type uint32 = typeof(uint); - - public static readonly Type int_ = typeof(long); - public static readonly Type int64 = int_; - public static readonly Type intp = int_; //TODO! IntPtr? - public static readonly Type int0 = int_; - - public static readonly Type uint64 = typeof(ulong); - public static readonly Type uint0 = uint64; - public static readonly Type @uint = uint64; - - public static readonly Type float32 = typeof(float); - - public static readonly Type float_ = typeof(double); - public static readonly Type float64 = float_; - public static readonly Type @double = float_; - - public static readonly Type complex_ = typeof(Complex); - public static readonly Type complex128 = complex_; - public static readonly Type complex64 = complex_; - public static readonly Type @decimal = typeof(decimal); - - public static Type chars => throw new NotSupportedException("Please use char with extra dimension."); + #region data type + public static readonly TF_DataType @bool = TF_DataType.TF_BOOL; + public static readonly TF_DataType @char = TF_DataType.TF_INT8; + public static readonly TF_DataType @byte = TF_DataType.TF_INT8; + public static readonly TF_DataType uint8 = TF_DataType.TF_UINT8; + public static readonly TF_DataType ubyte = TF_DataType.TF_UINT8; + public static readonly TF_DataType int16 = TF_DataType.TF_INT16; + public static readonly TF_DataType uint16 = TF_DataType.TF_UINT16; + public static readonly TF_DataType int32 = TF_DataType.TF_INT32; + public static readonly TF_DataType uint32 = TF_DataType.TF_UINT32; + public static readonly TF_DataType int64 = TF_DataType.TF_INT64; + public static readonly TF_DataType uint64 = TF_DataType.TF_UINT64; + public static readonly TF_DataType float32 = TF_DataType.TF_FLOAT; + public static readonly TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType @double = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX; + public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64; + public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128; + #endregion public static double nan => double.NaN; public static double NAN => double.NaN; @@ -70,8 +50,6 @@ namespace Tensorflow.NumPy public static double Infinity => double.PositiveInfinity; public static double infinity => double.PositiveInfinity; - - public static bool array_equal(NDArray a, NDArray b) => throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Numpy/NumpyDType.cs b/src/TensorFlowNET.Core/Numpy/NumpyDType.cs deleted file mode 100644 index c933d6a8..00000000 --- a/src/TensorFlowNET.Core/Numpy/NumpyDType.cs +++ /dev/null @@ -1,90 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Numerics; -using System.Text; - -namespace Tensorflow.NumPy -{ - /// - /// Represents all available types in numpy. - /// - /// The int values of the enum are a copy of excluding types not available in numpy. - public enum NumpyDType - { - /// A null reference. - Empty = 0, - - /// A simple type representing Boolean values of true or false. - Boolean = 3, - - /// An integral type representing unsigned 16-bit integers with values between 0 and 65535. The set of possible values for the type corresponds to the Unicode character set. - Char = 4, - - /// An integral type representing unsigned 8-bit integers with values between 0 and 255. - Byte = 6, - - /// An integral type representing signed 16-bit integers with values between -32768 and 32767. - Int16 = 7, - - /// An integral type representing unsigned 16-bit integers with values between 0 and 65535. - UInt16 = 8, - - /// An integral type representing signed 32-bit integers with values between -2147483648 and 2147483647. - Int32 = 9, - - /// An integral type representing unsigned 32-bit integers with values between 0 and 4294967295. - UInt32 = 10, // 0x0000000A - - /// An integral type representing signed 64-bit integers with values between -9223372036854775808 and 9223372036854775807. - Int64 = 11, // 0x0000000B - - /// An integral type representing unsigned 64-bit integers with values between 0 and 18446744073709551615. - UInt64 = 12, // 0x0000000C - - /// A floating point type representing values ranging from approximately 1.5 x 10 -45 to 3.4 x 10 38 with a precision of 7 digits. - Single = 13, // 0x0000000D - Float = 13, // 0x0000000D - - /// A floating point type representing values ranging from approximately 5.0 x 10 -324 to 1.7 x 10 308 with a precision of 15-16 digits. - Double = 14, // 0x0000000E - - /// A simple type representing values ranging from 1.0 x 10 -28 to approximately 7.9 x 10 28 with 28-29 significant digits. - Decimal = 15, // 0x0000000F - - /// A sealed class type representing Unicode character strings. - String = 18, // 0x00000012 - - Complex = 128, //0x00000080 - } - - public static class NTTypeCodeExtension - { - public static NumpyDType GetTypeCode(this Type type) - { - // ReSharper disable once PossibleNullReferenceException - while (type.IsArray) - type = type.GetElementType(); - - var tc = Type.GetTypeCode(type); - if (tc == TypeCode.Object) - { - if (type == typeof(Complex)) - { - return NumpyDType.Complex; - } - - return NumpyDType.Empty; - } - - try - { - return (NumpyDType)(int)tc; - } - catch (InvalidCastException) - { - return NumpyDType.Empty; - } - } - } - -} diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index df229252..be10541e 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -224,7 +224,7 @@ namespace Tensorflow dtype = t.dtype.as_base_dtype(); break; case NDArray t: - dtype = t.dtype.as_tf_dtype(); + dtype = t.dtype; break; } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 2459ae1a..3289c938 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -133,7 +133,7 @@ namespace Tensorflow } public unsafe Tensor(NDArray nd) - => _handle = TF_NewTensor(nd.shape, nd.dtype.as_tf_dtype(), nd.data.ToPointer()); + => _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); #region scala public Tensor(bool value) => _handle = TF_NewTensor(value); diff --git a/src/TensorFlowNET.Core/Tensors/TensorConverter.cs b/src/TensorFlowNET.Core/Tensors/TensorConverter.cs deleted file mode 100644 index 8521f8eb..00000000 --- a/src/TensorFlowNET.Core/Tensors/TensorConverter.cs +++ /dev/null @@ -1,225 +0,0 @@ -using Tensorflow.NumPy; -using System; -using System.Threading.Tasks; -using Tensorflow.Util; - -namespace Tensorflow -{ - /// - /// Provides various methods to conversion between types and . - /// - public static class TensorConverter - { - /// - /// Convert given to . - /// - /// The ndarray to convert, can be regular, jagged or multi-dim array. - /// Convert to given before inserting it into a . - /// - public static Tensor ToTensor(NDArray nd, TF_DataType? astype = null) - { - // return new Tensor(astype == null ? nd : nd.astype(astype.Value.as_numpy_typecode(), false)); - throw new NotImplementedException(""); - } - - /// - /// Convert given to . - /// - /// The ndarray to convert. - /// Convert to given before inserting it into a . - /// - public static Tensor ToTensor(NDArray nd, NumpyDType? astype = null) - { - // return new Tensor(astype == null ? nd : nd.astype(astype.Value, false)); - throw new NotImplementedException(""); - } - - /// - /// Convert given to . - /// - /// The array to convert, can be regular, jagged or multi-dim array. - /// Convert to given before inserting it into a . - /// - public static Tensor ToTensor(Array array, TF_DataType? astype = null) - { - if (array == null) throw new ArgumentNullException(nameof(array)); - var arrtype = array.ResolveElementType(); - - var astype_type = astype?.as_system_dtype() ?? arrtype; - if (astype_type == arrtype) - { - //no conversion required - if (astype == TF_DataType.TF_STRING) - { - throw new NotSupportedException(); //TODO! when string is fully implemented. - } - - if (astype == TF_DataType.TF_INT8) - { - // if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged - // array = Arrays.Flatten(array); - - return new Tensor((sbyte[])array); - } - - //is multidim or jagged, if so - use NDArrays constructor as it records shape. - if (array.Rank != 1 || array.GetType().GetElementType().IsArray) - return new Tensor(array, array.GetShape()); - - switch (arrtype.GetTypeCode()) - { - case NumpyDType.Boolean: return new Tensor((bool[])array); - case NumpyDType.Byte: return new Tensor((byte[])array); - case NumpyDType.Int16: return new Tensor((short[])array); - case NumpyDType.UInt16: return new Tensor((ushort[])array); - case NumpyDType.Int32: return new Tensor((int[])array); - case NumpyDType.UInt32: return new Tensor((uint[])array); - case NumpyDType.Int64: return new Tensor((long[])array); - case NumpyDType.UInt64: return new Tensor((ulong[])array); - // case NPTypeCode.Char: return new Tensor((char[])array); - case NumpyDType.Double: return new Tensor((double[])array); - case NumpyDType.Single: return new Tensor((float[])array); - default: - throw new NotSupportedException(); - } - } - else - { - //conversion is required. - //by this point astype is not null. - - //flatten if required - /*if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged - array = Arrays.Flatten(array); - - try - { - return ToTensor( - ArrayConvert.To(array, astype.Value.as_numpy_typecode()), - null - ); - } - catch (NotSupportedException) - { - //handle dtypes not supported by ArrayConvert - var ret = Array.CreateInstance(astype_type, array.LongLength); - Parallel.For(0, ret.LongLength, i => ret.SetValue(Convert.ChangeType(array.GetValue(i), astype_type), i)); - return ToTensor(ret, null); - }*/ - throw new NotImplementedException(""); - } - } - - /// - /// Convert given to . - /// - /// The constant scalar to convert - /// Convert to given before inserting it into a . - /// - public static Tensor ToTensor(T constant, TF_DataType? astype = null) where T : unmanaged - { - //was conversion requested? - if (astype == null) - { - //No conversion required - var constantType = typeof(T).as_tf_dtype(); - if (constantType == TF_DataType.TF_INT8) - return new Tensor((sbyte)(object)constant); - - if (constantType == TF_DataType.TF_STRING) - return new Tensor((string)(object)constant); - - /*switch (InfoOf.NPTypeCode) - { - case NPTypeCode.Boolean: return new Tensor((bool)(object)constant); - case NPTypeCode.Byte: return new Tensor((byte)(object)constant); - case NPTypeCode.Int16: return new Tensor((short)(object)constant); - case NPTypeCode.UInt16: return new Tensor((ushort)(object)constant); - case NPTypeCode.Int32: return new Tensor((int)(object)constant); - case NPTypeCode.UInt32: return new Tensor((uint)(object)constant); - case NPTypeCode.Int64: return new Tensor((long)(object)constant); - case NPTypeCode.UInt64: return new Tensor((ulong)(object)constant); - // case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); - case NPTypeCode.Double: return new Tensor((double)(object)constant); - case NPTypeCode.Single: return new Tensor((float)(object)constant); - default: - throw new NotSupportedException(); - }*/ - throw new NotImplementedException(""); - } - - //conversion required - - /*if (astype == TF_DataType.TF_INT8) - return new Tensor(Converts.ToSByte(constant)); - - if (astype == TF_DataType.TF_STRING) - return new Tensor(Converts.ToString(constant)); - - var astype_np = astype?.as_numpy_typecode(); - - - switch (astype_np) - { - case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant)); - case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant)); - case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant)); - case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant)); - case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant)); - case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant)); - case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant)); - case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant)); - case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); - case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant)); - case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant)); - default: - throw new NotSupportedException(); - }*/ - throw new NotImplementedException(""); - - } - - /// - /// Convert given to . - /// - /// The constant scalar to convert - /// Convert to given before inserting it into a . - /// - public static Tensor ToTensor(string constant, TF_DataType? astype = null) - { - /*switch (astype) - { - //was conversion requested? - case null: - case TF_DataType.TF_STRING: - return new Tensor(constant); - //conversion required - case TF_DataType.TF_INT8: - return new Tensor(Converts.ToSByte(constant)); - default: - { - var astype_np = astype?.as_numpy_typecode(); - - switch (astype_np) - { - case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant)); - case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant)); - case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant)); - case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant)); - case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant)); - case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant)); - case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant)); - case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant)); - case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); - case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant)); - case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant)); - default: - throw new NotSupportedException(); - } - } - }*/ - throw new NotImplementedException(""); - } - - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index f934de22..574bffc1 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -148,7 +148,7 @@ namespace Tensorflow } else if (dtype != TF_DataType.DtInvalid && value is NDArray nd && - nd.dtype.as_tf_dtype() != dtype) + nd.dtype != dtype) { value = nd.astype(dtype.as_system_dtype()); } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index b811ba7f..a33f3fb8 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -79,46 +79,6 @@ namespace Tensorflow } } - /// - /// - /// - /// - /// - /// When has no equivalent - public static NumpyDType as_numpy_typecode(this TF_DataType type) - { - switch (type) - { - case TF_DataType.TF_BOOL: - return NumpyDType.Boolean; - case TF_DataType.TF_UINT8: - return NumpyDType.Byte; - case TF_DataType.TF_INT64: - return NumpyDType.Int64; - case TF_DataType.TF_INT32: - return NumpyDType.Int32; - case TF_DataType.TF_INT16: - return NumpyDType.Int16; - case TF_DataType.TF_UINT64: - return NumpyDType.UInt64; - case TF_DataType.TF_UINT32: - return NumpyDType.UInt32; - case TF_DataType.TF_UINT16: - return NumpyDType.UInt16; - case TF_DataType.TF_FLOAT: - return NumpyDType.Single; - case TF_DataType.TF_DOUBLE: - return NumpyDType.Double; - case TF_DataType.TF_STRING: - return NumpyDType.String; - case TF_DataType.TF_COMPLEX128: - case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX - return NumpyDType.Complex; - default: - throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); - } - } - /// /// /// @@ -369,15 +329,5 @@ namespace Tensorflow || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64; } - - public static TF_DataType as_tf_dtype(this NumpyDType type) - => type switch - { - NumpyDType.Int32 => TF_DataType.TF_INT32, - NumpyDType.Int64 => TF_DataType.TF_INT64, - NumpyDType.Float => TF_DataType.TF_FLOAT, - NumpyDType.Double => TF_DataType.TF_DOUBLE, - _ => TF_DataType.TF_UINT8 - }; } } diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index b809b1c4..56b0d2a7 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -44,7 +44,7 @@ namespace Tensorflow.Keras.Datasets var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt")); var x_train_string = new string[lines.Length]; - var y_train = np.zeros(new int[] { lines.Length }, NumpyDType.Int64); + var y_train = np.zeros(new int[] { lines.Length }, np.int64); for (int i = 0; i < lines.Length; i++) { y_train[i] = long.Parse(lines[i].Substring(0, 1)); @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Datasets File.ReadAllLines(Path.Combine(dst, "imdb_test.txt")); var x_test_string = new string[lines.Length]; - var y_test = np.zeros(new int[] { lines.Length }, NumpyDType.Int64); + var y_test = np.zeros(new int[] { lines.Length }, np.int64); for (int i = 0; i < lines.Length; i++) { y_test[i] = long.Parse(lines[i].Substring(0, 1)); diff --git a/src/TensorFlowNET.Keras/Sequence.cs b/src/TensorFlowNET.Keras/Sequence.cs index b9036a93..9db34322 100644 --- a/src/TensorFlowNET.Keras/Sequence.cs +++ b/src/TensorFlowNET.Keras/Sequence.cs @@ -55,7 +55,7 @@ namespace Tensorflow.Keras value = 0f; var type = dtypes.tf_dtype_from_name(dtype); - var nd = new NDArray((length.Count(), maxlen.Value), dtype: type.as_numpy_typecode()); + var nd = new NDArray((length.Count(), maxlen.Value), dtype: type); for (int i = 0; i < nd.dims[0]; i++) { diff --git a/src/TensorFlowNET.Keras/Utils/np_utils.cs b/src/TensorFlowNET.Keras/Utils/np_utils.cs index 1ae0c9c4..8430bff0 100644 --- a/src/TensorFlowNET.Keras/Utils/np_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/np_utils.cs @@ -16,9 +16,9 @@ namespace Tensorflow.Keras.Utils /// public static NDArray to_categorical(NDArray y, int num_classes = -1, TF_DataType dtype = TF_DataType.TF_FLOAT) { - var y1 = y.astype(NumpyDType.Int32).ToArray(); + var y1 = y.astype(np.int32).ToArray(); // var input_shape = y.shape[..^1]; - var categorical = np.zeros(((int)y.size, num_classes), dtype: dtype.as_numpy_typecode()); + var categorical = np.zeros(((int)y.size, num_classes), dtype: dtype); // categorical[np.arange(y.size), y] = 1; for (ulong i = 0; i < y.size; i++) { diff --git a/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs b/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs index 031aeaf4..6b642cdd 100644 --- a/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs @@ -94,7 +94,7 @@ namespace TensorFlowNET.UnitTest var input = tf.placeholder(tf.float64, shape: new TensorShape(6)); var op = tf.reshape(input, new int[] { 2, 3 }); sess.run(tf.global_variables_initializer()); - var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NumpyDType.Single) + 0.1f)); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f)); ret.Should().BeOfType().And.BeShaped(2, 3).And.BeOfValuesApproximately(0.001d, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1); print(ret.dtype); @@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest var input = tf.placeholder(tf.int64, shape: new TensorShape(6)); var op = tf.reshape(input, new int[] { 2, 3 }); sess.run(tf.global_variables_initializer()); - var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NumpyDType.Single) + 0.1f)); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f)); ret.Should().BeOfType().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); print(ret.dtype); @@ -122,7 +122,7 @@ namespace TensorFlowNET.UnitTest var input = tf.placeholder(tf.byte8, shape: new TensorShape(6)); var op = tf.reshape(input, new int[] { 2, 3 }); sess.run(tf.global_variables_initializer()); - var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NumpyDType.Single) + 0.1f)); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f)); ret.Should().BeOfType().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); print(ret.dtype); diff --git a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs index 39e72880..ba7b3829 100644 --- a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs +++ b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs @@ -204,12 +204,6 @@ namespace TensorFlowNET.UnitTest return new AndConstraint(this); } - public AndConstraint BeOfType(NumpyDType typeCode) - { - Subject.dtype.Should().Be(typeCode); - return new AndConstraint(this); - } - public AndConstraint BeOfType(Type typeCode) { Subject.dtype.Should().Be(typeCode); @@ -287,7 +281,7 @@ namespace TensorFlowNET.UnitTest switch (Subject.dtype) { - case NumpyDType.Boolean: + case TF_DataType.TF_BOOL: { var iter = Subject.AsIterator(); var hasnext = iter.HasNext; @@ -308,7 +302,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.Byte: + case TF_DataType.TF_INT8: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -330,7 +324,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.Int16: + case TF_DataType.TF_INT16: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -352,7 +346,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.UInt16: + case TF_DataType.TF_UINT16: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -374,7 +368,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.Int32: + case TF_DataType.TF_INT32: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -396,7 +390,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.UInt32: + case TF_DataType.TF_UINT32: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -418,7 +412,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.Int64: + case TF_DataType.TF_INT64: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -440,7 +434,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.UInt64: + case TF_DataType.TF_UINT64: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -462,7 +456,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.Char: + case TF_DataType.TF_UINT8: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -484,7 +478,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.Double: + case TF_DataType.TF_DOUBLE: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -506,7 +500,7 @@ namespace TensorFlowNET.UnitTest break; } - case NumpyDType.Single: + case TF_DataType.TF_FLOAT: { var iter = Subject.AsIterator(); /*var next = iter.MoveNext; @@ -527,29 +521,6 @@ namespace TensorFlowNET.UnitTest break; } - - case NumpyDType.Decimal: - { - var iter = Subject.AsIterator(); - /*var next = iter.MoveNext; - var hasnext = iter.HasNext; - for (int i = 0; i < values.Length; i++) - { - Execute.Assertion - .ForCondition(hasnext()) - .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); - - var expected = Convert.ToDecimal(values[i]); - var nextval = next(); - - Execute.Assertion - .ForCondition(expected == nextval) - .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Decimal).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); - }*/ - - break; - } - default: throw new NotSupportedException(); }