| @@ -21,16 +21,12 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Numerics; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using NumSharp.Backends; | |||
| using NumSharp.Backends.Unmanaged; | |||
| namespace Tensorflow | |||
| { | |||
| public class BaseSession : DisposableObject | |||
| { | |||
| protected Graph _graph; | |||
| protected SessionOptions _options; | |||
| protected bool _opened; | |||
| protected bool _closed; | |||
| protected int _current_version; | |||
| @@ -39,13 +35,21 @@ namespace Tensorflow | |||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
| { | |||
| _graph = g ?? ops.get_default_graph(); | |||
| _graph = g is null ? ops.get_default_graph() : g; | |||
| _graph.as_default(); | |||
| _target = Encoding.UTF8.GetBytes(target); | |||
| _options = opts = opts ?? new SessionOptions(); | |||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||
| SessionOptions newOpts = null; | |||
| if (opts == null) | |||
| newOpts = new SessionOptions(); | |||
| var status = new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts, status); | |||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
| // dispose newOpts | |||
| if (opts == null) | |||
| newOpts.Dispose(); | |||
| status.Check(true); | |||
| } | |||
| @@ -67,19 +71,19 @@ namespace Tensorflow | |||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict); | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict); | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
| return (results[0], results[1], results[2]); | |||
| } | |||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict); | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
| return (results[0], results[1]); | |||
| } | |||
| @@ -90,7 +94,8 @@ namespace Tensorflow | |||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
| { | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| return _run(fetches, feed_items); | |||
| } | |||
| @@ -99,15 +104,23 @@ namespace Tensorflow | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| var feed_map = new Dictionary<object, object>(); | |||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||
| { | |||
| return new (object, object)[] { (item.Key, item.Value) }; | |||
| }; | |||
| // Validate and process feed_dict. | |||
| if (feed_dict != null && feed_dict.Length > 0) | |||
| if (feed_dict != null) | |||
| { | |||
| foreach (var subfeed in feed_dict) | |||
| foreach (var feed in feed_dict) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed.Value; | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
| } | |||
| } | |||
| } | |||
| @@ -124,7 +137,7 @@ namespace Tensorflow | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor); | |||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
| return fetch_handler.build_results(this, results); | |||
| } | |||
| @@ -144,58 +157,84 @@ namespace Tensorflow | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| { | |||
| var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||
| int i = 0; | |||
| foreach (var x in feed_dict) | |||
| var feeds = feed_dict.Select(x => | |||
| { | |||
| var tensor = (Tensor) x.Key; | |||
| switch (x.Value) | |||
| if (x.Key is Tensor tensor) | |||
| { | |||
| switch (x.Value) | |||
| { | |||
| #if _REGEN | |||
| %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| % | |||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case #1[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| % | |||
| #else | |||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case sbyte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case sbyte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| #endif | |||
| case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; | |||
| case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
| case Tensor v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break; | |||
| case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; | |||
| default: | |||
| throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | |||
| case bool v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||
| case string v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case IntPtr v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Tensor v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||
| case NDArray v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||
| default: | |||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||
| } | |||
| } | |||
| } | |||
| var fetches = new TF_Output[fetch_list.Count]; | |||
| for (i = 0; i < fetch_list.Count; i++) | |||
| fetches[i] = fetch_list[i]._as_tf_output(); | |||
| //var targets = target_list; | |||
| throw new NotImplementedException("_do_run.feed_dict"); | |||
| }).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| var targets = target_list; | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| } | |||
| @@ -206,27 +245,27 @@ namespace Tensorflow | |||
| _extend_graph(); | |||
| var status = new Status(); | |||
| var fetch_len = fetch_list.Length; | |||
| var output_values = new IntPtr[fetch_len]; | |||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
| input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(), | |||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||
| ninputs: feed_dict.Length, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: fetch_len, | |||
| target_opers: target_list.Select(f => (IntPtr) f).ToArray(), | |||
| noutputs: fetch_list.Length, | |||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
| ntargets: target_list.Count, | |||
| run_metadata: IntPtr.Zero, | |||
| status: status); | |||
| status.Check(true); | |||
| var result = new NDArray[fetch_len]; | |||
| var result = new NDArray[fetch_list.Length]; | |||
| for (int i = 0; i < fetch_len; i++) | |||
| for (int i = 0; i < fetch_list.Length; i++) | |||
| result[i] = fetchValue(output_values[i]); | |||
| for (int i = 0; i < feed_dict.Length; i++) | |||
| @@ -237,191 +276,109 @@ namespace Tensorflow | |||
| private unsafe NDArray fetchValue(IntPtr output) | |||
| { | |||
| NDArray ret; | |||
| using (var tensor = new Tensor(output)) | |||
| { | |||
| var ndims = tensor.shape; | |||
| var srcAddress = c_api.TF_TensorData(output).ToInt64(); | |||
| var tensor = new Tensor(output); | |||
| NDArray nd = null; | |||
| Type type = tensor.dtype.as_numpy_dtype(); | |||
| var ndims = tensor.shape; | |||
| var offset = c_api.TF_TensorData(output); | |||
| if (ndims.Length == 0) | |||
| if(ndims.Length == 0) | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| ret = NDArray.Scalar(*(bool*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // offset has to start from 9/ | |||
| var str = Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| ret = NDArray.FromString(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| ret = NDArray.Scalar(*(byte*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| ret = NDArray.Scalar(*(short*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| ret = NDArray.Scalar(*(int*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| ret = NDArray.Scalar(*(long*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_UINT16: | |||
| ret = NDArray.Scalar(*(ushort*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_UINT32: | |||
| ret = NDArray.Scalar(*(uint*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_UINT64: | |||
| ret = NDArray.Scalar(*(ulong*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| ret = NDArray.Scalar(*(float*) srcAddress); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| ret = NDArray.Scalar(*(double*) srcAddress); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } else | |||
| case TF_DataType.TF_BOOL: | |||
| nd = NDArray.Scalar(*(bool*)offset); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = NDArray.FromString(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| nd = NDArray.Scalar(*(byte*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| nd = NDArray.Scalar(*(short*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| nd = NDArray.Scalar(*(int*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| nd = NDArray.Scalar(*(long*)offset); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| nd = NDArray.Scalar(*(float*)offset); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| nd = NDArray.Scalar(*(double*)offset); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| //var size = (long) tensor.size; | |||
| //var itemsize = (long) tensor.itemsize; | |||
| var bytesize = (long) tensor.bytesize; | |||
| var src = (void*) srcAddress; | |||
| #if _REGEN | |||
| #region Compute | |||
| switch (tensor.dtype) | |||
| { | |||
| %foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")% | |||
| case TF_DataType.#3: | |||
| { | |||
| ret = new NDArray(NPTypeCode.#1, ndims, false); | |||
| System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize); | |||
| break; | |||
| } | |||
| % | |||
| case TF_DataType.TF_STRING: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Char, Shape.Vector((int) size), false); //TODO! Eli: when numsharp supports long size, remove (int) cast. | |||
| //var bytes = tensor.BufferToArray(); | |||
| //// wired, don't know why we have to start from offset 9. | |||
| //// length in the begin | |||
| //var str = Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| //ret = np.array(str); | |||
| //TODO! Eli: this has to be unit-tested. | |||
| var len = sizeof(char) * size; | |||
| var dst = ret.Unsafe.Address; | |||
| System.Buffer.MemoryCopy((byte*) src + 8, dst, len, len); | |||
| break; | |||
| } | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| #endregion | |||
| #else | |||
| #region Compute | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Boolean, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_UINT8: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Byte, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_INT16: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Int16, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_UINT16: | |||
| { | |||
| ret = new NDArray(NPTypeCode.UInt16, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_INT32: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Int32, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_UINT32: | |||
| { | |||
| ret = new NDArray(NPTypeCode.UInt32, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_INT64: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Int64, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_UINT64: | |||
| { | |||
| ret = new NDArray(NPTypeCode.UInt64, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_DOUBLE: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Double, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_FLOAT: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Single, ndims, false); | |||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
| break; | |||
| } | |||
| case TF_DataType.TF_STRING: | |||
| { | |||
| ret = new NDArray(NPTypeCode.Char, Shape.Vector((int) (bytesize - 8) / sizeof(char)), false); //TODO! Eli: when numsharp supports long size, remove (int) cast. | |||
| //TODO! Eli: this has to be unit-tested. | |||
| var len = bytesize - 8; | |||
| var dst = ret.Unsafe.Address; | |||
| System.Buffer.MemoryCopy((byte*) src + 8, dst, len, len); | |||
| break; | |||
| } | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| #endregion | |||
| #endif | |||
| case TF_DataType.TF_BOOL: | |||
| var bools = new bool[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(bools).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = np.array(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| var _bytes = new byte[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(_bytes).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| var shorts = new short[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(shorts).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| var ints = new int[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(ints).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| var longs = new long[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(longs).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| var floats = new float[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(floats).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| var doubles = new double[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(doubles).reshape(ndims); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| tensor.Dispose(); | |||
| return ret; | |||
| return nd; | |||
| } | |||
| /// <summary> | |||
| @@ -435,7 +392,9 @@ namespace Tensorflow | |||
| } | |||
| private void _extend_graph() | |||
| { } | |||
| { | |||
| } | |||
| public void close() | |||
| { | |||
| @@ -449,8 +408,6 @@ namespace Tensorflow | |||
| c_api.TF_DeleteSession(handle, status); | |||
| status.Check(true); | |||
| } | |||
| _options.Dispose(); | |||
| } | |||
| } | |||
| } | |||