| @@ -43,7 +43,7 @@ namespace Tensorflow | |||||
| private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null) | private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null) | ||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||||
| if (feed_dict != null) | if (feed_dict != null) | ||||
| feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | ||||
| @@ -79,9 +79,30 @@ namespace Tensorflow | |||||
| /// name of an operation, the first Tensor output of that operation | /// name of an operation, the first Tensor output of that operation | ||||
| /// will be returned for that element. | /// will be returned for that element. | ||||
| /// </returns> | /// </returns> | ||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||||
| { | { | ||||
| var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | |||||
| var feeds = feed_dict.Select(x => | |||||
| { | |||||
| if(x.Key is Tensor tensor) | |||||
| { | |||||
| switch (x.Value) | |||||
| { | |||||
| case Tensor t1: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1); | |||||
| case NDArray nd: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd)); | |||||
| case int intVal: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal)); | |||||
| case float floatVal: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal)); | |||||
| case double doubleVal: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal)); | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| throw new NotImplementedException("_do_run.feed_dict"); | |||||
| }).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | ||||
| var targets = target_list; | var targets = target_list; | ||||
| @@ -10,13 +10,13 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class FeedItem | public class FeedItem | ||||
| { | { | ||||
| public Tensor Key { get; } | |||||
| public NDArray Value { get; } | |||||
| public object Key { get; } | |||||
| public object Value { get; } | |||||
| public FeedItem(Tensor tensor, NDArray nd) | |||||
| public FeedItem(object key, object val) | |||||
| { | { | ||||
| Key = tensor; | |||||
| Value = nd; | |||||
| Key = key; | |||||
| Value = val; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -16,7 +16,7 @@ namespace Tensorflow | |||||
| private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
| private List<T> _targets = new List<T>(); | private List<T> _targets = new List<T>(); | ||||
| public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, Action feed_handles = null) | |||||
| public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); | ||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
| @@ -160,9 +160,9 @@ namespace Tensorflow | |||||
| if (!_is_empty) | if (!_is_empty) | ||||
| { | { | ||||
| /*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { | |||||
| var model_checkpoint_path1 = sess.run(_saver_def.SaveTensorName, new FeedItem[] { | |||||
| new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | ||||
| });*/ | |||||
| }); | |||||
| } | } | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| @@ -289,7 +289,7 @@ namespace Tensorflow | |||||
| /// <returns>The default `Session` being used in the current thread.</returns> | /// <returns>The default `Session` being used in the current thread.</returns> | ||||
| public static Session get_default_session() | public static Session get_default_session() | ||||
| { | { | ||||
| return tf.Session(); | |||||
| return tf.defaultSession; | |||||
| } | } | ||||
| public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) | public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) | ||||
| @@ -17,7 +17,7 @@ namespace Tensorflow | |||||
| public static Context context = new Context(new ContextOptions(), new Status()); | public static Context context = new Context(new ContextOptions(), new Status()); | ||||
| public static Graph g = new Graph(); | public static Graph g = new Graph(); | ||||
| public static Session session = new Session(); | |||||
| public static Session defaultSession; | |||||
| public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) | public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) | ||||
| { | { | ||||
| @@ -49,7 +49,8 @@ namespace Tensorflow | |||||
| public static Session Session() | public static Session Session() | ||||
| { | { | ||||
| return session; | |||||
| defaultSession = new Session(); | |||||
| return defaultSession; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||