| @@ -15,8 +15,8 @@ namespace Tensorflow | |||||
| public partial class Graph : IDisposable | public partial class Graph : IDisposable | ||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| private Dictionary<int, Operation> _nodes_by_id; | |||||
| private Dictionary<string, Operation> _nodes_by_name; | |||||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||||
| public Dictionary<string, ITensorOrOperation> _nodes_by_name; | |||||
| private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
| public int _version; | public int _version; | ||||
| private int _next_id_counter; | private int _next_id_counter; | ||||
| @@ -35,13 +35,13 @@ namespace Tensorflow | |||||
| { | { | ||||
| _handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
| Status = new Status(); | Status = new Status(); | ||||
| _nodes_by_id = new Dictionary<int, Operation>(); | |||||
| _nodes_by_name = new Dictionary<string, Operation>(); | |||||
| _nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||||
| _nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||||
| _names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
| _graph_key = $"grap-key-{ops.uid()}/"; | _graph_key = $"grap-key-{ops.uid()}/"; | ||||
| } | } | ||||
| public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| { | { | ||||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ namespace Tensorflow | |||||
| return null; | return null; | ||||
| } | } | ||||
| private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| { | { | ||||
| string types_str = ""; | string types_str = ""; | ||||
| @@ -294,7 +294,7 @@ namespace Tensorflow | |||||
| return c_api.TF_GraphOperationByName(_handle, operName); | return c_api.TF_GraphOperationByName(_handle, operName); | ||||
| } | } | ||||
| public Operation[] get_operations() | |||||
| public ITensorOrOperation[] get_operations() | |||||
| { | { | ||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
| } | } | ||||
| @@ -36,12 +36,12 @@ namespace Tensorflow | |||||
| } | } | ||||
| public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null) | |||||
| public virtual NDArray run(object fetches, FeedItem[] feed_dict = null) | |||||
| { | { | ||||
| return _run(fetches, feed_dict); | return _run(fetches, feed_dict); | ||||
| } | } | ||||
| private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null) | |||||
| private NDArray _run(object fetches, FeedItem[] feed_dict = null) | |||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<object, object>(); | var feed_dict_tensor = new Dictionary<object, object>(); | ||||
| @@ -49,7 +49,7 @@ namespace Tensorflow | |||||
| feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | ||||
| // Create a fetch handler to take care of the structure of fetches. | // Create a fetch handler to take care of the structure of fetches. | ||||
| var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor); | |||||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
| // Run request and get response. | // Run request and get response. | ||||
| // We need to keep the returned movers alive for the following _do_run(). | // We need to keep the returned movers alive for the following _do_run(). | ||||
| @@ -8,20 +8,36 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Fetch mapper for singleton tensors and ops. | /// Fetch mapper for singleton tensors and ops. | ||||
| /// </summary> | /// </summary> | ||||
| public class _ElementFetchMapper<T> : _FetchMapper<T> | |||||
| public class _ElementFetchMapper : _FetchMapper | |||||
| { | { | ||||
| private List<object> _unique_fetches = new List<object>(); | private List<object> _unique_fetches = new List<object>(); | ||||
| private Func<List<object>, object> _contraction_fn; | private Func<List<object>, object> _contraction_fn; | ||||
| public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn) | |||||
| public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | |||||
| { | { | ||||
| var g = ops.get_default_graph(); | |||||
| ITensorOrOperation el = null; | |||||
| foreach(var fetch in fetches) | foreach(var fetch in fetches) | ||||
| { | { | ||||
| var g = ops.get_default_graph(); | |||||
| var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||||
| _unique_fetches.Add(el); | |||||
| switch(fetch) | |||||
| { | |||||
| case Tensor tensor: | |||||
| el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||||
| break; | |||||
| case Operation op: | |||||
| el = g.as_graph_element(op, allow_tensor: true, allow_operation: true); | |||||
| break; | |||||
| case String str: | |||||
| // Looks like a Tensor name and can be a Tensor. | |||||
| el = g._nodes_by_name[str]; | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("_ElementFetchMapper"); | |||||
| } | |||||
| } | } | ||||
| _unique_fetches.Add(el); | |||||
| _contraction_fn = contraction_fn; | _contraction_fn = contraction_fn; | ||||
| } | } | ||||
| @@ -8,24 +8,24 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Handler for structured fetches. | /// Handler for structured fetches. | ||||
| /// </summary> | /// </summary> | ||||
| public class _FetchHandler<T> | |||||
| public class _FetchHandler | |||||
| { | { | ||||
| private _ElementFetchMapper<T> _fetch_mapper; | |||||
| private _ElementFetchMapper _fetch_mapper; | |||||
| private List<Tensor> _fetches = new List<Tensor>(); | private List<Tensor> _fetches = new List<Tensor>(); | ||||
| private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
| private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
| private List<T> _targets = new List<T>(); | |||||
| private List<object> _targets = new List<object>(); | |||||
| public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||||
| public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | |||||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
| { | { | ||||
| switch (fetch) | switch (fetch) | ||||
| { | { | ||||
| case Operation val: | case Operation val: | ||||
| _assert_fetchable(graph, val); | _assert_fetchable(graph, val); | ||||
| _targets.Add((T)(object)val); | |||||
| _targets.Add(val); | |||||
| _ops.Add(true); | _ops.Add(true); | ||||
| break; | break; | ||||
| case Tensor val: | case Tensor val: | ||||
| @@ -82,7 +82,7 @@ namespace Tensorflow | |||||
| return _final_fetches; | return _final_fetches; | ||||
| } | } | ||||
| public List<T> targets() | |||||
| public List<object> targets() | |||||
| { | { | ||||
| return _targets; | return _targets; | ||||
| } | } | ||||
| @@ -4,13 +4,13 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class _FetchMapper<T> | |||||
| public class _FetchMapper | |||||
| { | { | ||||
| public _ElementFetchMapper<T> for_fetch(T fetch) | |||||
| public _ElementFetchMapper for_fetch(object fetch) | |||||
| { | { | ||||
| var fetches = new List<T> { fetch }; | |||||
| var fetches = new object[] { fetch }; | |||||
| return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) => | |||||
| return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => | |||||
| { | { | ||||
| return fetched_vals[0]; | return fetched_vals[0]; | ||||
| }); | }); | ||||
| @@ -47,6 +47,11 @@ namespace Tensorflow | |||||
| return g; | return g; | ||||
| } | } | ||||
| public static void ResetGraph() | |||||
| { | |||||
| g = new Graph(); | |||||
| } | |||||
| public static Session Session() | public static Session Session() | ||||
| { | { | ||||
| defaultSession = new Session(); | defaultSession = new Session(); | ||||
| @@ -12,6 +12,8 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Gradients() | public void Gradients() | ||||
| { | { | ||||
| tf.ResetGraph(); | |||||
| var a = tf.constant(0.0); | var a = tf.constant(0.0); | ||||
| var b = 2.0 * a; | var b = 2.0 * a; | ||||
| Assert.AreEqual(b.name, "mul:0"); | Assert.AreEqual(b.name, "mul:0"); | ||||