diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 7460dea1..2bd71e49 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -72,7 +72,7 @@ namespace Tensorflow // or if the call is a partial run that specifies feeds. var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); - return fetch_handler.build_results(null, results); + return fetch_handler.build_results(this, results); } /// diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index d1c09533..7960d200 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -11,9 +11,9 @@ namespace Tensorflow public class _ElementFetchMapper : _FetchMapper { private List _unique_fetches = new List(); - private Func, NDArray> _contraction_fn; + private Func, object> _contraction_fn; - public _ElementFetchMapper(List fetches, Func, NDArray> contraction_fn) + public _ElementFetchMapper(List fetches, Func, object> contraction_fn) { foreach(var fetch in fetches) { @@ -32,10 +32,22 @@ namespace Tensorflow /// public NDArray build_results(List values) { - if (values.Count == 0) - return null; - else - return _contraction_fn(values); + NDArray result = null; + + if (values.Count > 0) + { + var ret = _contraction_fn(values); + switch (ret) + { + case NDArray value: + result = value; + break; + default: + break; + } + } + + return result; } public List unique_fetches() diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 4e709f76..9aced985 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -16,7 +16,7 @@ namespace Tensorflow private List _final_fetches = new List(); private List _targets = new List(); - public _FetchHandler(Graph graph, T fetches, Dictionary feeds = null, object feed_handles = null) + public _FetchHandler(Graph graph, T fetches, Dictionary feeds = null, Action feed_handles = null) { _fetch_mapper = new _FetchMapper().for_fetch(fetches); foreach(var fetch in _fetch_mapper.unique_fetches()) @@ -40,18 +40,32 @@ namespace Tensorflow _final_fetches = _fetches; } - public NDArray build_results(Session session, NDArray[] tensor_values) + public NDArray build_results(BaseSession session, NDArray[] tensor_values) { var full_values = new List(); + if (_final_fetches.Count != tensor_values.Length) + throw new InvalidOperationException("_final_fetches mismatch tensor_values"); + int i = 0; + int j = 0; foreach(var is_op in _ops) { if (is_op) { full_values.Add(null); } + else + { + var value = tensor_values[j]; + j += 1; + full_values.Add(value); + } + i += 1; } + if (j != tensor_values.Length) + throw new InvalidOperationException("j mismatch tensor_values"); + return _fetch_mapper.build_results(full_values); } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index fbad8db6..b5eff215 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -10,7 +10,10 @@ namespace Tensorflow { var fetches = new List { fetch }; - return new _ElementFetchMapper(fetches, null); + return new _ElementFetchMapper(fetches, (List fetched_vals) => + { + return fetched_vals[0]; + }); } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs index 30bdf488..5f1860c6 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -25,10 +25,5 @@ namespace Tensorflow { return new Tensor(handle); } - - public static implicit operator Tensor(RefVariable var) - { - return var._initial_value; - } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs new file mode 100644 index 00000000..6e4d28f5 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class RefVariable + { + public static implicit operator _VariableScopeStore(RefVariable variable) + { + return null; + } + + public static implicit operator RefVariable(_VariableScopeStore store) + { + return null; + } + + public static implicit operator Tensor(RefVariable var) + { + return var._AsTensor(); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs new file mode 100644 index 00000000..85fb19f8 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class RefVariable + { + public static Tensor operator +(RefVariable t1, int t2) + { + var tensor1 = t1._AsTensor(); + var tensor2 = ops.convert_to_tensor(t2, tensor1.dtype, "y"); + return gen_math_ops.add(tensor1, tensor2); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 96fc6fde..d539cfa6 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow { - public class RefVariable : VariableV1 + public partial class RefVariable : VariableV1 { public bool _in_graph_mode = true; public Tensor _initial_value; @@ -106,14 +106,9 @@ namespace Tensorflow return _variable; } - public static implicit operator _VariableScopeStore(RefVariable variable) + public Tensor _AsTensor() { - return null; - } - - public static implicit operator RefVariable(_VariableScopeStore store) - { - return null; + return _snapshot; } } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 6fbc4699..5fe914f6 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -59,7 +59,14 @@ namespace Tensorflow return get_default_graph(); } - public static Tensor convert_to_tensor(object value, string name = "") + /// + /// Converts the given `value` to a `Tensor`. + /// + /// + /// + /// + /// + public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") { switch (value) { @@ -67,7 +74,7 @@ namespace Tensorflow return val; default: var nd = tensor_util.convert_to_numpy_ndarray(value); - return tf.constant(nd, name); + return constant_op.Constant(nd, name); } } diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 52023ac6..979a84cf 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -38,8 +38,8 @@ namespace TensorFlowNET.UnitTest session.run(model); for(int i = 0; i < 5; i++) { - // x = x + 1; - var result = session.run(x); + var x1 = x + 1; + var result = session.run(x1); print(result); } }