| @@ -1,6 +1,7 @@ | |||||
| using NumSharp.Core; | using NumSharp.Core; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -38,12 +39,14 @@ namespace Tensorflow | |||||
| } | } | ||||
| public virtual byte[] run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| { | { | ||||
| return _run(fetches, feed_dict); | |||||
| var result = _run(fetches, feed_dict); | |||||
| return result; | |||||
| } | } | ||||
| private unsafe byte[] _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | ||||
| @@ -66,22 +69,71 @@ namespace Tensorflow | |||||
| // Create a fetch handler to take care of the structure of fetches. | // Create a fetch handler to take care of the structure of fetches. | ||||
| var fetch_handler = new _FetchHandler(_graph, fetches); | var fetch_handler = new _FetchHandler(_graph, fetches); | ||||
| // Run request and get response. | |||||
| // We need to keep the returned movers alive for the following _do_run(). | |||||
| // These movers are no longer needed when _do_run() completes, and | |||||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||||
| var _ = _update_with_movers(); | |||||
| var final_fetches = fetch_handler.fetches(); | |||||
| var final_targets = fetch_handler.targets(); | |||||
| // We only want to really perform the run if fetches or targets are provided, | |||||
| // or if the call is a partial run that specifies feeds. | |||||
| var results = _do_run(final_fetches); | |||||
| return fetch_handler.build_results(null, results); | |||||
| } | |||||
| private object[] _do_run(List<object> fetch_list) | |||||
| { | |||||
| var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray(); | |||||
| return _call_tf_sessionrun(fetches); | |||||
| } | |||||
| private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list) | |||||
| { | |||||
| // Ensure any changes to the graph are reflected in the runtime. | |||||
| _extend_graph(); | |||||
| var status = new Status(); | var status = new Status(); | ||||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
| c_api.TF_SessionRun(_session, | c_api.TF_SessionRun(_session, | ||||
| run_options: IntPtr.Zero, | run_options: IntPtr.Zero, | ||||
| inputs: new TF_Output[] { }, | inputs: new TF_Output[] { }, | ||||
| input_values: new IntPtr[] { }, | input_values: new IntPtr[] { }, | ||||
| ninputs: 0, | ninputs: 0, | ||||
| outputs: new TF_Output[] { new TF_Output() }, | |||||
| output_values: new IntPtr[] { }, | |||||
| noutputs: 1, | |||||
| outputs: fetch_list, | |||||
| output_values: output_values, | |||||
| noutputs: fetch_list.Length, | |||||
| target_opers: new IntPtr[] { }, | target_opers: new IntPtr[] { }, | ||||
| ntargets: 1, | |||||
| ntargets: 0, | |||||
| run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
| status: status.Handle); | status: status.Handle); | ||||
| return null; | |||||
| var result = output_values.Select(x => new Tensor(x).buffer).Select(x => | |||||
| { | |||||
| return (object)*(float*)x; | |||||
| }).ToArray(); | |||||
| return result; | |||||
| } | |||||
| /// <summary> | |||||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||||
| /// we move the tensor to the right device, generate a new tensor handle, | |||||
| /// and update feed_dict to use the new handle. | |||||
| /// </summary> | |||||
| private List<object> _update_with_movers() | |||||
| { | |||||
| return new List<object> { }; | |||||
| } | |||||
| private void _extend_graph() | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,6 +21,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public object build_results(object[] values) | |||||
| { | |||||
| return values[0]; | |||||
| } | |||||
| public List<Object> unique_fetches() | public List<Object> unique_fetches() | ||||
| { | { | ||||
| return _unique_fetches; | return _unique_fetches; | ||||
| @@ -13,6 +13,7 @@ namespace Tensorflow | |||||
| private List<object> _fetches = new List<object>(); | private List<object> _fetches = new List<object>(); | ||||
| private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
| private List<object> _final_fetches = new List<object>(); | private List<object> _final_fetches = new List<object>(); | ||||
| private List<object> _targets = new List<object>(); | |||||
| public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | ||||
| { | { | ||||
| @@ -33,6 +34,11 @@ namespace Tensorflow | |||||
| _final_fetches = _fetches; | _final_fetches = _fetches; | ||||
| } | } | ||||
| public object build_results(Session session, object[] results) | |||||
| { | |||||
| return _fetch_mapper.build_results(results); | |||||
| } | |||||
| private void _assert_fetchable(Graph graph, Operation op) | private void _assert_fetchable(Graph graph, Operation op) | ||||
| { | { | ||||
| if (!graph.is_fetchable(op)) | if (!graph.is_fetchable(op)) | ||||
| @@ -40,5 +46,15 @@ namespace Tensorflow | |||||
| throw new Exception($"Operation {op.name} has been marked as not fetchable."); | throw new Exception($"Operation {op.name} has been marked as not fetchable."); | ||||
| } | } | ||||
| } | } | ||||
| public List<Object> fetches() | |||||
| { | |||||
| return _final_fetches; | |||||
| } | |||||
| public List<Object> targets() | |||||
| { | |||||
| return _targets; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,6 +17,15 @@ namespace Tensorflow | |||||
| public string name; | public string name; | ||||
| private readonly IntPtr _handle; | |||||
| public IntPtr handle => _handle; | |||||
| public IntPtr buffer => c_api.TF_TensorData(_handle); | |||||
| public Tensor(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| public Tensor(Operation op, int value_index, DataType dtype) | public Tensor(Operation op, int value_index, DataType dtype) | ||||
| { | { | ||||
| _op = op; | _op = op; | ||||
| @@ -77,6 +77,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern unsafe IntPtr TF_TensorData(TF_Tensor tensor); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); | public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); | ||||
| @@ -12,13 +12,7 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void constant() | public void constant() | ||||
| { | { | ||||
| var a = tf.constant(4.0f); | |||||
| var b = tf.constant(5.0f); | |||||
| var c = tf.add(a, b); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var o = sess.run(c); | |||||
| } | |||||
| var x = tf.constant(4.0f); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -28,7 +22,7 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| public void add() | |||||
| public void addInPlaceholder() | |||||
| { | { | ||||
| var a = tf.placeholder(tf.float32); | var a = tf.placeholder(tf.float32); | ||||
| var b = tf.placeholder(tf.float32); | var b = tf.placeholder(tf.float32); | ||||
| @@ -43,5 +37,19 @@ namespace TensorFlowNET.UnitTest | |||||
| var o = sess.run(c, feed_dict); | var o = sess.run(c, feed_dict); | ||||
| } | } | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void addInConstant() | |||||
| { | |||||
| var a = tf.constant(4.0f); | |||||
| var b = tf.constant(5.0f); | |||||
| var c = tf.add(a, b); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var o = sess.run(c); | |||||
| Assert.AreEqual(o, 9.0f); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||