From 5b49f73eda0ef3e5e4661c624a1a476591ec8456 Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Fri, 10 May 2019 22:32:53 +0200 Subject: [PATCH] session.run: added overload with Hashtable as feed dict --- src/TensorFlowNET.Core/Sessions/BaseSession.cs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 022d378c..587aca5b 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -1,5 +1,6 @@ using NumSharp; using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; @@ -40,6 +41,13 @@ namespace Tensorflow return _run(fetches, feed_dict); } + public virtual NDArray run(ITensorOrOperation[] fetches, Hashtable feed_dict = null) + { + var feed_items = feed_dict == null ? new FeedItem[0] : + feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); + return _run(fetches, feed_items); + } + private NDArray _run(object fetches, FeedItem[] feed_dict = null) { var feed_dict_tensor = new Dictionary(); @@ -89,6 +97,12 @@ namespace Tensorflow case byte[] val: feed_dict_tensor[subfeed_t] = (NDArray)val; break; + case bool val: + feed_dict_tensor[subfeed_t] = (NDArray) val; + break; + case bool[] val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; default: Console.WriteLine($"can't handle data type of subfeed_val"); throw new NotImplementedException("_run subfeed");