From 98c383ccd32055dcdd02853f2c28d86163164501 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Thu, 29 Aug 2019 12:26:46 +0300 Subject: [PATCH] BaseSession: Perf-op --- .../Sessions/BaseSession.cs | 152 +++++++----------- 1 file changed, 62 insertions(+), 90 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 58177df2..c6dd3fd1 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -39,15 +39,13 @@ namespace Tensorflow _graph.as_default(); _target = UTF8Encoding.UTF8.GetBytes(target); - SessionOptions newOpts = null; - if (opts == null) - newOpts = new SessionOptions(); + SessionOptions newOpts = opts ?? new SessionOptions(); var status = new Status(); _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); - // dispose newOpts + // dispose opts only if not provided externally. if (opts == null) newOpts.Dispose(); @@ -102,25 +100,17 @@ namespace Tensorflow private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) { var feed_dict_tensor = new Dictionary(); - var feed_map = new Dictionary(); - - Func> feed_fn = (item) => - { - return new (object, object)[] { (item.Key, item.Value) }; - }; + //var feed_map = new Dictionary(); // Validate and process feed_dict. if (feed_dict != null) { - foreach (var feed in feed_dict) + foreach (var subfeed in feed_dict) { - foreach (var (subfeed, subfeed_val) in feed_fn(feed)) - { - var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); - //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used - feed_dict_tensor[subfeed_t] = subfeed_val; - feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); - } + var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); + //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used + feed_dict_tensor[subfeed_t] = subfeed.Value; + //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); } } @@ -157,86 +147,71 @@ namespace Tensorflow /// private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { - var feeds = feed_dict.Select(x => + + var feeds = new KeyValuePair[feed_dict.Count]; + var ignoreDispose = new bool[feed_dict.Count]; + int i = 0; + foreach (var x in feed_dict) { if (x.Key is Tensor tensor) { switch (x.Value) { + case Tensor v: ignoreDispose[i] = true; feeds[i++] = new KeyValuePair(tensor._as_tf_output(), v); break; + case NDArray v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; #if _REGEN - %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] - %foreach types% - case #1 v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case #1[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - % + %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %foreach types% + case #1 v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case #1[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + % #else - case sbyte v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case sbyte[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case byte v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case byte[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case short v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case short[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ushort v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ushort[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case int v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case int[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case uint v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case uint[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case long v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case long[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ulong v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ulong[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case float v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case float[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case double v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case double[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Complex v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Complex[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case sbyte v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case sbyte[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case byte v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case byte[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case short v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case short[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case ushort v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case ushort[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case int v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case int[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case uint v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case uint[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case long v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case long[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case ulong v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case ulong[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case float v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case float[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case double v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case double[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case Complex v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case Complex[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; #endif - case bool v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); - case string v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case IntPtr v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Tensor v: - return new KeyValuePair(tensor._as_tf_output(), v); - case NDArray v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); + case bool v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; + case string v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case IntPtr v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; default: - throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "")}"); + throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? ""}"); } } - throw new NotImplementedException("_do_run.feed_dict"); - }).ToArray(); - var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); - var targets = target_list; + } - return _call_tf_sessionrun(feeds, fetches, target_list); + var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); + //var targets = target_list; + try + { + return _call_tf_sessionrun(feeds, fetches, target_list); + } finally + { + for (var idx = 0; idx < feeds.Length; idx++) + { + if (ignoreDispose[idx]) + continue; + feeds[idx].Value.Dispose(); + } + } } private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) @@ -268,9 +243,6 @@ namespace Tensorflow for (int i = 0; i < fetch_list.Length; i++) result[i] = fetchValue(output_values[i]); - for (int i = 0; i < feed_dict.Length; i++) - feed_dict[i].Value.Dispose(); - return result; } @@ -280,7 +252,7 @@ namespace Tensorflow NDArray nd = null; Type type = tensor.dtype.as_numpy_dtype(); var ndims = tensor.shape; - var offset = c_api.TF_TensorData(output); + var offset = (byte*) c_api.TF_TensorData(output); if(ndims.Length == 0) {