|
|
|
@@ -43,7 +43,7 @@ namespace Tensorflow |
|
|
|
|
|
|
|
private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null) |
|
|
|
{ |
|
|
|
var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); |
|
|
|
var feed_dict_tensor = new Dictionary<object, object>(); |
|
|
|
|
|
|
|
if (feed_dict != null) |
|
|
|
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); |
|
|
|
@@ -79,9 +79,30 @@ namespace Tensorflow |
|
|
|
/// name of an operation, the first Tensor output of that operation |
|
|
|
/// will be returned for that element. |
|
|
|
/// </returns> |
|
|
|
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) |
|
|
|
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) |
|
|
|
{ |
|
|
|
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); |
|
|
|
var feeds = feed_dict.Select(x => |
|
|
|
{ |
|
|
|
if(x.Key is Tensor tensor) |
|
|
|
{ |
|
|
|
switch (x.Value) |
|
|
|
{ |
|
|
|
case Tensor t1: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1); |
|
|
|
case NDArray nd: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd)); |
|
|
|
case int intVal: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal)); |
|
|
|
case float floatVal: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal)); |
|
|
|
case double doubleVal: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal)); |
|
|
|
default: |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
throw new NotImplementedException("_do_run.feed_dict"); |
|
|
|
}).ToArray(); |
|
|
|
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); |
|
|
|
var targets = target_list; |
|
|
|
|
|
|
|
|