Browse Source

BaseSession: Perf-op

tags/v0.12
Eli Belash 6 years ago
parent
commit
98c383ccd3
1 changed files with 62 additions and 90 deletions
  1. +62
    -90
      src/TensorFlowNET.Core/Sessions/BaseSession.cs

+ 62
- 90
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -39,15 +39,13 @@ namespace Tensorflow
_graph.as_default(); _graph.as_default();
_target = UTF8Encoding.UTF8.GetBytes(target); _target = UTF8Encoding.UTF8.GetBytes(target);


SessionOptions newOpts = null;
if (opts == null)
newOpts = new SessionOptions();
SessionOptions newOpts = opts ?? new SessionOptions();


var status = new Status(); var status = new Status();


_handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status);


// dispose newOpts
// dispose opts only if not provided externally.
if (opts == null) if (opts == null)
newOpts.Dispose(); newOpts.Dispose();


@@ -102,25 +100,17 @@ namespace Tensorflow
private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) private NDArray[] _run(object fetches, FeedItem[] feed_dict = null)
{ {
var feed_dict_tensor = new Dictionary<object, object>(); var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>();

Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
{
return new (object, object)[] { (item.Key, item.Value) };
};
//var feed_map = new Dictionary<object, object>();


// Validate and process feed_dict. // Validate and process feed_dict.
if (feed_dict != null) if (feed_dict != null)
{ {
foreach (var feed in feed_dict)
foreach (var subfeed in feed_dict)
{ {
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed_val;
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed.Value;
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
} }
} }


@@ -157,86 +147,71 @@ namespace Tensorflow
/// </returns> /// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{ {
var feeds = feed_dict.Select(x =>

var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count];
var ignoreDispose = new bool[feed_dict.Count];
int i = 0;
foreach (var x in feed_dict)
{ {
if (x.Key is Tensor tensor) if (x.Key is Tensor tensor)
{ {
switch (x.Value) switch (x.Value)
{ {
case Tensor v: ignoreDispose[i] = true; feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break;
case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break;
#if _REGEN #if _REGEN
%types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case #1[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
%
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
%
#else #else
case sbyte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case sbyte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
#endif #endif
case bool v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL));
case string v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case IntPtr v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Tensor v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
case NDArray v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break;
case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
default: default:
throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}");
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
} }
} }
throw new NotImplementedException("_do_run.feed_dict");
}).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list;
}


return _call_tf_sessionrun(feeds, fetches, target_list);
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
//var targets = target_list;
try
{
return _call_tf_sessionrun(feeds, fetches, target_list);
} finally
{
for (var idx = 0; idx < feeds.Length; idx++)
{
if (ignoreDispose[idx])
continue;
feeds[idx].Value.Dispose();
}
}
} }


private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
@@ -268,9 +243,6 @@ namespace Tensorflow
for (int i = 0; i < fetch_list.Length; i++) for (int i = 0; i < fetch_list.Length; i++)
result[i] = fetchValue(output_values[i]); result[i] = fetchValue(output_values[i]);


for (int i = 0; i < feed_dict.Length; i++)
feed_dict[i].Value.Dispose();

return result; return result;
} }


@@ -280,7 +252,7 @@ namespace Tensorflow
NDArray nd = null; NDArray nd = null;
Type type = tensor.dtype.as_numpy_dtype(); Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape; var ndims = tensor.shape;
var offset = c_api.TF_TensorData(output);
var offset = (byte*) c_api.TF_TensorData(output);


if(ndims.Length == 0) if(ndims.Length == 0)
{ {


Loading…
Cancel
Save