Browse Source

BaseSession: Reverted all changes

tags/v0.12
Eli Belash 6 years ago
parent
commit
6b2fa402c4
1 changed files with 211 additions and 254 deletions
  1. +211
    -254
      src/TensorFlowNET.Core/Sessions/BaseSession.cs

+ 211
- 254
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -21,16 +21,12 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Numerics; using System.Numerics;
using System.Text; using System.Text;
using System.Threading.Tasks;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;


namespace Tensorflow namespace Tensorflow
{ {
public class BaseSession : DisposableObject public class BaseSession : DisposableObject
{ {
protected Graph _graph; protected Graph _graph;
protected SessionOptions _options;
protected bool _opened; protected bool _opened;
protected bool _closed; protected bool _closed;
protected int _current_version; protected int _current_version;
@@ -39,13 +35,21 @@ namespace Tensorflow


public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
{ {
_graph = g ?? ops.get_default_graph();
_graph = g is null ? ops.get_default_graph() : g;
_graph.as_default(); _graph.as_default();
_target = Encoding.UTF8.GetBytes(target);
_options = opts = opts ?? new SessionOptions();
_target = UTF8Encoding.UTF8.GetBytes(target);

SessionOptions newOpts = null;
if (opts == null)
newOpts = new SessionOptions();

var status = new Status(); var status = new Status();


_handle = c_api.TF_NewSession(_graph, opts, status);
_handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status);

// dispose newOpts
if (opts == null)
newOpts.Dispose();


status.Check(true); status.Check(true);
} }
@@ -67,19 +71,19 @@ namespace Tensorflow


public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{ {
var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict);
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
return (results[0], results[1], results[2], results[3]); return (results[0], results[1], results[2], results[3]);
} }


public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{ {
var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict);
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
return (results[0], results[1], results[2]); return (results[0], results[1], results[2]);
} }


public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{ {
var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict);
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
return (results[0], results[1]); return (results[0], results[1]);
} }


@@ -90,7 +94,8 @@ namespace Tensorflow


public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) public virtual NDArray[] run(object fetches, Hashtable feed_dict = null)
{ {
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
var feed_items = feed_dict == null ? new FeedItem[0] :
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items); return _run(fetches, feed_items);
} }


@@ -99,15 +104,23 @@ namespace Tensorflow
var feed_dict_tensor = new Dictionary<object, object>(); var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>(); var feed_map = new Dictionary<object, object>();


Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
{
return new (object, object)[] { (item.Key, item.Value) };
};

// Validate and process feed_dict. // Validate and process feed_dict.
if (feed_dict != null && feed_dict.Length > 0)
if (feed_dict != null)
{ {
foreach (var subfeed in feed_dict)
foreach (var feed in feed_dict)
{ {
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed.Value;
feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed_val;
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
} }
} }


@@ -124,7 +137,7 @@ namespace Tensorflow


// We only want to really perform the run if fetches or targets are provided, // We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds. // or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor);
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor);


return fetch_handler.build_results(this, results); return fetch_handler.build_results(this, results);
} }
@@ -144,58 +157,84 @@ namespace Tensorflow
/// </returns> /// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{ {
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count];
int i = 0;
foreach (var x in feed_dict)
var feeds = feed_dict.Select(x =>
{ {
var tensor = (Tensor) x.Key;
switch (x.Value)
if (x.Key is Tensor tensor)
{ {
switch (x.Value)
{
#if _REGEN #if _REGEN
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
%
%types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case #1[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
%
#else #else
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case sbyte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case sbyte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
#endif #endif
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break;
case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Tensor v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break;
case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break;
default:
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
case bool v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL));
case string v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case IntPtr v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Tensor v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
case NDArray v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
default:
throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}");
}
} }
}

var fetches = new TF_Output[fetch_list.Count];
for (i = 0; i < fetch_list.Count; i++)
fetches[i] = fetch_list[i]._as_tf_output();

//var targets = target_list;
throw new NotImplementedException("_do_run.feed_dict");
}).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list;


return _call_tf_sessionrun(feeds, fetches, target_list); return _call_tf_sessionrun(feeds, fetches, target_list);
} }
@@ -206,27 +245,27 @@ namespace Tensorflow
_extend_graph(); _extend_graph();


var status = new Status(); var status = new Status();
var fetch_len = fetch_list.Length;
var output_values = new IntPtr[fetch_len];
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();


c_api.TF_SessionRun(_handle, c_api.TF_SessionRun(_handle,
run_options: null, run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(), inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(),
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
ninputs: feed_dict.Length, ninputs: feed_dict.Length,
outputs: fetch_list, outputs: fetch_list,
output_values: output_values, output_values: output_values,
noutputs: fetch_len,
target_opers: target_list.Select(f => (IntPtr) f).ToArray(),
noutputs: fetch_list.Length,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count, ntargets: target_list.Count,
run_metadata: IntPtr.Zero, run_metadata: IntPtr.Zero,
status: status); status: status);


status.Check(true); status.Check(true);


var result = new NDArray[fetch_len];
var result = new NDArray[fetch_list.Length];


for (int i = 0; i < fetch_len; i++)
for (int i = 0; i < fetch_list.Length; i++)
result[i] = fetchValue(output_values[i]); result[i] = fetchValue(output_values[i]);


for (int i = 0; i < feed_dict.Length; i++) for (int i = 0; i < feed_dict.Length; i++)
@@ -237,191 +276,109 @@ namespace Tensorflow


private unsafe NDArray fetchValue(IntPtr output) private unsafe NDArray fetchValue(IntPtr output)
{ {
NDArray ret;
using (var tensor = new Tensor(output))
{
var ndims = tensor.shape;
var srcAddress = c_api.TF_TensorData(output).ToInt64();
var tensor = new Tensor(output);
NDArray nd = null;
Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape;
var offset = c_api.TF_TensorData(output);


if (ndims.Length == 0)
if(ndims.Length == 0)
{
switch (tensor.dtype)
{ {
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
ret = NDArray.Scalar(*(bool*) srcAddress);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// offset has to start from 9/
var str = Encoding.Default.GetString(bytes, 9, bytes[8]);
ret = NDArray.FromString(str);
break;
case TF_DataType.TF_UINT8:
ret = NDArray.Scalar(*(byte*) srcAddress);
break;
case TF_DataType.TF_INT16:
ret = NDArray.Scalar(*(short*) srcAddress);
break;
case TF_DataType.TF_INT32:
ret = NDArray.Scalar(*(int*) srcAddress);
break;
case TF_DataType.TF_INT64:
ret = NDArray.Scalar(*(long*) srcAddress);
break;
case TF_DataType.TF_UINT16:
ret = NDArray.Scalar(*(ushort*) srcAddress);
break;
case TF_DataType.TF_UINT32:
ret = NDArray.Scalar(*(uint*) srcAddress);
break;
case TF_DataType.TF_UINT64:
ret = NDArray.Scalar(*(ulong*) srcAddress);
break;
case TF_DataType.TF_FLOAT:
ret = NDArray.Scalar(*(float*) srcAddress);
break;
case TF_DataType.TF_DOUBLE:
ret = NDArray.Scalar(*(double*) srcAddress);
break;
default:
throw new NotImplementedException("can't fetch output");
}
} else
case TF_DataType.TF_BOOL:
nd = NDArray.Scalar(*(bool*)offset);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = NDArray.FromString(str);
break;
case TF_DataType.TF_UINT8:
nd = NDArray.Scalar(*(byte*)offset);
break;
case TF_DataType.TF_INT16:
nd = NDArray.Scalar(*(short*)offset);
break;
case TF_DataType.TF_INT32:
nd = NDArray.Scalar(*(int*)offset);
break;
case TF_DataType.TF_INT64:
nd = NDArray.Scalar(*(long*)offset);
break;
case TF_DataType.TF_FLOAT:
nd = NDArray.Scalar(*(float*)offset);
break;
case TF_DataType.TF_DOUBLE:
nd = NDArray.Scalar(*(double*)offset);
break;
default:
throw new NotImplementedException("can't fetch output");
}
}
else
{
switch (tensor.dtype)
{ {
//var size = (long) tensor.size;
//var itemsize = (long) tensor.itemsize;
var bytesize = (long) tensor.bytesize;
var src = (void*) srcAddress;

#if _REGEN
#region Compute
switch (tensor.dtype)
{
%foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")%
case TF_DataType.#3:
{
ret = new NDArray(NPTypeCode.#1, ndims, false);
System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize);
break;
}
%
case TF_DataType.TF_STRING:
{
ret = new NDArray(NPTypeCode.Char, Shape.Vector((int) size), false); //TODO! Eli: when numsharp supports long size, remove (int) cast.
//var bytes = tensor.BufferToArray();
//// wired, don't know why we have to start from offset 9.
//// length in the begin
//var str = Encoding.Default.GetString(bytes, 9, bytes[8]);
//ret = np.array(str);

//TODO! Eli: this has to be unit-tested.
var len = sizeof(char) * size;
var dst = ret.Unsafe.Address;
System.Buffer.MemoryCopy((byte*) src + 8, dst, len, len);
break;
}
default:
throw new NotSupportedException();
}
#endregion
#else

#region Compute

switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
{
ret = new NDArray(NPTypeCode.Boolean, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_UINT8:
{
ret = new NDArray(NPTypeCode.Byte, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_INT16:
{
ret = new NDArray(NPTypeCode.Int16, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_UINT16:
{
ret = new NDArray(NPTypeCode.UInt16, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_INT32:
{
ret = new NDArray(NPTypeCode.Int32, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_UINT32:
{
ret = new NDArray(NPTypeCode.UInt32, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_INT64:
{
ret = new NDArray(NPTypeCode.Int64, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_UINT64:
{
ret = new NDArray(NPTypeCode.UInt64, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_DOUBLE:
{
ret = new NDArray(NPTypeCode.Double, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_FLOAT:
{
ret = new NDArray(NPTypeCode.Single, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_STRING:
{
ret = new NDArray(NPTypeCode.Char, Shape.Vector((int) (bytesize - 8) / sizeof(char)), false); //TODO! Eli: when numsharp supports long size, remove (int) cast.

//TODO! Eli: this has to be unit-tested.
var len = bytesize - 8;
var dst = ret.Unsafe.Address;
System.Buffer.MemoryCopy((byte*) src + 8, dst, len, len);
break;
}

default:
throw new NotSupportedException();
}

#endregion

#endif
case TF_DataType.TF_BOOL:
var bools = new bool[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str);
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
nd = np.array(shorts).reshape(ndims);
break;
case TF_DataType.TF_INT32:
var ints = new int[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
nd = np.array(ints).reshape(ndims);
break;
case TF_DataType.TF_INT64:
var longs = new long[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
nd = np.array(longs).reshape(ndims);
break;
case TF_DataType.TF_FLOAT:
var floats = new float[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
nd = np.array(floats).reshape(ndims);
break;
case TF_DataType.TF_DOUBLE:
var doubles = new double[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
nd = np.array(doubles).reshape(ndims);
break;
default:
throw new NotImplementedException("can't fetch output");
} }
} }
tensor.Dispose();


return ret;
return nd;
} }


/// <summary> /// <summary>
@@ -435,7 +392,9 @@ namespace Tensorflow
} }


private void _extend_graph() private void _extend_graph()
{ }
{

}


public void close() public void close()
{ {
@@ -449,8 +408,6 @@ namespace Tensorflow
c_api.TF_DeleteSession(handle, status); c_api.TF_DeleteSession(handle, status);
status.Check(true); status.Check(true);
} }

_options.Dispose();
} }
} }
} }

Loading…
Cancel
Save