From b252a07d29489d69b42e2a378ab586e8cc5974d6 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 1 Sep 2019 23:17:49 +0300 Subject: [PATCH] BaseSession: Added transparent dtype conversion to feed_dict --- .../Sessions/BaseSession.cs | 79 ++++++++++--------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 4066c1df..1701c625 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -107,7 +107,7 @@ namespace Tensorflow foreach (var subfeed in feed_dict) { var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); - //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used + //var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used feed_dict_tensor[subfeed_t] = subfeed.Value; //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); } @@ -150,58 +150,64 @@ namespace Tensorflow int i = 0; foreach (var x in feed_dict) { - if (x.Key is Tensor tensor) + if (x.Key is Tensor key) { switch (x.Value) { case Tensor v: - feeds[i++] = new KeyValuePair(tensor._as_tf_output(), v); + if (v.dtype != key.dtype) + throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); + feeds[i++] = new KeyValuePair(key._as_tf_output(), v); break; case NDArray v: - feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); + feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; case IntPtr v: - feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + var tensor = new Tensor(v); + if (tensor.dtype != key.dtype) + throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); + + feeds[i++] = new KeyValuePair(key._as_tf_output(), tensor); break; #if _REGEN // @formatter:off — disable formatter after this line - %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] - %foreach types% - case #1 v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case #1[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - % + %types = ["bool", "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %foreach types% + case #1 v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case #1[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + % // @formatter:on — enable formatter after this line #else // @formatter:off — disable formatter after this line - case sbyte v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case sbyte[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case byte v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case byte[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case short v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case short[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case ushort v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case ushort[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case int v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case int[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case uint v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case uint[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case long v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case long[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case ulong v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case ulong[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case float v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case float[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case double v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case double[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case Complex v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case Complex[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case bool v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case bool[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case sbyte v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case sbyte[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case byte v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case byte[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case short v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case short[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case ushort v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case ushort[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case int v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case int[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case uint v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case uint[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case long v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case long[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case ulong v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case ulong[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case float v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case float[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case double v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case double[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case Complex v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case Complex[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; // @formatter:on — enable formatter after this line #endif - case bool v: - feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); - break; + case string v: - feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; default: throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? ""}"); @@ -214,6 +220,7 @@ namespace Tensorflow return _call_tf_sessionrun(feeds, fetches, target_list); } + private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) { // Ensure any changes to the graph are reflected in the runtime.