| @@ -77,9 +77,9 @@ namespace Tensorflow | |||||
| var temp_obj = _as_graph_element(obj); | var temp_obj = _as_graph_element(obj); | ||||
| if(obj is Tensor && allow_tensor) | |||||
| if (obj is Tensor tensor && allow_tensor) | |||||
| { | { | ||||
| if ((obj as Tensor).Graph.Equals(this)) | |||||
| if (tensor.Graph.Equals(this)) | |||||
| { | { | ||||
| return obj; | return obj; | ||||
| } | } | ||||
| @@ -88,6 +88,17 @@ namespace Tensorflow | |||||
| throw new Exception($"Tensor {obj} is not an element of this graph."); | throw new Exception($"Tensor {obj} is not an element of this graph."); | ||||
| } | } | ||||
| } | } | ||||
| else if (obj is Operation op && allow_operation) | |||||
| { | |||||
| if (op.Graph.Equals(this)) | |||||
| { | |||||
| return obj; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new Exception($"Operation {obj} is not an element of this graph."); | |||||
| } | |||||
| } | |||||
| throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | ||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -10,14 +11,35 @@ namespace Tensorflow | |||||
| { | { | ||||
| using(var namescope = new ops.name_scope<Operation>(name, "group_deps", inputs)) | using(var namescope = new ops.name_scope<Operation>(name, "group_deps", inputs)) | ||||
| { | { | ||||
| name = namescope; | |||||
| var ops_on_device = new Dictionary<string, Operation[]>(); | |||||
| // Sorts *inputs according to their devices. | // Sorts *inputs according to their devices. | ||||
| foreach (var inp in inputs) | |||||
| { | |||||
| ops_on_device[inp.Device] = new Operation[] { inp }; | |||||
| } | |||||
| // 1-level tree. The root node is the returned NoOp node. | |||||
| if (ops_on_device.Count == 1) | |||||
| { | |||||
| return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); | |||||
| } | |||||
| return _GroupControlDeps("", name); | |||||
| // 2-level tree. The root node is the returned NoOp node. | |||||
| // deps contains 1 NoOp node for each device. | |||||
| return null; | |||||
| } | } | ||||
| } | } | ||||
| private static Operation _GroupControlDeps(string dev, string name = "") | |||||
| private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") | |||||
| { | { | ||||
| if (string.IsNullOrEmpty(dev)) | |||||
| { | |||||
| return gen_control_flow_ops.no_op(name); | |||||
| } | |||||
| return null; | return null; | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class gen_control_flow_ops | |||||
| { | |||||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||||
| public static Operation no_op(string name = "") | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("NoOp", name); | |||||
| return _op; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -40,7 +40,12 @@ namespace Tensorflow | |||||
| return _run(fetches, feed_dict); | return _run(fetches, feed_dict); | ||||
| } | } | ||||
| private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
| public virtual NDArray run(Operation fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
| { | |||||
| return _run(fetches, feed_dict); | |||||
| } | |||||
| private NDArray _run<T>(T fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | ||||
| @@ -53,7 +58,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| // Create a fetch handler to take care of the structure of fetches. | // Create a fetch handler to take care of the structure of fetches. | ||||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
| var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor); | |||||
| // Run request and get response. | // Run request and get response. | ||||
| // We need to keep the returned movers alive for the following _do_run(). | // We need to keep the returned movers alive for the following _do_run(). | ||||
| @@ -65,20 +70,34 @@ namespace Tensorflow | |||||
| // We only want to really perform the run if fetches or targets are provided, | // We only want to really perform the run if fetches or targets are provided, | ||||
| // or if the call is a partial run that specifies feeds. | // or if the call is a partial run that specifies feeds. | ||||
| var results = _do_run(final_fetches, feed_dict_tensor); | |||||
| var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); | |||||
| return fetch_handler.build_results(null, results); | return fetch_handler.build_results(null, results); | ||||
| } | } | ||||
| private NDArray[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||||
| /// <summary> | |||||
| /// Runs a step based on the given fetches and feeds. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||||
| /// <param name="fetch_list"></param> | |||||
| /// <param name="feed_dict"></param> | |||||
| /// <returns> | |||||
| /// A list of numpy ndarrays, corresponding to the elements of | |||||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||||
| /// name of an operation, the first Tensor output of that operation | |||||
| /// will be returned for that element. | |||||
| /// </returns> | |||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||||
| { | { | ||||
| var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | ||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | ||||
| var targets = target_list; | |||||
| return _call_tf_sessionrun(feeds, fetches); | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
| } | } | ||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list) | |||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||||
| { | { | ||||
| // Ensure any changes to the graph are reflected in the runtime. | // Ensure any changes to the graph are reflected in the runtime. | ||||
| _extend_graph(); | _extend_graph(); | ||||
| @@ -95,8 +114,8 @@ namespace Tensorflow | |||||
| outputs: fetch_list, | outputs: fetch_list, | ||||
| output_values: output_values, | output_values: output_values, | ||||
| noutputs: fetch_list.Length, | noutputs: fetch_list.Length, | ||||
| target_opers: IntPtr.Zero, | |||||
| ntargets: 0, | |||||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
| ntargets: target_list.Count, | |||||
| run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
| status: status); | status: status); | ||||
| @@ -8,26 +8,37 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Fetch mapper for singleton tensors and ops. | /// Fetch mapper for singleton tensors and ops. | ||||
| /// </summary> | /// </summary> | ||||
| public class _ElementFetchMapper : _FetchMapper | |||||
| public class _ElementFetchMapper<T> : _FetchMapper<T> | |||||
| { | { | ||||
| private List<Object> _unique_fetches = new List<object>(); | |||||
| private Action _contraction_fn; | |||||
| private List<object> _unique_fetches = new List<object>(); | |||||
| private Func<List<object>> _contraction_fn; | |||||
| public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn) | |||||
| public _ElementFetchMapper(List<T> fetches, Func<List<object>> contraction_fn) | |||||
| { | { | ||||
| foreach(var tensor in fetches) | |||||
| foreach(var fetch in fetches) | |||||
| { | { | ||||
| var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||||
| _unique_fetches.Add(fetch); | |||||
| var g = ops.get_default_graph(); | |||||
| var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||||
| _unique_fetches.Add(el); | |||||
| } | } | ||||
| _contraction_fn = contraction_fn; | |||||
| } | } | ||||
| public NDArray build_results(NDArray[] values) | |||||
| /// <summary> | |||||
| /// Build results matching the original fetch shape. | |||||
| /// </summary> | |||||
| /// <param name="values"></param> | |||||
| /// <returns></returns> | |||||
| public NDArray build_results(List<object> values) | |||||
| { | { | ||||
| return values[0]; | |||||
| if (values.Count == 0) | |||||
| return null; | |||||
| else | |||||
| return _contraction_fn(values); | |||||
| } | } | ||||
| public List<Object> unique_fetches() | |||||
| public List<object> unique_fetches() | |||||
| { | { | ||||
| return _unique_fetches; | return _unique_fetches; | ||||
| } | } | ||||
| @@ -8,21 +8,26 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Handler for structured fetches. | /// Handler for structured fetches. | ||||
| /// </summary> | /// </summary> | ||||
| public class _FetchHandler | |||||
| public class _FetchHandler<T> | |||||
| { | { | ||||
| private _ElementFetchMapper _fetch_mapper; | |||||
| private _ElementFetchMapper<T> _fetch_mapper; | |||||
| private List<Tensor> _fetches = new List<Tensor>(); | private List<Tensor> _fetches = new List<Tensor>(); | ||||
| private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
| private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
| private List<object> _targets = new List<object>(); | |||||
| private List<T> _targets = new List<T>(); | |||||
| public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||||
| public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||||
| _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | |||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
| { | { | ||||
| switch (fetch) | switch (fetch) | ||||
| { | { | ||||
| case Operation val: | |||||
| _assert_fetchable(graph, val); | |||||
| _targets.Add((T)(object)val); | |||||
| _ops.Add(true); | |||||
| break; | |||||
| case Tensor val: | case Tensor val: | ||||
| _assert_fetchable(graph, val.op); | _assert_fetchable(graph, val.op); | ||||
| _fetches.Add(val); | _fetches.Add(val); | ||||
| @@ -35,9 +40,19 @@ namespace Tensorflow | |||||
| _final_fetches = _fetches; | _final_fetches = _fetches; | ||||
| } | } | ||||
| public NDArray build_results(Session session, NDArray[] results) | |||||
| public NDArray build_results(Session session, NDArray[] tensor_values) | |||||
| { | { | ||||
| return _fetch_mapper.build_results(results); | |||||
| var full_values = new List<object>(); | |||||
| foreach(var is_op in _ops) | |||||
| { | |||||
| if (is_op) | |||||
| { | |||||
| full_values.Add(null); | |||||
| } | |||||
| } | |||||
| return _fetch_mapper.build_results(full_values); | |||||
| } | } | ||||
| private void _assert_fetchable(Graph graph, Operation op) | private void _assert_fetchable(Graph graph, Operation op) | ||||
| @@ -53,7 +68,7 @@ namespace Tensorflow | |||||
| return _final_fetches; | return _final_fetches; | ||||
| } | } | ||||
| public List<Object> targets() | |||||
| public List<T> targets() | |||||
| { | { | ||||
| return _targets; | return _targets; | ||||
| } | } | ||||
| @@ -4,13 +4,13 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class _FetchMapper | |||||
| public class _FetchMapper<T> | |||||
| { | { | ||||
| public _ElementFetchMapper for_fetch(Tensor fetch) | |||||
| public _ElementFetchMapper<T> for_fetch(T fetch) | |||||
| { | { | ||||
| var fetches = new List<Tensor> { fetch }; | |||||
| var fetches = new List<T> { fetch }; | |||||
| return new _ElementFetchMapper(fetches, null); | |||||
| return new _ElementFetchMapper<T>(fetches, null); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||||
| public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | ||||
| TF_Output[] inputs, IntPtr[] input_values, int ninputs, | TF_Output[] inputs, IntPtr[] input_values, int ninputs, | ||||
| TF_Output[] outputs, IntPtr[] output_values, int noutputs, | TF_Output[] outputs, IntPtr[] output_values, int noutputs, | ||||
| IntPtr target_opers, int ntargets, | |||||
| IntPtr[] target_opers, int ntargets, | |||||
| IntPtr run_metadata, | IntPtr run_metadata, | ||||
| IntPtr status); | IntPtr status); | ||||
| } | } | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||||
| /// <returns>An Op that run the initializers of all the specified variables.</returns> | /// <returns>An Op that run the initializers of all the specified variables.</returns> | ||||
| public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | ||||
| { | { | ||||
| return control_flow_ops.group(var_list.Select(x => x.initializer).ToList()); | |||||
| return control_flow_ops.group(var_list.Select(x => x.initializer).ToList(), name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -76,7 +76,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | ||||
| var outputs_ptr = outputs_.ToArray(); | var outputs_ptr = outputs_.ToArray(); | ||||
| var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); | var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); | ||||
| IntPtr targets_ptr = IntPtr.Zero; | |||||
| IntPtr[] targets_ptr = new IntPtr[0]; | |||||
| c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | ||||
| outputs_ptr, output_values_ptr, outputs_.Count, | outputs_ptr, output_values_ptr, outputs_.Count, | ||||
| @@ -35,13 +35,13 @@ namespace TensorFlowNET.UnitTest | |||||
| using (var session = tf.Session()) | using (var session = tf.Session()) | ||||
| { | { | ||||
| /*session.run(model); | |||||
| session.run(model); | |||||
| for(int i = 0; i < 5; i++) | for(int i = 0; i < 5; i++) | ||||
| { | { | ||||
| x = x + 1; | |||||
| //x = x + 1; | |||||
| var result = session.run(x); | var result = session.run(x); | ||||
| print(result); | print(result); | ||||
| }*/ | |||||
| } | |||||
| } | } | ||||
| } | } | ||||