| @@ -72,7 +72,7 @@ namespace Tensorflow | |||||
| // or if the call is a partial run that specifies feeds. | // or if the call is a partial run that specifies feeds. | ||||
| var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); | var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); | ||||
| return fetch_handler.build_results(null, results); | |||||
| return fetch_handler.build_results(this, results); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -11,9 +11,9 @@ namespace Tensorflow | |||||
| public class _ElementFetchMapper<T> : _FetchMapper<T> | public class _ElementFetchMapper<T> : _FetchMapper<T> | ||||
| { | { | ||||
| private List<object> _unique_fetches = new List<object>(); | private List<object> _unique_fetches = new List<object>(); | ||||
| private Func<List<object>, NDArray> _contraction_fn; | |||||
| private Func<List<object>, object> _contraction_fn; | |||||
| public _ElementFetchMapper(List<T> fetches, Func<List<object>, NDArray> contraction_fn) | |||||
| public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn) | |||||
| { | { | ||||
| foreach(var fetch in fetches) | foreach(var fetch in fetches) | ||||
| { | { | ||||
| @@ -32,10 +32,22 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public NDArray build_results(List<object> values) | public NDArray build_results(List<object> values) | ||||
| { | { | ||||
| if (values.Count == 0) | |||||
| return null; | |||||
| else | |||||
| return _contraction_fn(values); | |||||
| NDArray result = null; | |||||
| if (values.Count > 0) | |||||
| { | |||||
| var ret = _contraction_fn(values); | |||||
| switch (ret) | |||||
| { | |||||
| case NDArray value: | |||||
| result = value; | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| return result; | |||||
| } | } | ||||
| public List<object> unique_fetches() | public List<object> unique_fetches() | ||||
| @@ -16,7 +16,7 @@ namespace Tensorflow | |||||
| private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
| private List<T> _targets = new List<T>(); | private List<T> _targets = new List<T>(); | ||||
| public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||||
| public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, Action feed_handles = null) | |||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | ||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
| @@ -40,18 +40,32 @@ namespace Tensorflow | |||||
| _final_fetches = _fetches; | _final_fetches = _fetches; | ||||
| } | } | ||||
| public NDArray build_results(Session session, NDArray[] tensor_values) | |||||
| public NDArray build_results(BaseSession session, NDArray[] tensor_values) | |||||
| { | { | ||||
| var full_values = new List<object>(); | var full_values = new List<object>(); | ||||
| if (_final_fetches.Count != tensor_values.Length) | |||||
| throw new InvalidOperationException("_final_fetches mismatch tensor_values"); | |||||
| int i = 0; | |||||
| int j = 0; | |||||
| foreach(var is_op in _ops) | foreach(var is_op in _ops) | ||||
| { | { | ||||
| if (is_op) | if (is_op) | ||||
| { | { | ||||
| full_values.Add(null); | full_values.Add(null); | ||||
| } | } | ||||
| else | |||||
| { | |||||
| var value = tensor_values[j]; | |||||
| j += 1; | |||||
| full_values.Add(value); | |||||
| } | |||||
| i += 1; | |||||
| } | } | ||||
| if (j != tensor_values.Length) | |||||
| throw new InvalidOperationException("j mismatch tensor_values"); | |||||
| return _fetch_mapper.build_results(full_values); | return _fetch_mapper.build_results(full_values); | ||||
| } | } | ||||
| @@ -10,7 +10,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| var fetches = new List<T> { fetch }; | var fetches = new List<T> { fetch }; | ||||
| return new _ElementFetchMapper<T>(fetches, null); | |||||
| return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) => | |||||
| { | |||||
| return fetched_vals[0]; | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,10 +25,5 @@ namespace Tensorflow | |||||
| { | { | ||||
| return new Tensor(handle); | return new Tensor(handle); | ||||
| } | } | ||||
| public static implicit operator Tensor(RefVariable var) | |||||
| { | |||||
| return var._initial_value; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,24 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class RefVariable | |||||
| { | |||||
| public static implicit operator _VariableScopeStore(RefVariable variable) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| public static implicit operator RefVariable(_VariableScopeStore store) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| public static implicit operator Tensor(RefVariable var) | |||||
| { | |||||
| return var._AsTensor(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class RefVariable | |||||
| { | |||||
| public static Tensor operator +(RefVariable t1, int t2) | |||||
| { | |||||
| var tensor1 = t1._AsTensor(); | |||||
| var tensor2 = ops.convert_to_tensor(t2, tensor1.dtype, "y"); | |||||
| return gen_math_ops.add(tensor1, tensor2); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -4,7 +4,7 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class RefVariable : VariableV1 | |||||
| public partial class RefVariable : VariableV1 | |||||
| { | { | ||||
| public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
| public Tensor _initial_value; | public Tensor _initial_value; | ||||
| @@ -106,14 +106,9 @@ namespace Tensorflow | |||||
| return _variable; | return _variable; | ||||
| } | } | ||||
| public static implicit operator _VariableScopeStore(RefVariable variable) | |||||
| public Tensor _AsTensor() | |||||
| { | { | ||||
| return null; | |||||
| } | |||||
| public static implicit operator RefVariable(_VariableScopeStore store) | |||||
| { | |||||
| return null; | |||||
| return _snapshot; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -59,7 +59,14 @@ namespace Tensorflow | |||||
| return get_default_graph(); | return get_default_graph(); | ||||
| } | } | ||||
| public static Tensor convert_to_tensor(object value, string name = "") | |||||
| /// <summary> | |||||
| /// Converts the given `value` to a `Tensor`. | |||||
| /// </summary> | |||||
| /// <param name="value"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||||
| { | { | ||||
| switch (value) | switch (value) | ||||
| { | { | ||||
| @@ -67,7 +74,7 @@ namespace Tensorflow | |||||
| return val; | return val; | ||||
| default: | default: | ||||
| var nd = tensor_util.convert_to_numpy_ndarray(value); | var nd = tensor_util.convert_to_numpy_ndarray(value); | ||||
| return tf.constant(nd, name); | |||||
| return constant_op.Constant(nd, name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -38,8 +38,8 @@ namespace TensorFlowNET.UnitTest | |||||
| session.run(model); | session.run(model); | ||||
| for(int i = 0; i < 5; i++) | for(int i = 0; i < 5; i++) | ||||
| { | { | ||||
| // x = x + 1; | |||||
| var result = session.run(x); | |||||
| var x1 = x + 1; | |||||
| var result = session.run(x1); | |||||
| print(result); | print(result); | ||||
| } | } | ||||
| } | } | ||||