| @@ -77,9 +77,9 @@ namespace Tensorflow | |||
| var temp_obj = _as_graph_element(obj); | |||
| if(obj is Tensor && allow_tensor) | |||
| if (obj is Tensor tensor && allow_tensor) | |||
| { | |||
| if ((obj as Tensor).Graph.Equals(this)) | |||
| if (tensor.Graph.Equals(this)) | |||
| { | |||
| return obj; | |||
| } | |||
| @@ -88,6 +88,17 @@ namespace Tensorflow | |||
| throw new Exception($"Tensor {obj} is not an element of this graph."); | |||
| } | |||
| } | |||
| else if (obj is Operation op && allow_operation) | |||
| { | |||
| if (op.Graph.Equals(this)) | |||
| { | |||
| return obj; | |||
| } | |||
| else | |||
| { | |||
| throw new Exception($"Operation {obj} is not an element of this graph."); | |||
| } | |||
| } | |||
| throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| @@ -10,14 +11,35 @@ namespace Tensorflow | |||
| { | |||
| using(var namescope = new ops.name_scope<Operation>(name, "group_deps", inputs)) | |||
| { | |||
| name = namescope; | |||
| var ops_on_device = new Dictionary<string, Operation[]>(); | |||
| // Sorts *inputs according to their devices. | |||
| foreach (var inp in inputs) | |||
| { | |||
| ops_on_device[inp.Device] = new Operation[] { inp }; | |||
| } | |||
| // 1-level tree. The root node is the returned NoOp node. | |||
| if (ops_on_device.Count == 1) | |||
| { | |||
| return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); | |||
| } | |||
| return _GroupControlDeps("", name); | |||
| // 2-level tree. The root node is the returned NoOp node. | |||
| // deps contains 1 NoOp node for each device. | |||
| return null; | |||
| } | |||
| } | |||
| private static Operation _GroupControlDeps(string dev, string name = "") | |||
| private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") | |||
| { | |||
| if (string.IsNullOrEmpty(dev)) | |||
| { | |||
| return gen_control_flow_ops.no_op(name); | |||
| } | |||
| return null; | |||
| } | |||
| } | |||
| @@ -0,0 +1,18 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class gen_control_flow_ops | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| public static Operation no_op(string name = "") | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("NoOp", name); | |||
| return _op; | |||
| } | |||
| } | |||
| } | |||
| @@ -40,7 +40,12 @@ namespace Tensorflow | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
| public virtual NDArray run(Operation fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| private NDArray _run<T>(T fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||
| @@ -53,7 +58,7 @@ namespace Tensorflow | |||
| } | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| @@ -65,20 +70,34 @@ namespace Tensorflow | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_fetches, feed_dict_tensor); | |||
| var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); | |||
| return fetch_handler.build_results(null, results); | |||
| } | |||
| private NDArray[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||
| /// <summary> | |||
| /// Runs a step based on the given fetches and feeds. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
| /// <param name="fetch_list"></param> | |||
| /// <param name="feed_dict"></param> | |||
| /// <returns> | |||
| /// A list of numpy ndarrays, corresponding to the elements of | |||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||
| /// name of an operation, the first Tensor output of that operation | |||
| /// will be returned for that element. | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||
| { | |||
| var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| var targets = target_list; | |||
| return _call_tf_sessionrun(feeds, fetches); | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| } | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list) | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
| { | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| @@ -95,8 +114,8 @@ namespace Tensorflow | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: fetch_list.Length, | |||
| target_opers: IntPtr.Zero, | |||
| ntargets: 0, | |||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
| ntargets: target_list.Count, | |||
| run_metadata: IntPtr.Zero, | |||
| status: status); | |||
| @@ -8,26 +8,37 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Fetch mapper for singleton tensors and ops. | |||
| /// </summary> | |||
| public class _ElementFetchMapper : _FetchMapper | |||
| public class _ElementFetchMapper<T> : _FetchMapper<T> | |||
| { | |||
| private List<Object> _unique_fetches = new List<object>(); | |||
| private Action _contraction_fn; | |||
| private List<object> _unique_fetches = new List<object>(); | |||
| private Func<List<object>> _contraction_fn; | |||
| public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn) | |||
| public _ElementFetchMapper(List<T> fetches, Func<List<object>> contraction_fn) | |||
| { | |||
| foreach(var tensor in fetches) | |||
| foreach(var fetch in fetches) | |||
| { | |||
| var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||
| _unique_fetches.Add(fetch); | |||
| var g = ops.get_default_graph(); | |||
| var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||
| _unique_fetches.Add(el); | |||
| } | |||
| _contraction_fn = contraction_fn; | |||
| } | |||
| public NDArray build_results(NDArray[] values) | |||
| /// <summary> | |||
| /// Build results matching the original fetch shape. | |||
| /// </summary> | |||
| /// <param name="values"></param> | |||
| /// <returns></returns> | |||
| public NDArray build_results(List<object> values) | |||
| { | |||
| return values[0]; | |||
| if (values.Count == 0) | |||
| return null; | |||
| else | |||
| return _contraction_fn(values); | |||
| } | |||
| public List<Object> unique_fetches() | |||
| public List<object> unique_fetches() | |||
| { | |||
| return _unique_fetches; | |||
| } | |||
| @@ -8,21 +8,26 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Handler for structured fetches. | |||
| /// </summary> | |||
| public class _FetchHandler | |||
| public class _FetchHandler<T> | |||
| { | |||
| private _ElementFetchMapper _fetch_mapper; | |||
| private _ElementFetchMapper<T> _fetch_mapper; | |||
| private List<Tensor> _fetches = new List<Tensor>(); | |||
| private List<bool> _ops = new List<bool>(); | |||
| private List<Tensor> _final_fetches = new List<Tensor>(); | |||
| private List<object> _targets = new List<object>(); | |||
| private List<T> _targets = new List<T>(); | |||
| public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||
| public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||
| { | |||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
| _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | |||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
| { | |||
| switch (fetch) | |||
| { | |||
| case Operation val: | |||
| _assert_fetchable(graph, val); | |||
| _targets.Add((T)(object)val); | |||
| _ops.Add(true); | |||
| break; | |||
| case Tensor val: | |||
| _assert_fetchable(graph, val.op); | |||
| _fetches.Add(val); | |||
| @@ -35,9 +40,19 @@ namespace Tensorflow | |||
| _final_fetches = _fetches; | |||
| } | |||
| public NDArray build_results(Session session, NDArray[] results) | |||
| public NDArray build_results(Session session, NDArray[] tensor_values) | |||
| { | |||
| return _fetch_mapper.build_results(results); | |||
| var full_values = new List<object>(); | |||
| foreach(var is_op in _ops) | |||
| { | |||
| if (is_op) | |||
| { | |||
| full_values.Add(null); | |||
| } | |||
| } | |||
| return _fetch_mapper.build_results(full_values); | |||
| } | |||
| private void _assert_fetchable(Graph graph, Operation op) | |||
| @@ -53,7 +68,7 @@ namespace Tensorflow | |||
| return _final_fetches; | |||
| } | |||
| public List<Object> targets() | |||
| public List<T> targets() | |||
| { | |||
| return _targets; | |||
| } | |||
| @@ -4,13 +4,13 @@ using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class _FetchMapper | |||
| public class _FetchMapper<T> | |||
| { | |||
| public _ElementFetchMapper for_fetch(Tensor fetch) | |||
| public _ElementFetchMapper<T> for_fetch(T fetch) | |||
| { | |||
| var fetches = new List<Tensor> { fetch }; | |||
| var fetches = new List<T> { fetch }; | |||
| return new _ElementFetchMapper(fetches, null); | |||
| return new _ElementFetchMapper<T>(fetches, null); | |||
| } | |||
| } | |||
| } | |||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||
| public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | |||
| TF_Output[] inputs, IntPtr[] input_values, int ninputs, | |||
| TF_Output[] outputs, IntPtr[] output_values, int noutputs, | |||
| IntPtr target_opers, int ntargets, | |||
| IntPtr[] target_opers, int ntargets, | |||
| IntPtr run_metadata, | |||
| IntPtr status); | |||
| } | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||
| /// <returns>An Op that run the initializers of all the specified variables.</returns> | |||
| public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | |||
| { | |||
| return control_flow_ops.group(var_list.Select(x => x.initializer).ToList()); | |||
| return control_flow_ops.group(var_list.Select(x => x.initializer).ToList(), name); | |||
| } | |||
| } | |||
| } | |||
| @@ -76,7 +76,7 @@ namespace TensorFlowNET.UnitTest | |||
| var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | |||
| var outputs_ptr = outputs_.ToArray(); | |||
| var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); | |||
| IntPtr targets_ptr = IntPtr.Zero; | |||
| IntPtr[] targets_ptr = new IntPtr[0]; | |||
| c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | |||
| outputs_ptr, output_values_ptr, outputs_.Count, | |||
| @@ -35,13 +35,13 @@ namespace TensorFlowNET.UnitTest | |||
| using (var session = tf.Session()) | |||
| { | |||
| /*session.run(model); | |||
| session.run(model); | |||
| for(int i = 0; i < 5; i++) | |||
| { | |||
| x = x + 1; | |||
| //x = x + 1; | |||
| var result = session.run(x); | |||
| print(result); | |||
| }*/ | |||
| } | |||
| } | |||
| } | |||