Browse Source

fix default session

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
ce59939ac1
6 changed files with 36 additions and 14 deletions
  1. +24
    -3
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +5
    -5
      src/TensorFlowNET.Core/Sessions/FeedItem.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  5. +1
    -1
      src/TensorFlowNET.Core/ops.py.cs
  6. +3
    -2
      src/TensorFlowNET.Core/tf.cs

+ 24
- 3
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow


private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null) private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null)
{ {
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();
var feed_dict_tensor = new Dictionary<object, object>();


if (feed_dict != null) if (feed_dict != null)
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value));
@@ -79,9 +79,30 @@ namespace Tensorflow
/// name of an operation, the first Tensor output of that operation /// name of an operation, the first Tensor output of that operation
/// will be returned for that element. /// will be returned for that element.
/// </returns> /// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{ {
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
var feeds = feed_dict.Select(x =>
{
if(x.Key is Tensor tensor)
{
switch (x.Value)
{
case Tensor t1:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
case NDArray nd:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd));
case int intVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
case float floatVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
case double doubleVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
default:
break;
}
}
throw new NotImplementedException("_do_run.feed_dict");
}).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list; var targets = target_list;




+ 5
- 5
src/TensorFlowNET.Core/Sessions/FeedItem.cs View File

@@ -10,13 +10,13 @@ namespace Tensorflow
/// </summary> /// </summary>
public class FeedItem public class FeedItem
{ {
public Tensor Key { get; }
public NDArray Value { get; }
public object Key { get; }
public object Value { get; }


public FeedItem(Tensor tensor, NDArray nd)
public FeedItem(object key, object val)
{ {
Key = tensor;
Value = nd;
Key = key;
Value = val;
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow
private List<Tensor> _final_fetches = new List<Tensor>(); private List<Tensor> _final_fetches = new List<Tensor>();
private List<T> _targets = new List<T>(); private List<T> _targets = new List<T>();


public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, Action feed_handles = null)
public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
{ {
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches); _fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches()) foreach(var fetch in _fetch_mapper.unique_fetches())


+ 2
- 2
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -160,9 +160,9 @@ namespace Tensorflow


if (!_is_empty) if (!_is_empty)
{ {
/*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
var model_checkpoint_path1 = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
});*/
});
} }


throw new NotImplementedException(""); throw new NotImplementedException("");


+ 1
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -289,7 +289,7 @@ namespace Tensorflow
/// <returns>The default `Session` being used in the current thread.</returns> /// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session() public static Session get_default_session()
{ {
return tf.Session();
return tf.defaultSession;
} }


public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)


+ 3
- 2
src/TensorFlowNET.Core/tf.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow
public static Context context = new Context(new ContextOptions(), new Status()); public static Context context = new Context(new ContextOptions(), new Status());


public static Graph g = new Graph(); public static Graph g = new Graph();
public static Session session = new Session();
public static Session defaultSession;


public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid)
{ {
@@ -49,7 +49,8 @@ namespace Tensorflow


public static Session Session() public static Session Session()
{ {
return session;
defaultSession = new Session();
return defaultSession;
} }
} }
} }

Loading…
Cancel
Save