diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 512e08c2..dca54511 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -77,9 +77,9 @@ namespace Tensorflow var temp_obj = _as_graph_element(obj); - if(obj is Tensor && allow_tensor) + if (obj is Tensor tensor && allow_tensor) { - if ((obj as Tensor).Graph.Equals(this)) + if (tensor.Graph.Equals(this)) { return obj; } @@ -88,6 +88,17 @@ namespace Tensorflow throw new Exception($"Tensor {obj} is not an element of this graph."); } } + else if (obj is Operation op && allow_operation) + { + if (op.Graph.Equals(this)) + { + return obj; + } + else + { + throw new Exception($"Operation {obj} is not an element of this graph."); + } + } throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index b6fe4d78..37b91448 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -10,14 +11,35 @@ namespace Tensorflow { using(var namescope = new ops.name_scope(name, "group_deps", inputs)) { + name = namescope; + + var ops_on_device = new Dictionary(); + // Sorts *inputs according to their devices. + foreach (var inp in inputs) + { + ops_on_device[inp.Device] = new Operation[] { inp }; + } + + // 1-level tree. The root node is the returned NoOp node. + if (ops_on_device.Count == 1) + { + return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); + } - return _GroupControlDeps("", name); + // 2-level tree. The root node is the returned NoOp node. + // deps contains 1 NoOp node for each device. + return null; } } - private static Operation _GroupControlDeps(string dev, string name = "") + private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") { + if (string.IsNullOrEmpty(dev)) + { + return gen_control_flow_ops.no_op(name); + } + return null; } } diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs new file mode 100644 index 00000000..c17753dd --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class gen_control_flow_ops + { + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + + public static Operation no_op(string name = "") + { + var _op = _op_def_lib._apply_op_helper("NoOp", name); + + return _op; + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 3208c349..7460dea1 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -40,7 +40,12 @@ namespace Tensorflow return _run(fetches, feed_dict); } - private NDArray _run(Tensor fetches, Dictionary feed_dict = null) + public virtual NDArray run(Operation fetches, Dictionary feed_dict = null) + { + return _run(fetches, feed_dict); + } + + private NDArray _run(T fetches, Dictionary feed_dict = null) { var feed_dict_tensor = new Dictionary(); @@ -53,7 +58,7 @@ namespace Tensorflow } // Create a fetch handler to take care of the structure of fetches. - var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); + var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); // Run request and get response. // We need to keep the returned movers alive for the following _do_run(). @@ -65,20 +70,34 @@ namespace Tensorflow // We only want to really perform the run if fetches or targets are provided, // or if the call is a partial run that specifies feeds. - var results = _do_run(final_fetches, feed_dict_tensor); + var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); return fetch_handler.build_results(null, results); } - private NDArray[] _do_run(List fetch_list, Dictionary feed_dict) + /// + /// Runs a step based on the given fetches and feeds. + /// + /// + /// A list of operations to be run, but not fetched. + /// + /// + /// + /// A list of numpy ndarrays, corresponding to the elements of + /// `fetch_list`. If the ith element of `fetch_list` contains the + /// name of an operation, the first Tensor output of that operation + /// will be returned for that element. + /// + private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { var feeds = feed_dict.Select(x => new KeyValuePair(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); + var targets = target_list; - return _call_tf_sessionrun(feeds, fetches); + return _call_tf_sessionrun(feeds, fetches, target_list); } - private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list) + private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) { // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); @@ -95,8 +114,8 @@ namespace Tensorflow outputs: fetch_list, output_values: output_values, noutputs: fetch_list.Length, - target_opers: IntPtr.Zero, - ntargets: 0, + target_opers: target_list.Select(f => (IntPtr)f).ToArray(), + ntargets: target_list.Count, run_metadata: IntPtr.Zero, status: status); diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index 83068972..41f67417 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -8,26 +8,37 @@ namespace Tensorflow /// /// Fetch mapper for singleton tensors and ops. /// - public class _ElementFetchMapper : _FetchMapper + public class _ElementFetchMapper : _FetchMapper { - private List _unique_fetches = new List(); - private Action _contraction_fn; + private List _unique_fetches = new List(); + private Func> _contraction_fn; - public _ElementFetchMapper(List fetches, Action contraction_fn) + public _ElementFetchMapper(List fetches, Func> contraction_fn) { - foreach(var tensor in fetches) + foreach(var fetch in fetches) { - var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); - _unique_fetches.Add(fetch); + var g = ops.get_default_graph(); + var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); + _unique_fetches.Add(el); } + + _contraction_fn = contraction_fn; } - public NDArray build_results(NDArray[] values) + /// + /// Build results matching the original fetch shape. + /// + /// + /// + public NDArray build_results(List values) { - return values[0]; + if (values.Count == 0) + return null; + else + return _contraction_fn(values); } - public List unique_fetches() + public List unique_fetches() { return _unique_fetches; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 1da14507..4e709f76 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -8,21 +8,26 @@ namespace Tensorflow /// /// Handler for structured fetches. /// - public class _FetchHandler + public class _FetchHandler { - private _ElementFetchMapper _fetch_mapper; + private _ElementFetchMapper _fetch_mapper; private List _fetches = new List(); private List _ops = new List(); private List _final_fetches = new List(); - private List _targets = new List(); + private List _targets = new List(); - public _FetchHandler(Graph graph, Tensor fetches, Dictionary feeds = null, object feed_handles = null) + public _FetchHandler(Graph graph, T fetches, Dictionary feeds = null, object feed_handles = null) { - _fetch_mapper = new _FetchMapper().for_fetch(fetches); + _fetch_mapper = new _FetchMapper().for_fetch(fetches); foreach(var fetch in _fetch_mapper.unique_fetches()) { switch (fetch) { + case Operation val: + _assert_fetchable(graph, val); + _targets.Add((T)(object)val); + _ops.Add(true); + break; case Tensor val: _assert_fetchable(graph, val.op); _fetches.Add(val); @@ -35,9 +40,19 @@ namespace Tensorflow _final_fetches = _fetches; } - public NDArray build_results(Session session, NDArray[] results) + public NDArray build_results(Session session, NDArray[] tensor_values) { - return _fetch_mapper.build_results(results); + var full_values = new List(); + + foreach(var is_op in _ops) + { + if (is_op) + { + full_values.Add(null); + } + } + + return _fetch_mapper.build_results(full_values); } private void _assert_fetchable(Graph graph, Operation op) @@ -53,7 +68,7 @@ namespace Tensorflow return _final_fetches; } - public List targets() + public List targets() { return _targets; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 763b67a0..fbad8db6 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -4,13 +4,13 @@ using System.Text; namespace Tensorflow { - public class _FetchMapper + public class _FetchMapper { - public _ElementFetchMapper for_fetch(Tensor fetch) + public _ElementFetchMapper for_fetch(T fetch) { - var fetches = new List { fetch }; + var fetches = new List { fetch }; - return new _ElementFetchMapper(fetches, null); + return new _ElementFetchMapper(fetches, null); } } } diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 3fc20365..b1860cd2 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -87,7 +87,7 @@ namespace Tensorflow public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, TF_Output[] inputs, IntPtr[] input_values, int ninputs, TF_Output[] outputs, IntPtr[] output_values, int noutputs, - IntPtr target_opers, int ntargets, + IntPtr[] target_opers, int ntargets, IntPtr run_metadata, IntPtr status); } diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index ecb635ce..ea317a89 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -42,7 +42,7 @@ namespace Tensorflow /// An Op that run the initializers of all the specified variables. public static Operation variables_initializer(RefVariable[] var_list, string name = "init") { - return control_flow_ops.group(var_list.Select(x => x.initializer).ToList()); + return control_flow_ops.group(var_list.Select(x => x.initializer).ToList(), name); } } } diff --git a/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll b/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll index d667cffa..4c62c8ce 100644 Binary files a/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll and b/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll differ diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index 70e549f3..77de6b6b 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -76,7 +76,7 @@ namespace TensorFlowNET.UnitTest var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); var outputs_ptr = outputs_.ToArray(); var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); - IntPtr targets_ptr = IntPtr.Zero; + IntPtr[] targets_ptr = new IntPtr[0]; c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, outputs_ptr, output_values_ptr, outputs_.Count, diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index f316c3f0..c5d40b75 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -35,13 +35,13 @@ namespace TensorFlowNET.UnitTest using (var session = tf.Session()) { - /*session.run(model); + session.run(model); for(int i = 0; i < 5; i++) { - x = x + 1; + //x = x + 1; var result = session.run(x); print(result); - }*/ + } } }