Added transparent dtype conversion to feed_dicttags/v0.12
| @@ -107,7 +107,7 @@ namespace Tensorflow | |||||
| foreach (var subfeed in feed_dict) | foreach (var subfeed in feed_dict) | ||||
| { | { | ||||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | ||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| //var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed.Value; | feed_dict_tensor[subfeed_t] = subfeed.Value; | ||||
| //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | ||||
| } | } | ||||
| @@ -150,58 +150,64 @@ namespace Tensorflow | |||||
| int i = 0; | int i = 0; | ||||
| foreach (var x in feed_dict) | foreach (var x in feed_dict) | ||||
| { | { | ||||
| if (x.Key is Tensor tensor) | |||||
| if (x.Key is Tensor key) | |||||
| { | { | ||||
| switch (x.Value) | switch (x.Value) | ||||
| { | { | ||||
| case Tensor v: | case Tensor v: | ||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| if (v.dtype != key.dtype) | |||||
| throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||||
| break; | break; | ||||
| case NDArray v: | case NDArray v: | ||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); | |||||
| break; | break; | ||||
| case IntPtr v: | case IntPtr v: | ||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| var tensor = new Tensor(v); | |||||
| if (tensor.dtype != key.dtype) | |||||
| throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||||
| break; | break; | ||||
| #if _REGEN | #if _REGEN | ||||
| // @formatter:off — disable formatter after this line | // @formatter:off — disable formatter after this line | ||||
| %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| % | |||||
| %types = ["bool", "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| % | |||||
| // @formatter:on — enable formatter after this line | // @formatter:on — enable formatter after this line | ||||
| #else | #else | ||||
| // @formatter:off — disable formatter after this line | // @formatter:off — disable formatter after this line | ||||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case bool[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| // @formatter:on — enable formatter after this line | // @formatter:on — enable formatter after this line | ||||
| #endif | #endif | ||||
| case bool v: | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); | |||||
| break; | |||||
| case string v: | case string v: | ||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | ||||
| @@ -214,6 +220,7 @@ namespace Tensorflow | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | return _call_tf_sessionrun(feeds, fetches, target_list); | ||||
| } | } | ||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | ||||
| { | { | ||||
| // Ensure any changes to the graph are reflected in the runtime. | // Ensure any changes to the graph are reflected in the runtime. | ||||
| @@ -0,0 +1,285 @@ | |||||
| using System; | |||||
| using System.Threading.Tasks; | |||||
| using NumSharp; | |||||
| using NumSharp.Backends; | |||||
| using NumSharp.Utilities; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Provides various methods to conversion between types and <see cref="Tensor"/>. | |||||
| /// </summary> | |||||
| public static class TensorConverter | |||||
| { | |||||
| /// <summary> | |||||
| /// Convert given <see cref="Array"/> to <see cref="Tensor"/>. | |||||
| /// </summary> | |||||
| /// <param name="nd">The ndarray to convert, can be regular, jagged or multi-dim array.</param> | |||||
| /// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param> | |||||
| /// <exception cref="NotSupportedException"></exception> | |||||
| public static Tensor ToTensor(NDArray nd, TF_DataType? astype = null) | |||||
| { | |||||
| return new Tensor(astype == null ? nd : nd.astype(astype.Value.as_numpy_typecode(), false)); | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert given <see cref="NDArray"/> to <see cref="Tensor"/>. | |||||
| /// </summary> | |||||
| /// <param name="nd">The ndarray to convert.</param> | |||||
| /// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param> | |||||
| /// <exception cref="NotSupportedException"></exception> | |||||
| public static Tensor ToTensor(NDArray nd, NPTypeCode? astype = null) | |||||
| { | |||||
| return new Tensor(astype == null ? nd : nd.astype(astype.Value, false)); | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert given <see cref="Array"/> to <see cref="Tensor"/>. | |||||
| /// </summary> | |||||
| /// <param name="array">The array to convert, can be regular, jagged or multi-dim array.</param> | |||||
| /// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param> | |||||
| /// <exception cref="NotSupportedException"></exception> | |||||
| public static Tensor ToTensor(Array array, TF_DataType? astype = null) | |||||
| { | |||||
| if (array == null) throw new ArgumentNullException(nameof(array)); | |||||
| var arrtype = array.ResolveElementType(); | |||||
| var astype_type = astype?.as_numpy_dtype() ?? arrtype; | |||||
| if (astype_type == arrtype) | |||||
| { | |||||
| //no conversion required | |||||
| if (astype == TF_DataType.TF_STRING) | |||||
| { | |||||
| throw new NotSupportedException(); //TODO! when string is fully implemented. | |||||
| } | |||||
| if (astype == TF_DataType.TF_INT8) | |||||
| { | |||||
| if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged | |||||
| array = Arrays.Flatten(array); | |||||
| return new Tensor((sbyte[]) array); | |||||
| } | |||||
| //is multidim or jagged, if so - use NDArrays constructor as it records shape. | |||||
| if (array.Rank != 1 || array.GetType().GetElementType().IsArray) | |||||
| return new Tensor(new NDArray(array)); | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (arrtype) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: return new Tensor((#2[])arr); | |||||
| % | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (arrtype.GetTypeCode()) | |||||
| { | |||||
| case NPTypeCode.Boolean: return new Tensor((bool[]) array); | |||||
| case NPTypeCode.Byte: return new Tensor((byte[]) array); | |||||
| case NPTypeCode.Int16: return new Tensor((short[]) array); | |||||
| case NPTypeCode.UInt16: return new Tensor((ushort[]) array); | |||||
| case NPTypeCode.Int32: return new Tensor((int[]) array); | |||||
| case NPTypeCode.UInt32: return new Tensor((uint[]) array); | |||||
| case NPTypeCode.Int64: return new Tensor((long[]) array); | |||||
| case NPTypeCode.UInt64: return new Tensor((ulong[]) array); | |||||
| case NPTypeCode.Char: return new Tensor((char[]) array); | |||||
| case NPTypeCode.Double: return new Tensor((double[]) array); | |||||
| case NPTypeCode.Single: return new Tensor((float[]) array); | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } else | |||||
| { | |||||
| //conversion is required. | |||||
| //by this point astype is not null. | |||||
| //flatten if required | |||||
| if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged | |||||
| array = Arrays.Flatten(array); | |||||
| try | |||||
| { | |||||
| return ToTensor( | |||||
| ArrayConvert.To(array, astype.Value.as_numpy_typecode()), | |||||
| null | |||||
| ); | |||||
| } catch (NotSupportedException) | |||||
| { | |||||
| //handle dtypes not supported by ArrayConvert | |||||
| var ret = Array.CreateInstance(astype_type, array.LongLength); | |||||
| Parallel.For(0, ret.LongLength, i => ret.SetValue(Convert.ChangeType(array.GetValue(i), astype_type), i)); | |||||
| return ToTensor(ret, null); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert given <see cref="Array"/> to <see cref="Tensor"/>. | |||||
| /// </summary> | |||||
| /// <param name="constant">The constant scalar to convert</param> | |||||
| /// <param name="astype">Convert <paramref name="constant"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param> | |||||
| /// <exception cref="NotSupportedException"></exception> | |||||
| public static Tensor ToTensor<T>(T constant, TF_DataType? astype = null) where T : unmanaged | |||||
| { | |||||
| //was conversion requested? | |||||
| if (astype == null) | |||||
| { | |||||
| //No conversion required | |||||
| var constantType = typeof(T).as_dtype(); | |||||
| if (constantType == TF_DataType.TF_INT8) | |||||
| return new Tensor((sbyte) (object) constant); | |||||
| if (constantType == TF_DataType.TF_STRING) | |||||
| return new Tensor((string) (object) constant); | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (InfoOf<T>.NPTypeCode) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: return new Tensor((#2)(object)constant); | |||||
| % | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (InfoOf<T>.NPTypeCode) | |||||
| { | |||||
| case NPTypeCode.Boolean: return new Tensor((bool) (object) constant); | |||||
| case NPTypeCode.Byte: return new Tensor((byte) (object) constant); | |||||
| case NPTypeCode.Int16: return new Tensor((short) (object) constant); | |||||
| case NPTypeCode.UInt16: return new Tensor((ushort) (object) constant); | |||||
| case NPTypeCode.Int32: return new Tensor((int) (object) constant); | |||||
| case NPTypeCode.UInt32: return new Tensor((uint) (object) constant); | |||||
| case NPTypeCode.Int64: return new Tensor((long) (object) constant); | |||||
| case NPTypeCode.UInt64: return new Tensor((ulong) (object) constant); | |||||
| case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); | |||||
| case NPTypeCode.Double: return new Tensor((double) (object) constant); | |||||
| case NPTypeCode.Single: return new Tensor((float) (object) constant); | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| //conversion required | |||||
| if (astype == TF_DataType.TF_INT8) | |||||
| return new Tensor(Converts.ToSByte(constant)); | |||||
| if (astype == TF_DataType.TF_STRING) | |||||
| return new Tensor(Converts.ToString(constant)); | |||||
| var astype_np = astype?.as_numpy_typecode(); | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (astype_np) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: return new Tensor(Converts.To#1(constant)); | |||||
| % | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (astype_np) | |||||
| { | |||||
| case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant)); | |||||
| case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant)); | |||||
| case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant)); | |||||
| case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant)); | |||||
| case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant)); | |||||
| case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant)); | |||||
| case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant)); | |||||
| case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant)); | |||||
| case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); | |||||
| case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant)); | |||||
| case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant)); | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert given <see cref="Array"/> to <see cref="Tensor"/>. | |||||
| /// </summary> | |||||
| /// <param name="constant">The constant scalar to convert</param> | |||||
| /// <param name="astype">Convert <paramref name="constant"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param> | |||||
| /// <exception cref="NotSupportedException"></exception> | |||||
| public static Tensor ToTensor(string constant, TF_DataType? astype = null) | |||||
| { | |||||
| switch (astype) | |||||
| { | |||||
| //was conversion requested? | |||||
| case null: | |||||
| case TF_DataType.TF_STRING: | |||||
| return new Tensor(constant); | |||||
| //conversion required | |||||
| case TF_DataType.TF_INT8: | |||||
| return new Tensor(Converts.ToSByte(constant)); | |||||
| default: | |||||
| { | |||||
| var astype_np = astype?.as_numpy_typecode(); | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (astype_np) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: return new Tensor(Converts.To#1(constant)); | |||||
| % | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (astype_np) | |||||
| { | |||||
| case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant)); | |||||
| case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant)); | |||||
| case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant)); | |||||
| case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant)); | |||||
| case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant)); | |||||
| case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant)); | |||||
| case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant)); | |||||
| case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant)); | |||||
| case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); | |||||
| case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant)); | |||||
| case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant)); | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -45,6 +45,8 @@ namespace Tensorflow | |||||
| return typeof(bool); | return typeof(bool); | ||||
| case TF_DataType.TF_UINT8: | case TF_DataType.TF_UINT8: | ||||
| return typeof(byte); | return typeof(byte); | ||||
| case TF_DataType.TF_INT8: | |||||
| return typeof(sbyte); | |||||
| case TF_DataType.TF_INT64: | case TF_DataType.TF_INT64: | ||||
| return typeof(long); | return typeof(long); | ||||
| case TF_DataType.TF_UINT64: | case TF_DataType.TF_UINT64: | ||||
| @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; | |||||
| using System.Text; | using System.Text; | ||||
| using FluentAssertions; | using FluentAssertions; | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using NumSharp.Backends; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -131,5 +132,61 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Autocast_Case1() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(tf.float64, shape: new TensorShape(6)); | |||||
| var op = tf.reshape(input, new int[] {2, 3}); | |||||
| sess.run(tf.global_variables_initializer()); | |||||
| var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6))); | |||||
| ret.Should().BeOfType<double>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); | |||||
| print(ret.dtype); | |||||
| print(ret); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Autocast_Case2() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(tf.float64, shape: new TensorShape(6)); | |||||
| var op = tf.reshape(input, new int[] {2, 3}); | |||||
| sess.run(tf.global_variables_initializer()); | |||||
| var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f)); | |||||
| ret.Should().BeOfType<double>().And.BeShaped(2, 3).And.BeOfValuesApproximately(0.001d, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1); | |||||
| print(ret.dtype); | |||||
| print(ret); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Autocast_Case3() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(tf.int16, shape: new TensorShape(6)); | |||||
| var op = tf.reshape(input, new int[] {2, 3}); | |||||
| sess.run(tf.global_variables_initializer()); | |||||
| var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f)); | |||||
| ret.Should().BeOfType<short>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); | |||||
| print(ret.dtype); | |||||
| print(ret); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Autocast_Case4() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(tf.@byte, shape: new TensorShape(6)); | |||||
| var op = tf.reshape(input, new int[] {2, 3}); | |||||
| sess.run(tf.global_variables_initializer()); | |||||
| var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f)); | |||||
| ret.Should().BeOfType<byte>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); | |||||
| print(ret.dtype); | |||||
| print(ret); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||