Browse Source

BaseSession: Added transparent dtype conversion to feed_dict

tags/v0.12
Eli Belash 6 years ago
parent
commit
b252a07d29
1 changed files with 43 additions and 36 deletions
  1. +43
    -36
      src/TensorFlowNET.Core/Sessions/BaseSession.cs

+ 43
- 36
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -107,7 +107,7 @@ namespace Tensorflow
foreach (var subfeed in feed_dict)
{
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed.Value;
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
}
@@ -150,58 +150,64 @@ namespace Tensorflow
int i = 0;
foreach (var x in feed_dict)
{
if (x.Key is Tensor tensor)
if (x.Key is Tensor key)
{
switch (x.Value)
{
case Tensor v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
if (v.dtype != key.dtype)
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}");
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v);
break;
case NDArray v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype));
break;
case IntPtr v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
var tensor = new Tensor(v);
if (tensor.dtype != key.dtype)
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}");

feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor);
break;
#if _REGEN
// @formatter:off — disable formatter after this line
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
%
%types = ["bool", "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
%
// @formatter:on — enable formatter after this line
#else
// @formatter:off — disable formatter after this line
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case bool[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
// @formatter:on — enable formatter after this line
#endif
case bool v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL));
break;

case string v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype));
break;
default:
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
@@ -214,6 +220,7 @@ namespace Tensorflow
return _call_tf_sessionrun(feeds, fetches, target_list);
}


private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.


Loading…
Cancel
Save