| @@ -39,15 +39,13 @@ namespace Tensorflow | |||||
| _graph.as_default(); | _graph.as_default(); | ||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | _target = UTF8Encoding.UTF8.GetBytes(target); | ||||
| SessionOptions newOpts = null; | |||||
| if (opts == null) | |||||
| newOpts = new SessionOptions(); | |||||
| SessionOptions newOpts = opts ?? new SessionOptions(); | |||||
| var status = new Status(); | var status = new Status(); | ||||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | ||||
| // dispose newOpts | |||||
| // dispose opts only if not provided externally. | |||||
| if (opts == null) | if (opts == null) | ||||
| newOpts.Dispose(); | newOpts.Dispose(); | ||||
| @@ -102,25 +100,17 @@ namespace Tensorflow | |||||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | ||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<object, object>(); | var feed_dict_tensor = new Dictionary<object, object>(); | ||||
| var feed_map = new Dictionary<object, object>(); | |||||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||||
| { | |||||
| return new (object, object)[] { (item.Key, item.Value) }; | |||||
| }; | |||||
| //var feed_map = new Dictionary<object, object>(); | |||||
| // Validate and process feed_dict. | // Validate and process feed_dict. | ||||
| if (feed_dict != null) | if (feed_dict != null) | ||||
| { | { | ||||
| foreach (var feed in feed_dict) | |||||
| foreach (var subfeed in feed_dict) | |||||
| { | { | ||||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||||
| { | |||||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||||
| } | |||||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed.Value; | |||||
| //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||||
| } | } | ||||
| } | } | ||||
| @@ -157,86 +147,71 @@ namespace Tensorflow | |||||
| /// </returns> | /// </returns> | ||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | ||||
| { | { | ||||
| var feeds = feed_dict.Select(x => | |||||
| var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||||
| var ignoreDispose = new bool[feed_dict.Count]; | |||||
| int i = 0; | |||||
| foreach (var x in feed_dict) | |||||
| { | { | ||||
| if (x.Key is Tensor tensor) | if (x.Key is Tensor tensor) | ||||
| { | { | ||||
| switch (x.Value) | switch (x.Value) | ||||
| { | { | ||||
| case Tensor v: ignoreDispose[i] = true; feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break; | |||||
| case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; | |||||
| #if _REGEN | #if _REGEN | ||||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case #1[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| % | |||||
| %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| % | |||||
| #else | #else | ||||
| case sbyte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case sbyte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| #endif | #endif | ||||
| case bool v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||||
| case string v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case IntPtr v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Tensor v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| case NDArray v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; | |||||
| case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| default: | default: | ||||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||||
| throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | |||||
| } | } | ||||
| } | } | ||||
| throw new NotImplementedException("_do_run.feed_dict"); | |||||
| }).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
| var targets = target_list; | |||||
| } | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
| //var targets = target_list; | |||||
| try | |||||
| { | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
| } finally | |||||
| { | |||||
| for (var idx = 0; idx < feeds.Length; idx++) | |||||
| { | |||||
| if (ignoreDispose[idx]) | |||||
| continue; | |||||
| feeds[idx].Value.Dispose(); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | ||||
| @@ -268,9 +243,6 @@ namespace Tensorflow | |||||
| for (int i = 0; i < fetch_list.Length; i++) | for (int i = 0; i < fetch_list.Length; i++) | ||||
| result[i] = fetchValue(output_values[i]); | result[i] = fetchValue(output_values[i]); | ||||
| for (int i = 0; i < feed_dict.Length; i++) | |||||
| feed_dict[i].Value.Dispose(); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -280,7 +252,7 @@ namespace Tensorflow | |||||
| NDArray nd = null; | NDArray nd = null; | ||||
| Type type = tensor.dtype.as_numpy_dtype(); | Type type = tensor.dtype.as_numpy_dtype(); | ||||
| var ndims = tensor.shape; | var ndims = tensor.shape; | ||||
| var offset = c_api.TF_TensorData(output); | |||||
| var offset = (byte*) c_api.TF_TensorData(output); | |||||
| if(ndims.Length == 0) | if(ndims.Length == 0) | ||||
| { | { | ||||