From ce59939ac114afe0260da2eba9b143debbf22cb1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 10 Feb 2019 08:14:44 -0600 Subject: [PATCH] fix default session --- .../Sessions/BaseSession.cs | 27 ++++++++++++++++--- src/TensorFlowNET.Core/Sessions/FeedItem.cs | 10 +++---- .../Sessions/_FetchHandler.cs | 2 +- src/TensorFlowNET.Core/Train/Saving/Saver.cs | 4 +-- src/TensorFlowNET.Core/ops.py.cs | 2 +- src/TensorFlowNET.Core/tf.cs | 5 ++-- 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 82eee051..54ba3759 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -43,7 +43,7 @@ namespace Tensorflow private NDArray _run(T fetches, FeedItem[] feed_dict = null) { - var feed_dict_tensor = new Dictionary(); + var feed_dict_tensor = new Dictionary(); if (feed_dict != null) feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); @@ -79,9 +79,30 @@ namespace Tensorflow /// name of an operation, the first Tensor output of that operation /// will be returned for that element. /// - private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) + private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { - var feeds = feed_dict.Select(x => new KeyValuePair(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); + var feeds = feed_dict.Select(x => + { + if(x.Key is Tensor tensor) + { + switch (x.Value) + { + case Tensor t1: + return new KeyValuePair(tensor._as_tf_output(), t1); + case NDArray nd: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(nd)); + case int intVal: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(intVal)); + case float floatVal: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(floatVal)); + case double doubleVal: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(doubleVal)); + default: + break; + } + } + throw new NotImplementedException("_do_run.feed_dict"); + }).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); var targets = target_list; diff --git a/src/TensorFlowNET.Core/Sessions/FeedItem.cs b/src/TensorFlowNET.Core/Sessions/FeedItem.cs index 06060d03..ad798ef8 100644 --- a/src/TensorFlowNET.Core/Sessions/FeedItem.cs +++ b/src/TensorFlowNET.Core/Sessions/FeedItem.cs @@ -10,13 +10,13 @@ namespace Tensorflow /// public class FeedItem { - public Tensor Key { get; } - public NDArray Value { get; } + public object Key { get; } + public object Value { get; } - public FeedItem(Tensor tensor, NDArray nd) + public FeedItem(object key, object val) { - Key = tensor; - Value = nd; + Key = key; + Value = val; } } } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index a6d9c711..f4e699cb 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -16,7 +16,7 @@ namespace Tensorflow private List _final_fetches = new List(); private List _targets = new List(); - public _FetchHandler(Graph graph, T fetches, Dictionary feeds = null, Action feed_handles = null) + public _FetchHandler(Graph graph, T fetches, Dictionary feeds = null, Action feed_handles = null) { _fetch_mapper = new _FetchMapper().for_fetch(fetches); foreach(var fetch in _fetch_mapper.unique_fetches()) diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index 577fba58..90f0fcd0 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -160,9 +160,9 @@ namespace Tensorflow if (!_is_empty) { - /*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { + var model_checkpoint_path1 = sess.run(_saver_def.SaveTensorName, new FeedItem[] { new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) - });*/ + }); } throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index fbbbebd7..e43881c1 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -289,7 +289,7 @@ namespace Tensorflow /// The default `Session` being used in the current thread. public static Session get_default_session() { - return tf.Session(); + return tf.defaultSession; } public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index fed4ea54..ed5b428c 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -17,7 +17,7 @@ namespace Tensorflow public static Context context = new Context(new ContextOptions(), new Status()); public static Graph g = new Graph(); - public static Session session = new Session(); + public static Session defaultSession; public static RefVariable Variable(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) { @@ -49,7 +49,8 @@ namespace Tensorflow public static Session Session() { - return session; + defaultSession = new Session(); + return defaultSession; } } }