Browse Source

BaseSession.run: revamped validate and process feed_dict.

tags/v0.12
Eli Belash 6 years ago
parent
commit
b5d1021f07
1 changed files with 6 additions and 11 deletions
  1. +6
    -11
      src/TensorFlowNET.Core/Sessions/BaseSession.cs

+ 6
- 11
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -99,20 +99,15 @@ namespace Tensorflow
var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>();

Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => { return new (object, object)[] {(item.Key, item.Value)}; };

// Validate and process feed_dict.
if (feed_dict != null)
if (feed_dict != null && feed_dict.Length > 0)
{
foreach (var feed in feed_dict)
foreach (var subfeed in feed_dict)
{
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed_val;
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed.Value;
feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
}
}



Loading…
Cancel
Save