| @@ -40,35 +40,35 @@ namespace Tensorflow | |||||
| } | } | ||||
| public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| public virtual object run(Tensor fetches, FeedDict feed_dict = null) | |||||
| { | { | ||||
| var result = _run(fetches, feed_dict); | var result = _run(fetches, feed_dict); | ||||
| return result; | return result; | ||||
| } | } | ||||
| private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| private unsafe object _run(Tensor fetches, FeedDict feed_dict = null) | |||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||||
| var feed_dict_tensor = new FeedDict(); | |||||
| if (feed_dict != null) | if (feed_dict != null) | ||||
| { | { | ||||
| NDArray np_val = null; | NDArray np_val = null; | ||||
| foreach (var feed in feed_dict) | |||||
| foreach (FeedValue feed in feed_dict) | |||||
| { | { | ||||
| switch (feed.Value) | |||||
| switch (feed.feed_val) | |||||
| { | { | ||||
| case float value: | case float value: | ||||
| np_val = np.asarray(value); | np_val = np.asarray(value); | ||||
| break; | break; | ||||
| } | } | ||||
| feed_dict_tensor[feed.Key] = np_val; | |||||
| feed_dict_tensor[feed.feed] = np_val; | |||||
| } | } | ||||
| } | } | ||||
| // Create a fetch handler to take care of the structure of fetches. | // Create a fetch handler to take care of the structure of fetches. | ||||
| var fetch_handler = new _FetchHandler(_graph, fetches); | |||||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
| // Run request and get response. | // Run request and get response. | ||||
| // We need to keep the returned movers alive for the following _do_run(). | // We need to keep the returned movers alive for the following _do_run(). | ||||
| @@ -80,19 +80,20 @@ namespace Tensorflow | |||||
| // We only want to really perform the run if fetches or targets are provided, | // We only want to really perform the run if fetches or targets are provided, | ||||
| // or if the call is a partial run that specifies feeds. | // or if the call is a partial run that specifies feeds. | ||||
| var results = _do_run(final_fetches); | |||||
| var results = _do_run(final_fetches, feed_dict_tensor); | |||||
| return fetch_handler.build_results(null, results); | return fetch_handler.build_results(null, results); | ||||
| } | } | ||||
| private object[] _do_run(List<object> fetch_list) | |||||
| private object[] _do_run(List<Tensor> fetch_list, FeedDict feed_dict) | |||||
| { | { | ||||
| var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray(); | |||||
| var feeds = feed_dict.items().Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
| return _call_tf_sessionrun(fetches); | |||||
| return _call_tf_sessionrun(feeds, fetches); | |||||
| } | } | ||||
| private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list) | |||||
| private unsafe object[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list) | |||||
| { | { | ||||
| // Ensure any changes to the graph are reflected in the runtime. | // Ensure any changes to the graph are reflected in the runtime. | ||||
| _extend_graph(); | _extend_graph(); | ||||
| @@ -103,7 +104,7 @@ namespace Tensorflow | |||||
| c_api.TF_SessionRun(_session, | c_api.TF_SessionRun(_session, | ||||
| run_options: IntPtr.Zero, | run_options: IntPtr.Zero, | ||||
| inputs: new TF_Output[] { }, | |||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||||
| input_values: new IntPtr[] { }, | input_values: new IntPtr[] { }, | ||||
| ninputs: 0, | ninputs: 0, | ||||
| outputs: fetch_list, | outputs: fetch_list, | ||||
| @@ -0,0 +1,59 @@ | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class FeedDict : IEnumerable | |||||
| { | |||||
| private Dictionary<Tensor, object> feed_dict; | |||||
| public FeedDict() | |||||
| { | |||||
| feed_dict = new Dictionary<Tensor, object>(); | |||||
| } | |||||
| public object this[Tensor feed] | |||||
| { | |||||
| get | |||||
| { | |||||
| return feed_dict[feed]; | |||||
| } | |||||
| set | |||||
| { | |||||
| feed_dict[feed] = value; | |||||
| } | |||||
| } | |||||
| public FeedDict Add(Tensor feed, object value) | |||||
| { | |||||
| feed_dict.Add(feed, value); | |||||
| return this; | |||||
| } | |||||
| public IEnumerator GetEnumerator() | |||||
| { | |||||
| foreach (KeyValuePair<Tensor, object> feed in feed_dict) | |||||
| { | |||||
| yield return new FeedValue | |||||
| { | |||||
| feed = feed.Key, | |||||
| feed_val = feed.Value | |||||
| }; | |||||
| } | |||||
| } | |||||
| public Dictionary<Tensor, object> items() | |||||
| { | |||||
| return feed_dict; | |||||
| } | |||||
| } | |||||
| public struct FeedValue | |||||
| { | |||||
| public Tensor feed { get; set; } | |||||
| public object feed_val { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -10,12 +10,12 @@ namespace Tensorflow | |||||
| public class _FetchHandler | public class _FetchHandler | ||||
| { | { | ||||
| private _ElementFetchMapper _fetch_mapper; | private _ElementFetchMapper _fetch_mapper; | ||||
| private List<object> _fetches = new List<object>(); | |||||
| private List<Tensor> _fetches = new List<Tensor>(); | |||||
| private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
| private List<object> _final_fetches = new List<object>(); | |||||
| private List<Tensor> _final_fetches = new List<Tensor>(); | |||||
| private List<object> _targets = new List<object>(); | private List<object> _targets = new List<object>(); | ||||
| public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | |||||
| public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null) | |||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | _fetch_mapper = new _FetchMapper().for_fetch(fetches); | ||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
| @@ -24,7 +24,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| case Tensor val: | case Tensor val: | ||||
| _assert_fetchable(graph, val.op); | _assert_fetchable(graph, val.op); | ||||
| _fetches.Add(fetch); | |||||
| _fetches.Add(val); | |||||
| _ops.Add(false); | _ops.Add(false); | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -47,7 +47,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public List<Object> fetches() | |||||
| public List<Tensor> fetches() | |||||
| { | { | ||||
| return _final_fetches; | return _final_fetches; | ||||
| } | } | ||||
| @@ -38,8 +38,8 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="session"></param> | /// <param name="session"></param> | ||||
| /// <param name="run_options"></param> | /// <param name="run_options"></param> | ||||
| /// <param name="inputs"></param> | |||||
| /// <param name="input_values"></param> | |||||
| /// <param name="inputs">TF_Output</param> | |||||
| /// <param name="input_values">TF_Tensor</param> | |||||
| /// <param name="ninputs"></param> | /// <param name="ninputs"></param> | ||||
| /// <param name="outputs"></param> | /// <param name="outputs"></param> | ||||
| /// <param name="output_values"></param> | /// <param name="output_values"></param> | ||||
| @@ -8,6 +8,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static TF_Output tf_output(IntPtr c_op, int index) | public static TF_Output tf_output(IntPtr c_op, int index) | ||||
| { | { | ||||
| var ret = new TF_Output(); | var ret = new TF_Output(); | ||||
| ret.oper = c_op; | ret.oper = c_op; | ||||
| ret.index = index; | ret.index = index; | ||||
| @@ -9,12 +9,6 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class OperationsTest | public class OperationsTest | ||||
| { | { | ||||
| [TestMethod] | |||||
| public void placeholder() | |||||
| { | |||||
| var x = tf.placeholder(tf.float32); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void addInPlaceholder() | public void addInPlaceholder() | ||||
| { | { | ||||
| @@ -24,9 +18,9 @@ namespace TensorFlowNET.UnitTest | |||||
| using(var sess = tf.Session()) | using(var sess = tf.Session()) | ||||
| { | { | ||||
| var feed_dict = new Dictionary<Tensor, object>(); | |||||
| feed_dict.Add(a, 3.0f); | |||||
| feed_dict.Add(b, 2.0f); | |||||
| var feed_dict = new FeedDict() | |||||
| .Add(a, 3.0f) | |||||
| .Add(b, 2.0f); | |||||
| var o = sess.run(c, feed_dict); | var o = sess.run(c, feed_dict); | ||||
| } | } | ||||
| @@ -0,0 +1,18 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class PlaceholderTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void placeholder() | |||||
| { | |||||
| var x = tf.placeholder(tf.float32); | |||||
| } | |||||
| } | |||||
| } | |||||