diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 48521320..46e388a8 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -40,35 +40,35 @@ namespace Tensorflow } - public virtual object run(Tensor fetches, Dictionary feed_dict = null) + public virtual object run(Tensor fetches, FeedDict feed_dict = null) { var result = _run(fetches, feed_dict); return result; } - private unsafe object _run(Tensor fetches, Dictionary feed_dict = null) + private unsafe object _run(Tensor fetches, FeedDict feed_dict = null) { - var feed_dict_tensor = new Dictionary(); + var feed_dict_tensor = new FeedDict(); if (feed_dict != null) { NDArray np_val = null; - foreach (var feed in feed_dict) + foreach (FeedValue feed in feed_dict) { - switch (feed.Value) + switch (feed.feed_val) { case float value: np_val = np.asarray(value); break; } - feed_dict_tensor[feed.Key] = np_val; + feed_dict_tensor[feed.feed] = np_val; } } // Create a fetch handler to take care of the structure of fetches. - var fetch_handler = new _FetchHandler(_graph, fetches); + var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); // Run request and get response. // We need to keep the returned movers alive for the following _do_run(). @@ -80,19 +80,20 @@ namespace Tensorflow // We only want to really perform the run if fetches or targets are provided, // or if the call is a partial run that specifies feeds. - var results = _do_run(final_fetches); + var results = _do_run(final_fetches, feed_dict_tensor); return fetch_handler.build_results(null, results); } - private object[] _do_run(List fetch_list) + private object[] _do_run(List fetch_list, FeedDict feed_dict) { - var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray(); + var feeds = feed_dict.items().Select(x => new KeyValuePair(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); + var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); - return _call_tf_sessionrun(fetches); + return _call_tf_sessionrun(feeds, fetches); } - private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list) + private unsafe object[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list) { // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); @@ -103,7 +104,7 @@ namespace Tensorflow c_api.TF_SessionRun(_session, run_options: IntPtr.Zero, - inputs: new TF_Output[] { }, + inputs: feed_dict.Select(f => f.Key).ToArray(), input_values: new IntPtr[] { }, ninputs: 0, outputs: fetch_list, diff --git a/src/TensorFlowNET.Core/Sessions/FeedDict.cs b/src/TensorFlowNET.Core/Sessions/FeedDict.cs new file mode 100644 index 00000000..7d36e899 --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/FeedDict.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class FeedDict : IEnumerable + { + private Dictionary feed_dict; + + public FeedDict() + { + feed_dict = new Dictionary(); + } + + public object this[Tensor feed] + { + get + { + return feed_dict[feed]; + } + + set + { + feed_dict[feed] = value; + } + } + + public FeedDict Add(Tensor feed, object value) + { + feed_dict.Add(feed, value); + return this; + } + + public IEnumerator GetEnumerator() + { + foreach (KeyValuePair feed in feed_dict) + { + yield return new FeedValue + { + feed = feed.Key, + feed_val = feed.Value + }; + } + } + + public Dictionary items() + { + return feed_dict; + } + } + + public struct FeedValue + { + public Tensor feed { get; set; } + public object feed_val { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 0ec355d5..94c8b2ed 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -10,12 +10,12 @@ namespace Tensorflow public class _FetchHandler { private _ElementFetchMapper _fetch_mapper; - private List _fetches = new List(); + private List _fetches = new List(); private List _ops = new List(); - private List _final_fetches = new List(); + private List _final_fetches = new List(); private List _targets = new List(); - public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) + public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null) { _fetch_mapper = new _FetchMapper().for_fetch(fetches); foreach(var fetch in _fetch_mapper.unique_fetches()) @@ -24,7 +24,7 @@ namespace Tensorflow { case Tensor val: _assert_fetchable(graph, val.op); - _fetches.Add(fetch); + _fetches.Add(val); _ops.Add(false); break; } @@ -47,7 +47,7 @@ namespace Tensorflow } } - public List fetches() + public List fetches() { return _final_fetches; } diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 49c32890..a3d07826 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -38,8 +38,8 @@ namespace Tensorflow /// /// /// - /// - /// + /// TF_Output + /// TF_Tensor /// /// /// diff --git a/src/TensorFlowNET.Core/c_api_util.cs b/src/TensorFlowNET.Core/c_api_util.cs index f6d54062..f4a918aa 100644 --- a/src/TensorFlowNET.Core/c_api_util.cs +++ b/src/TensorFlowNET.Core/c_api_util.cs @@ -8,6 +8,7 @@ namespace Tensorflow { public static TF_Output tf_output(IntPtr c_op, int index) { + var ret = new TF_Output(); ret.oper = c_op; ret.index = index; diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 47b849d9..01749733 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -9,12 +9,6 @@ namespace TensorFlowNET.UnitTest [TestClass] public class OperationsTest { - [TestMethod] - public void placeholder() - { - var x = tf.placeholder(tf.float32); - } - [TestMethod] public void addInPlaceholder() { @@ -24,9 +18,9 @@ namespace TensorFlowNET.UnitTest using(var sess = tf.Session()) { - var feed_dict = new Dictionary(); - feed_dict.Add(a, 3.0f); - feed_dict.Add(b, 2.0f); + var feed_dict = new FeedDict() + .Add(a, 3.0f) + .Add(b, 2.0f); var o = sess.run(c, feed_dict); } diff --git a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs new file mode 100644 index 00000000..d5413e8d --- /dev/null +++ b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs @@ -0,0 +1,18 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class PlaceholderTest + { + [TestMethod] + public void placeholder() + { + var x = tf.placeholder(tf.float32); + } + } +}