| @@ -107,7 +107,7 @@ namespace Tensorflow | |||||
| foreach (var subfeed in feed_dict) | foreach (var subfeed in feed_dict) | ||||
| { | { | ||||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | ||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| //var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed.Value; | feed_dict_tensor[subfeed_t] = subfeed.Value; | ||||
| //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | ||||
| } | } | ||||
| @@ -150,58 +150,64 @@ namespace Tensorflow | |||||
| int i = 0; | int i = 0; | ||||
| foreach (var x in feed_dict) | foreach (var x in feed_dict) | ||||
| { | { | ||||
| if (x.Key is Tensor tensor) | |||||
| if (x.Key is Tensor key) | |||||
| { | { | ||||
| switch (x.Value) | switch (x.Value) | ||||
| { | { | ||||
| case Tensor v: | case Tensor v: | ||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| if (v.dtype != key.dtype) | |||||
| throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||||
| break; | break; | ||||
| case NDArray v: | case NDArray v: | ||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); | |||||
| break; | break; | ||||
| case IntPtr v: | case IntPtr v: | ||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| var tensor = new Tensor(v); | |||||
| if (tensor.dtype != key.dtype) | |||||
| throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||||
| break; | break; | ||||
| #if _REGEN | #if _REGEN | ||||
| // @formatter:off — disable formatter after this line | // @formatter:off — disable formatter after this line | ||||
| %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| % | |||||
| %types = ["bool", "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| % | |||||
| // @formatter:on — enable formatter after this line | // @formatter:on — enable formatter after this line | ||||
| #else | #else | ||||
| // @formatter:off — disable formatter after this line | // @formatter:off — disable formatter after this line | ||||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case bool[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||||
| // @formatter:on — enable formatter after this line | // @formatter:on — enable formatter after this line | ||||
| #endif | #endif | ||||
| case bool v: | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); | |||||
| break; | |||||
| case string v: | case string v: | ||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | ||||
| @@ -214,6 +220,7 @@ namespace Tensorflow | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | return _call_tf_sessionrun(feeds, fetches, target_list); | ||||
| } | } | ||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | ||||
| { | { | ||||
| // Ensure any changes to the graph are reflected in the runtime. | // Ensure any changes to the graph are reflected in the runtime. | ||||