| @@ -15,8 +15,8 @@ namespace Tensorflow | |||
| public partial class Graph : IDisposable | |||
| { | |||
| private IntPtr _handle; | |||
| private Dictionary<int, Operation> _nodes_by_id; | |||
| private Dictionary<string, Operation> _nodes_by_name; | |||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
| public Dictionary<string, ITensorOrOperation> _nodes_by_name; | |||
| private Dictionary<string, int> _names_in_use; | |||
| public int _version; | |||
| private int _next_id_counter; | |||
| @@ -35,13 +35,13 @@ namespace Tensorflow | |||
| { | |||
| _handle = c_api.TF_NewGraph(); | |||
| Status = new Status(); | |||
| _nodes_by_id = new Dictionary<int, Operation>(); | |||
| _nodes_by_name = new Dictionary<string, Operation>(); | |||
| _nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||
| _nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||
| _names_in_use = new Dictionary<string, int>(); | |||
| _graph_key = $"grap-key-{ops.uid()}/"; | |||
| } | |||
| public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||
| public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||
| { | |||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | |||
| } | |||
| @@ -54,7 +54,7 @@ namespace Tensorflow | |||
| return null; | |||
| } | |||
| private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | |||
| private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||
| { | |||
| string types_str = ""; | |||
| @@ -294,7 +294,7 @@ namespace Tensorflow | |||
| return c_api.TF_GraphOperationByName(_handle, operName); | |||
| } | |||
| public Operation[] get_operations() | |||
| public ITensorOrOperation[] get_operations() | |||
| { | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| } | |||
| @@ -36,12 +36,12 @@ namespace Tensorflow | |||
| } | |||
| public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null) | |||
| public virtual NDArray run(object fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null) | |||
| private NDArray _run(object fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| @@ -49,7 +49,7 @@ namespace Tensorflow | |||
| feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor); | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| @@ -8,20 +8,36 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Fetch mapper for singleton tensors and ops. | |||
| /// </summary> | |||
| public class _ElementFetchMapper<T> : _FetchMapper<T> | |||
| public class _ElementFetchMapper : _FetchMapper | |||
| { | |||
| private List<object> _unique_fetches = new List<object>(); | |||
| private Func<List<object>, object> _contraction_fn; | |||
| public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn) | |||
| public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | |||
| { | |||
| var g = ops.get_default_graph(); | |||
| ITensorOrOperation el = null; | |||
| foreach(var fetch in fetches) | |||
| { | |||
| var g = ops.get_default_graph(); | |||
| var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||
| _unique_fetches.Add(el); | |||
| switch(fetch) | |||
| { | |||
| case Tensor tensor: | |||
| el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||
| break; | |||
| case Operation op: | |||
| el = g.as_graph_element(op, allow_tensor: true, allow_operation: true); | |||
| break; | |||
| case String str: | |||
| // Looks like a Tensor name and can be a Tensor. | |||
| el = g._nodes_by_name[str]; | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("_ElementFetchMapper"); | |||
| } | |||
| } | |||
| _unique_fetches.Add(el); | |||
| _contraction_fn = contraction_fn; | |||
| } | |||
| @@ -8,24 +8,24 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Handler for structured fetches. | |||
| /// </summary> | |||
| public class _FetchHandler<T> | |||
| public class _FetchHandler | |||
| { | |||
| private _ElementFetchMapper<T> _fetch_mapper; | |||
| private _ElementFetchMapper _fetch_mapper; | |||
| private List<Tensor> _fetches = new List<Tensor>(); | |||
| private List<bool> _ops = new List<bool>(); | |||
| private List<Tensor> _final_fetches = new List<Tensor>(); | |||
| private List<T> _targets = new List<T>(); | |||
| private List<object> _targets = new List<object>(); | |||
| public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||
| public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||
| { | |||
| _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | |||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
| { | |||
| switch (fetch) | |||
| { | |||
| case Operation val: | |||
| _assert_fetchable(graph, val); | |||
| _targets.Add((T)(object)val); | |||
| _targets.Add(val); | |||
| _ops.Add(true); | |||
| break; | |||
| case Tensor val: | |||
| @@ -82,7 +82,7 @@ namespace Tensorflow | |||
| return _final_fetches; | |||
| } | |||
| public List<T> targets() | |||
| public List<object> targets() | |||
| { | |||
| return _targets; | |||
| } | |||
| @@ -4,13 +4,13 @@ using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class _FetchMapper<T> | |||
| public class _FetchMapper | |||
| { | |||
| public _ElementFetchMapper<T> for_fetch(T fetch) | |||
| public _ElementFetchMapper for_fetch(object fetch) | |||
| { | |||
| var fetches = new List<T> { fetch }; | |||
| var fetches = new object[] { fetch }; | |||
| return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) => | |||
| return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => | |||
| { | |||
| return fetched_vals[0]; | |||
| }); | |||
| @@ -47,6 +47,11 @@ namespace Tensorflow | |||
| return g; | |||
| } | |||
| public static void ResetGraph() | |||
| { | |||
| g = new Graph(); | |||
| } | |||
| public static Session Session() | |||
| { | |||
| defaultSession = new Session(); | |||
| @@ -12,6 +12,8 @@ namespace TensorFlowNET.UnitTest | |||
| [TestMethod] | |||
| public void Gradients() | |||
| { | |||
| tf.ResetGraph(); | |||
| var a = tf.constant(0.0); | |||
| var b = 2.0 * a; | |||
| Assert.AreEqual(b.name, "mul:0"); | |||