| @@ -39,15 +39,13 @@ namespace Tensorflow | |||
| _graph.as_default(); | |||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||
| SessionOptions newOpts = null; | |||
| if (opts == null) | |||
| newOpts = new SessionOptions(); | |||
| SessionOptions newOpts = opts ?? new SessionOptions(); | |||
| var status = new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
| // dispose newOpts | |||
| // dispose opts only if not provided externally. | |||
| if (opts == null) | |||
| newOpts.Dispose(); | |||
| @@ -102,25 +100,17 @@ namespace Tensorflow | |||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| var feed_map = new Dictionary<object, object>(); | |||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||
| { | |||
| return new (object, object)[] { (item.Key, item.Value) }; | |||
| }; | |||
| //var feed_map = new Dictionary<object, object>(); | |||
| // Validate and process feed_dict. | |||
| if (feed_dict != null) | |||
| { | |||
| foreach (var feed in feed_dict) | |||
| foreach (var subfeed in feed_dict) | |||
| { | |||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
| } | |||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed.Value; | |||
| //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||
| } | |||
| } | |||
| @@ -157,86 +147,71 @@ namespace Tensorflow | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| { | |||
| var feeds = feed_dict.Select(x => | |||
| var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||
| var ignoreDispose = new bool[feed_dict.Count]; | |||
| int i = 0; | |||
| foreach (var x in feed_dict) | |||
| { | |||
| if (x.Key is Tensor tensor) | |||
| { | |||
| switch (x.Value) | |||
| { | |||
| case Tensor v: ignoreDispose[i] = true; feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break; | |||
| case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; | |||
| #if _REGEN | |||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case #1[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| % | |||
| %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| % | |||
| #else | |||
| case sbyte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case sbyte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| #endif | |||
| case bool v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||
| case string v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case IntPtr v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Tensor v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||
| case NDArray v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||
| case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; | |||
| case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| default: | |||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||
| throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | |||
| } | |||
| } | |||
| throw new NotImplementedException("_do_run.feed_dict"); | |||
| }).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| var targets = target_list; | |||
| } | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| //var targets = target_list; | |||
| try | |||
| { | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| } finally | |||
| { | |||
| for (var idx = 0; idx < feeds.Length; idx++) | |||
| { | |||
| if (ignoreDispose[idx]) | |||
| continue; | |||
| feeds[idx].Value.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
| @@ -268,9 +243,6 @@ namespace Tensorflow | |||
| for (int i = 0; i < fetch_list.Length; i++) | |||
| result[i] = fetchValue(output_values[i]); | |||
| for (int i = 0; i < feed_dict.Length; i++) | |||
| feed_dict[i].Value.Dispose(); | |||
| return result; | |||
| } | |||
| @@ -280,7 +252,7 @@ namespace Tensorflow | |||
| NDArray nd = null; | |||
| Type type = tensor.dtype.as_numpy_dtype(); | |||
| var ndims = tensor.shape; | |||
| var offset = c_api.TF_TensorData(output); | |||
| var offset = (byte*) c_api.TF_TensorData(output); | |||
| if(ndims.Length == 0) | |||
| { | |||