| @@ -21,16 +21,12 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Numerics; | using System.Numerics; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading.Tasks; | |||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class BaseSession : DisposableObject | public class BaseSession : DisposableObject | ||||
| { | { | ||||
| protected Graph _graph; | protected Graph _graph; | ||||
| protected SessionOptions _options; | |||||
| protected bool _opened; | protected bool _opened; | ||||
| protected bool _closed; | protected bool _closed; | ||||
| protected int _current_version; | protected int _current_version; | ||||
| @@ -39,13 +35,21 @@ namespace Tensorflow | |||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | ||||
| { | { | ||||
| _graph = g ?? ops.get_default_graph(); | |||||
| _graph = g is null ? ops.get_default_graph() : g; | |||||
| _graph.as_default(); | _graph.as_default(); | ||||
| _target = Encoding.UTF8.GetBytes(target); | |||||
| _options = opts = opts ?? new SessionOptions(); | |||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||||
| SessionOptions newOpts = null; | |||||
| if (opts == null) | |||||
| newOpts = new SessionOptions(); | |||||
| var status = new Status(); | var status = new Status(); | ||||
| _handle = c_api.TF_NewSession(_graph, opts, status); | |||||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
| // dispose newOpts | |||||
| if (opts == null) | |||||
| newOpts.Dispose(); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| @@ -67,19 +71,19 @@ namespace Tensorflow | |||||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
| { | { | ||||
| var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict); | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
| return (results[0], results[1], results[2], results[3]); | return (results[0], results[1], results[2], results[3]); | ||||
| } | } | ||||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
| { | { | ||||
| var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict); | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
| return (results[0], results[1], results[2]); | return (results[0], results[1], results[2]); | ||||
| } | } | ||||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
| { | { | ||||
| var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict); | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
| return (results[0], results[1]); | return (results[0], results[1]); | ||||
| } | } | ||||
| @@ -90,7 +94,8 @@ namespace Tensorflow | |||||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | ||||
| { | { | ||||
| var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| return _run(fetches, feed_items); | return _run(fetches, feed_items); | ||||
| } | } | ||||
| @@ -99,15 +104,23 @@ namespace Tensorflow | |||||
| var feed_dict_tensor = new Dictionary<object, object>(); | var feed_dict_tensor = new Dictionary<object, object>(); | ||||
| var feed_map = new Dictionary<object, object>(); | var feed_map = new Dictionary<object, object>(); | ||||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||||
| { | |||||
| return new (object, object)[] { (item.Key, item.Value) }; | |||||
| }; | |||||
| // Validate and process feed_dict. | // Validate and process feed_dict. | ||||
| if (feed_dict != null && feed_dict.Length > 0) | |||||
| if (feed_dict != null) | |||||
| { | { | ||||
| foreach (var subfeed in feed_dict) | |||||
| foreach (var feed in feed_dict) | |||||
| { | { | ||||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed.Value; | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||||
| { | |||||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -124,7 +137,7 @@ namespace Tensorflow | |||||
| // We only want to really perform the run if fetches or targets are provided, | // We only want to really perform the run if fetches or targets are provided, | ||||
| // or if the call is a partial run that specifies feeds. | // or if the call is a partial run that specifies feeds. | ||||
| var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor); | |||||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
| return fetch_handler.build_results(this, results); | return fetch_handler.build_results(this, results); | ||||
| } | } | ||||
| @@ -144,58 +157,84 @@ namespace Tensorflow | |||||
| /// </returns> | /// </returns> | ||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | ||||
| { | { | ||||
| var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||||
| int i = 0; | |||||
| foreach (var x in feed_dict) | |||||
| var feeds = feed_dict.Select(x => | |||||
| { | { | ||||
| var tensor = (Tensor) x.Key; | |||||
| switch (x.Value) | |||||
| if (x.Key is Tensor tensor) | |||||
| { | { | ||||
| switch (x.Value) | |||||
| { | |||||
| #if _REGEN | #if _REGEN | ||||
| %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| % | |||||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case #1[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| % | |||||
| #else | #else | ||||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case sbyte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case sbyte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| #endif | #endif | ||||
| case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; | |||||
| case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Tensor v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break; | |||||
| case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; | |||||
| default: | |||||
| throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | |||||
| case bool v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||||
| case string v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case IntPtr v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Tensor v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| case NDArray v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| default: | |||||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| var fetches = new TF_Output[fetch_list.Count]; | |||||
| for (i = 0; i < fetch_list.Count; i++) | |||||
| fetches[i] = fetch_list[i]._as_tf_output(); | |||||
| //var targets = target_list; | |||||
| throw new NotImplementedException("_do_run.feed_dict"); | |||||
| }).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
| var targets = target_list; | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | return _call_tf_sessionrun(feeds, fetches, target_list); | ||||
| } | } | ||||
| @@ -206,27 +245,27 @@ namespace Tensorflow | |||||
| _extend_graph(); | _extend_graph(); | ||||
| var status = new Status(); | var status = new Status(); | ||||
| var fetch_len = fetch_list.Length; | |||||
| var output_values = new IntPtr[fetch_len]; | |||||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
| c_api.TF_SessionRun(_handle, | c_api.TF_SessionRun(_handle, | ||||
| run_options: null, | run_options: null, | ||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | inputs: feed_dict.Select(f => f.Key).ToArray(), | ||||
| input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(), | |||||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||||
| ninputs: feed_dict.Length, | ninputs: feed_dict.Length, | ||||
| outputs: fetch_list, | outputs: fetch_list, | ||||
| output_values: output_values, | output_values: output_values, | ||||
| noutputs: fetch_len, | |||||
| target_opers: target_list.Select(f => (IntPtr) f).ToArray(), | |||||
| noutputs: fetch_list.Length, | |||||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
| ntargets: target_list.Count, | ntargets: target_list.Count, | ||||
| run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
| status: status); | status: status); | ||||
| status.Check(true); | status.Check(true); | ||||
| var result = new NDArray[fetch_len]; | |||||
| var result = new NDArray[fetch_list.Length]; | |||||
| for (int i = 0; i < fetch_len; i++) | |||||
| for (int i = 0; i < fetch_list.Length; i++) | |||||
| result[i] = fetchValue(output_values[i]); | result[i] = fetchValue(output_values[i]); | ||||
| for (int i = 0; i < feed_dict.Length; i++) | for (int i = 0; i < feed_dict.Length; i++) | ||||
| @@ -237,191 +276,109 @@ namespace Tensorflow | |||||
| private unsafe NDArray fetchValue(IntPtr output) | private unsafe NDArray fetchValue(IntPtr output) | ||||
| { | { | ||||
| NDArray ret; | |||||
| using (var tensor = new Tensor(output)) | |||||
| { | |||||
| var ndims = tensor.shape; | |||||
| var srcAddress = c_api.TF_TensorData(output).ToInt64(); | |||||
| var tensor = new Tensor(output); | |||||
| NDArray nd = null; | |||||
| Type type = tensor.dtype.as_numpy_dtype(); | |||||
| var ndims = tensor.shape; | |||||
| var offset = c_api.TF_TensorData(output); | |||||
| if (ndims.Length == 0) | |||||
| if(ndims.Length == 0) | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | { | ||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| ret = NDArray.Scalar(*(bool*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // offset has to start from 9/ | |||||
| var str = Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| ret = NDArray.FromString(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| ret = NDArray.Scalar(*(byte*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| ret = NDArray.Scalar(*(short*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| ret = NDArray.Scalar(*(int*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| ret = NDArray.Scalar(*(long*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_UINT16: | |||||
| ret = NDArray.Scalar(*(ushort*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_UINT32: | |||||
| ret = NDArray.Scalar(*(uint*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_UINT64: | |||||
| ret = NDArray.Scalar(*(ulong*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| ret = NDArray.Scalar(*(float*) srcAddress); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| ret = NDArray.Scalar(*(double*) srcAddress); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } else | |||||
| case TF_DataType.TF_BOOL: | |||||
| nd = NDArray.Scalar(*(bool*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = NDArray.FromString(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| nd = NDArray.Scalar(*(byte*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| nd = NDArray.Scalar(*(short*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| nd = NDArray.Scalar(*(int*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| nd = NDArray.Scalar(*(long*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| nd = NDArray.Scalar(*(float*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| nd = NDArray.Scalar(*(double*)offset); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | { | ||||
| //var size = (long) tensor.size; | |||||
| //var itemsize = (long) tensor.itemsize; | |||||
| var bytesize = (long) tensor.bytesize; | |||||
| var src = (void*) srcAddress; | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| %foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")% | |||||
| case TF_DataType.#3: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.#1, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| % | |||||
| case TF_DataType.TF_STRING: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Char, Shape.Vector((int) size), false); //TODO! Eli: when numsharp supports long size, remove (int) cast. | |||||
| //var bytes = tensor.BufferToArray(); | |||||
| //// wired, don't know why we have to start from offset 9. | |||||
| //// length in the begin | |||||
| //var str = Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| //ret = np.array(str); | |||||
| //TODO! Eli: this has to be unit-tested. | |||||
| var len = sizeof(char) * size; | |||||
| var dst = ret.Unsafe.Address; | |||||
| System.Buffer.MemoryCopy((byte*) src + 8, dst, len, len); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Boolean, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT8: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Byte, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT16: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Int16, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT16: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.UInt16, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT32: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Int32, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT32: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.UInt32, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT64: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Int64, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT64: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.UInt64, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Double, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_FLOAT: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Single, ndims, false); | |||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_STRING: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Char, Shape.Vector((int) (bytesize - 8) / sizeof(char)), false); //TODO! Eli: when numsharp supports long size, remove (int) cast. | |||||
| //TODO! Eli: this has to be unit-tested. | |||||
| var len = bytesize - 8; | |||||
| var dst = ret.Unsafe.Address; | |||||
| System.Buffer.MemoryCopy((byte*) src + 8, dst, len, len); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| case TF_DataType.TF_BOOL: | |||||
| var bools = new bool[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(bools).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = np.array(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| var _bytes = new byte[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(_bytes).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| var shorts = new short[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(shorts).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| var ints = new int[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(ints).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| var longs = new long[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(longs).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| var floats = new float[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(floats).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| var doubles = new double[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(doubles).reshape(ndims); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | } | ||||
| } | } | ||||
| tensor.Dispose(); | |||||
| return ret; | |||||
| return nd; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -435,7 +392,9 @@ namespace Tensorflow | |||||
| } | } | ||||
| private void _extend_graph() | private void _extend_graph() | ||||
| { } | |||||
| { | |||||
| } | |||||
| public void close() | public void close() | ||||
| { | { | ||||
| @@ -449,8 +408,6 @@ namespace Tensorflow | |||||
| c_api.TF_DeleteSession(handle, status); | c_api.TF_DeleteSession(handle, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| _options.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||