| @@ -40,35 +40,35 @@ namespace Tensorflow | |||
| } | |||
| public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
| public virtual object run(Tensor fetches, FeedDict feed_dict = null) | |||
| { | |||
| var result = _run(fetches, feed_dict); | |||
| return result; | |||
| } | |||
| private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
| private unsafe object _run(Tensor fetches, FeedDict feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||
| var feed_dict_tensor = new FeedDict(); | |||
| if (feed_dict != null) | |||
| { | |||
| NDArray np_val = null; | |||
| foreach (var feed in feed_dict) | |||
| foreach (FeedValue feed in feed_dict) | |||
| { | |||
| switch (feed.Value) | |||
| switch (feed.feed_val) | |||
| { | |||
| case float value: | |||
| np_val = np.asarray(value); | |||
| break; | |||
| } | |||
| feed_dict_tensor[feed.Key] = np_val; | |||
| feed_dict_tensor[feed.feed] = np_val; | |||
| } | |||
| } | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler(_graph, fetches); | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| @@ -80,19 +80,20 @@ namespace Tensorflow | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_fetches); | |||
| var results = _do_run(final_fetches, feed_dict_tensor); | |||
| return fetch_handler.build_results(null, results); | |||
| } | |||
| private object[] _do_run(List<object> fetch_list) | |||
| private object[] _do_run(List<Tensor> fetch_list, FeedDict feed_dict) | |||
| { | |||
| var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray(); | |||
| var feeds = feed_dict.items().Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| return _call_tf_sessionrun(fetches); | |||
| return _call_tf_sessionrun(feeds, fetches); | |||
| } | |||
| private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list) | |||
| private unsafe object[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list) | |||
| { | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| @@ -103,7 +104,7 @@ namespace Tensorflow | |||
| c_api.TF_SessionRun(_session, | |||
| run_options: IntPtr.Zero, | |||
| inputs: new TF_Output[] { }, | |||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
| input_values: new IntPtr[] { }, | |||
| ninputs: 0, | |||
| outputs: fetch_list, | |||
| @@ -0,0 +1,59 @@ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class FeedDict : IEnumerable | |||
| { | |||
| private Dictionary<Tensor, object> feed_dict; | |||
| public FeedDict() | |||
| { | |||
| feed_dict = new Dictionary<Tensor, object>(); | |||
| } | |||
| public object this[Tensor feed] | |||
| { | |||
| get | |||
| { | |||
| return feed_dict[feed]; | |||
| } | |||
| set | |||
| { | |||
| feed_dict[feed] = value; | |||
| } | |||
| } | |||
| public FeedDict Add(Tensor feed, object value) | |||
| { | |||
| feed_dict.Add(feed, value); | |||
| return this; | |||
| } | |||
| public IEnumerator GetEnumerator() | |||
| { | |||
| foreach (KeyValuePair<Tensor, object> feed in feed_dict) | |||
| { | |||
| yield return new FeedValue | |||
| { | |||
| feed = feed.Key, | |||
| feed_val = feed.Value | |||
| }; | |||
| } | |||
| } | |||
| public Dictionary<Tensor, object> items() | |||
| { | |||
| return feed_dict; | |||
| } | |||
| } | |||
| public struct FeedValue | |||
| { | |||
| public Tensor feed { get; set; } | |||
| public object feed_val { get; set; } | |||
| } | |||
| } | |||
| @@ -10,12 +10,12 @@ namespace Tensorflow | |||
| public class _FetchHandler | |||
| { | |||
| private _ElementFetchMapper _fetch_mapper; | |||
| private List<object> _fetches = new List<object>(); | |||
| private List<Tensor> _fetches = new List<Tensor>(); | |||
| private List<bool> _ops = new List<bool>(); | |||
| private List<object> _final_fetches = new List<object>(); | |||
| private List<Tensor> _final_fetches = new List<Tensor>(); | |||
| private List<object> _targets = new List<object>(); | |||
| public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | |||
| public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null) | |||
| { | |||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
| @@ -24,7 +24,7 @@ namespace Tensorflow | |||
| { | |||
| case Tensor val: | |||
| _assert_fetchable(graph, val.op); | |||
| _fetches.Add(fetch); | |||
| _fetches.Add(val); | |||
| _ops.Add(false); | |||
| break; | |||
| } | |||
| @@ -47,7 +47,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public List<Object> fetches() | |||
| public List<Tensor> fetches() | |||
| { | |||
| return _final_fetches; | |||
| } | |||
| @@ -38,8 +38,8 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="session"></param> | |||
| /// <param name="run_options"></param> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="input_values"></param> | |||
| /// <param name="inputs">TF_Output</param> | |||
| /// <param name="input_values">TF_Tensor</param> | |||
| /// <param name="ninputs"></param> | |||
| /// <param name="outputs"></param> | |||
| /// <param name="output_values"></param> | |||
| @@ -8,6 +8,7 @@ namespace Tensorflow | |||
| { | |||
| public static TF_Output tf_output(IntPtr c_op, int index) | |||
| { | |||
| var ret = new TF_Output(); | |||
| ret.oper = c_op; | |||
| ret.index = index; | |||
| @@ -9,12 +9,6 @@ namespace TensorFlowNET.UnitTest | |||
| [TestClass] | |||
| public class OperationsTest | |||
| { | |||
| [TestMethod] | |||
| public void placeholder() | |||
| { | |||
| var x = tf.placeholder(tf.float32); | |||
| } | |||
| [TestMethod] | |||
| public void addInPlaceholder() | |||
| { | |||
| @@ -24,9 +18,9 @@ namespace TensorFlowNET.UnitTest | |||
| using(var sess = tf.Session()) | |||
| { | |||
| var feed_dict = new Dictionary<Tensor, object>(); | |||
| feed_dict.Add(a, 3.0f); | |||
| feed_dict.Add(b, 2.0f); | |||
| var feed_dict = new FeedDict() | |||
| .Add(a, 3.0f) | |||
| .Add(b, 2.0f); | |||
| var o = sess.run(c, feed_dict); | |||
| } | |||
| @@ -0,0 +1,18 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class PlaceholderTest | |||
| { | |||
| [TestMethod] | |||
| public void placeholder() | |||
| { | |||
| var x = tf.placeholder(tf.float32); | |||
| } | |||
| } | |||
| } | |||