| @@ -1,5 +1,6 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| @@ -40,6 +41,13 @@ namespace Tensorflow | |||||
| return _run(fetches, feed_dict); | return _run(fetches, feed_dict); | ||||
| } | } | ||||
| public virtual NDArray run(ITensorOrOperation[] fetches, Hashtable feed_dict = null) | |||||
| { | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| return _run(fetches, feed_items); | |||||
| } | |||||
| private NDArray _run(object fetches, FeedItem[] feed_dict = null) | private NDArray _run(object fetches, FeedItem[] feed_dict = null) | ||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<object, object>(); | var feed_dict_tensor = new Dictionary<object, object>(); | ||||
| @@ -89,6 +97,12 @@ namespace Tensorflow | |||||
| case byte[] val: | case byte[] val: | ||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | feed_dict_tensor[subfeed_t] = (NDArray)val; | ||||
| break; | break; | ||||
| case bool val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray) val; | |||||
| break; | |||||
| case bool[] val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| default: | default: | ||||
| Console.WriteLine($"can't handle data type of subfeed_val"); | Console.WriteLine($"can't handle data type of subfeed_val"); | ||||
| throw new NotImplementedException("_run subfeed"); | throw new NotImplementedException("_run subfeed"); | ||||