diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index f81e1bf9..58177df2 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -21,16 +21,12 @@ using System.Collections.Generic; using System.Linq; using System.Numerics; using System.Text; -using System.Threading.Tasks; -using NumSharp.Backends; -using NumSharp.Backends.Unmanaged; namespace Tensorflow { public class BaseSession : DisposableObject { protected Graph _graph; - protected SessionOptions _options; protected bool _opened; protected bool _closed; protected int _current_version; @@ -39,13 +35,21 @@ namespace Tensorflow public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) { - _graph = g ?? ops.get_default_graph(); + _graph = g is null ? ops.get_default_graph() : g; _graph.as_default(); - _target = Encoding.UTF8.GetBytes(target); - _options = opts = opts ?? new SessionOptions(); + _target = UTF8Encoding.UTF8.GetBytes(target); + + SessionOptions newOpts = null; + if (opts == null) + newOpts = new SessionOptions(); + var status = new Status(); - _handle = c_api.TF_NewSession(_graph, opts, status); + _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); + + // dispose newOpts + if (opts == null) + newOpts.Dispose(); status.Check(true); } @@ -67,19 +71,19 @@ namespace Tensorflow public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict); + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); return (results[0], results[1], results[2], results[3]); } public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict); + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); return (results[0], results[1], results[2]); } public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict); + var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); return (results[0], results[1]); } @@ -90,7 +94,8 @@ namespace Tensorflow public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) { - var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); + var feed_items = feed_dict == null ? new FeedItem[0] : + feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); return _run(fetches, feed_items); } @@ -99,15 +104,23 @@ namespace Tensorflow var feed_dict_tensor = new Dictionary(); var feed_map = new Dictionary(); + Func> feed_fn = (item) => + { + return new (object, object)[] { (item.Key, item.Value) }; + }; + // Validate and process feed_dict. - if (feed_dict != null && feed_dict.Length > 0) + if (feed_dict != null) { - foreach (var subfeed in feed_dict) + foreach (var feed in feed_dict) { - var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); - //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used - feed_dict_tensor[subfeed_t] = subfeed.Value; - feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); + foreach (var (subfeed, subfeed_val) in feed_fn(feed)) + { + var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); + //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used + feed_dict_tensor[subfeed_t] = subfeed_val; + feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); + } } } @@ -124,7 +137,7 @@ namespace Tensorflow // We only want to really perform the run if fetches or targets are provided, // or if the call is a partial run that specifies feeds. - var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor); + var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); return fetch_handler.build_results(this, results); } @@ -144,58 +157,84 @@ namespace Tensorflow /// private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { - var feeds = new KeyValuePair[feed_dict.Count]; - int i = 0; - foreach (var x in feed_dict) + var feeds = feed_dict.Select(x => { - var tensor = (Tensor) x.Key; - switch (x.Value) + if (x.Key is Tensor tensor) { + switch (x.Value) + { #if _REGEN - %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] - %foreach types% - case #1 v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case #1[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - % + %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %foreach types% + case #1 v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case #1[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + % #else - case sbyte v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case sbyte[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case byte v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case byte[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case short v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case short[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case ushort v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case ushort[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case int v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case int[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case uint v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case uint[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case long v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case long[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case ulong v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case ulong[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case float v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case float[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case double v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case double[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case Complex v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case Complex[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case sbyte v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case sbyte[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case byte v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case byte[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case short v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case short[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ushort v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ushort[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case int v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case int[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case uint v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case uint[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case long v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case long[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ulong v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ulong[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case float v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case float[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case double v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case double[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Complex v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Complex[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); #endif - case bool v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; - case string v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case IntPtr v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case Tensor v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), v); break; - case NDArray v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; - default: - throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? ""}"); + case bool v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); + case string v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case IntPtr v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Tensor v: + return new KeyValuePair(tensor._as_tf_output(), v); + case NDArray v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); + default: + throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "")}"); + } } - } - - var fetches = new TF_Output[fetch_list.Count]; - for (i = 0; i < fetch_list.Count; i++) - fetches[i] = fetch_list[i]._as_tf_output(); - - //var targets = target_list; + throw new NotImplementedException("_do_run.feed_dict"); + }).ToArray(); + var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); + var targets = target_list; return _call_tf_sessionrun(feeds, fetches, target_list); } @@ -206,27 +245,27 @@ namespace Tensorflow _extend_graph(); var status = new Status(); - var fetch_len = fetch_list.Length; - var output_values = new IntPtr[fetch_len]; + + var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); c_api.TF_SessionRun(_handle, run_options: null, inputs: feed_dict.Select(f => f.Key).ToArray(), - input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(), + input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), ninputs: feed_dict.Length, outputs: fetch_list, output_values: output_values, - noutputs: fetch_len, - target_opers: target_list.Select(f => (IntPtr) f).ToArray(), + noutputs: fetch_list.Length, + target_opers: target_list.Select(f => (IntPtr)f).ToArray(), ntargets: target_list.Count, run_metadata: IntPtr.Zero, status: status); status.Check(true); - var result = new NDArray[fetch_len]; + var result = new NDArray[fetch_list.Length]; - for (int i = 0; i < fetch_len; i++) + for (int i = 0; i < fetch_list.Length; i++) result[i] = fetchValue(output_values[i]); for (int i = 0; i < feed_dict.Length; i++) @@ -237,191 +276,109 @@ namespace Tensorflow private unsafe NDArray fetchValue(IntPtr output) { - NDArray ret; - using (var tensor = new Tensor(output)) - { - var ndims = tensor.shape; - var srcAddress = c_api.TF_TensorData(output).ToInt64(); + var tensor = new Tensor(output); + NDArray nd = null; + Type type = tensor.dtype.as_numpy_dtype(); + var ndims = tensor.shape; + var offset = c_api.TF_TensorData(output); - if (ndims.Length == 0) + if(ndims.Length == 0) + { + switch (tensor.dtype) { - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - ret = NDArray.Scalar(*(bool*) srcAddress); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.BufferToArray(); - // offset has to start from 9/ - var str = Encoding.Default.GetString(bytes, 9, bytes[8]); - ret = NDArray.FromString(str); - break; - case TF_DataType.TF_UINT8: - ret = NDArray.Scalar(*(byte*) srcAddress); - break; - case TF_DataType.TF_INT16: - ret = NDArray.Scalar(*(short*) srcAddress); - break; - case TF_DataType.TF_INT32: - ret = NDArray.Scalar(*(int*) srcAddress); - break; - case TF_DataType.TF_INT64: - ret = NDArray.Scalar(*(long*) srcAddress); - break; - case TF_DataType.TF_UINT16: - ret = NDArray.Scalar(*(ushort*) srcAddress); - break; - case TF_DataType.TF_UINT32: - ret = NDArray.Scalar(*(uint*) srcAddress); - break; - case TF_DataType.TF_UINT64: - ret = NDArray.Scalar(*(ulong*) srcAddress); - break; - case TF_DataType.TF_FLOAT: - ret = NDArray.Scalar(*(float*) srcAddress); - break; - case TF_DataType.TF_DOUBLE: - ret = NDArray.Scalar(*(double*) srcAddress); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - } else + case TF_DataType.TF_BOOL: + nd = NDArray.Scalar(*(bool*)offset); + break; + case TF_DataType.TF_STRING: + var bytes = tensor.BufferToArray(); + // wired, don't know why we have to start from offset 9. + // length in the begin + var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); + nd = NDArray.FromString(str); + break; + case TF_DataType.TF_UINT8: + nd = NDArray.Scalar(*(byte*)offset); + break; + case TF_DataType.TF_INT16: + nd = NDArray.Scalar(*(short*)offset); + break; + case TF_DataType.TF_INT32: + nd = NDArray.Scalar(*(int*)offset); + break; + case TF_DataType.TF_INT64: + nd = NDArray.Scalar(*(long*)offset); + break; + case TF_DataType.TF_FLOAT: + nd = NDArray.Scalar(*(float*)offset); + break; + case TF_DataType.TF_DOUBLE: + nd = NDArray.Scalar(*(double*)offset); + break; + default: + throw new NotImplementedException("can't fetch output"); + } + } + else + { + switch (tensor.dtype) { - //var size = (long) tensor.size; - //var itemsize = (long) tensor.itemsize; - var bytesize = (long) tensor.bytesize; - var src = (void*) srcAddress; - -#if _REGEN - #region Compute - switch (tensor.dtype) - { - %foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")% - case TF_DataType.#3: - { - ret = new NDArray(NPTypeCode.#1, ndims, false); - System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize); - break; - } - % - case TF_DataType.TF_STRING: - { - ret = new NDArray(NPTypeCode.Char, Shape.Vector((int) size), false); //TODO! Eli: when numsharp supports long size, remove (int) cast. - //var bytes = tensor.BufferToArray(); - //// wired, don't know why we have to start from offset 9. - //// length in the begin - //var str = Encoding.Default.GetString(bytes, 9, bytes[8]); - //ret = np.array(str); - - //TODO! Eli: this has to be unit-tested. - var len = sizeof(char) * size; - var dst = ret.Unsafe.Address; - System.Buffer.MemoryCopy((byte*) src + 8, dst, len, len); - break; - } - default: - throw new NotSupportedException(); - } - #endregion -#else - - #region Compute - - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - { - ret = new NDArray(NPTypeCode.Boolean, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_UINT8: - { - ret = new NDArray(NPTypeCode.Byte, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_INT16: - { - ret = new NDArray(NPTypeCode.Int16, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_UINT16: - { - ret = new NDArray(NPTypeCode.UInt16, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_INT32: - { - ret = new NDArray(NPTypeCode.Int32, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_UINT32: - { - ret = new NDArray(NPTypeCode.UInt32, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_INT64: - { - ret = new NDArray(NPTypeCode.Int64, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_UINT64: - { - ret = new NDArray(NPTypeCode.UInt64, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_DOUBLE: - { - ret = new NDArray(NPTypeCode.Double, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_FLOAT: - { - ret = new NDArray(NPTypeCode.Single, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_STRING: - { - ret = new NDArray(NPTypeCode.Char, Shape.Vector((int) (bytesize - 8) / sizeof(char)), false); //TODO! Eli: when numsharp supports long size, remove (int) cast. - - //TODO! Eli: this has to be unit-tested. - var len = bytesize - 8; - var dst = ret.Unsafe.Address; - System.Buffer.MemoryCopy((byte*) src + 8, dst, len, len); - break; - } - - default: - throw new NotSupportedException(); - } - - #endregion - -#endif + case TF_DataType.TF_BOOL: + var bools = new bool[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(bools).reshape(ndims); + break; + case TF_DataType.TF_STRING: + var bytes = tensor.BufferToArray(); + // wired, don't know why we have to start from offset 9. + // length in the begin + var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); + nd = np.array(str); + break; + case TF_DataType.TF_UINT8: + var _bytes = new byte[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(_bytes).reshape(ndims); + break; + case TF_DataType.TF_INT16: + var shorts = new short[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(shorts).reshape(ndims); + break; + case TF_DataType.TF_INT32: + var ints = new int[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(ints).reshape(ndims); + break; + case TF_DataType.TF_INT64: + var longs = new long[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(longs).reshape(ndims); + break; + case TF_DataType.TF_FLOAT: + var floats = new float[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(floats).reshape(ndims); + break; + case TF_DataType.TF_DOUBLE: + var doubles = new double[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(doubles).reshape(ndims); + break; + default: + throw new NotImplementedException("can't fetch output"); } } + + tensor.Dispose(); - return ret; + return nd; } /// @@ -435,7 +392,9 @@ namespace Tensorflow } private void _extend_graph() - { } + { + + } public void close() { @@ -449,8 +408,6 @@ namespace Tensorflow c_api.TF_DeleteSession(handle, status); status.Check(true); } - - _options.Dispose(); } } } \ No newline at end of file