Browse Source

fix default session

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
ce59939ac1
6 changed files with 36 additions and 14 deletions
  1. +24
    -3
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +5
    -5
      src/TensorFlowNET.Core/Sessions/FeedItem.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  5. +1
    -1
      src/TensorFlowNET.Core/ops.py.cs
  6. +3
    -2
      src/TensorFlowNET.Core/tf.cs

+ 24
- 3
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow

private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();
var feed_dict_tensor = new Dictionary<object, object>();

if (feed_dict != null)
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value));
@@ -79,9 +79,30 @@ namespace Tensorflow
/// name of an operation, the first Tensor output of that operation
/// will be returned for that element.
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
var feeds = feed_dict.Select(x =>
{
if(x.Key is Tensor tensor)
{
switch (x.Value)
{
case Tensor t1:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
case NDArray nd:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd));
case int intVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
case float floatVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
case double doubleVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
default:
break;
}
}
throw new NotImplementedException("_do_run.feed_dict");
}).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list;



+ 5
- 5
src/TensorFlowNET.Core/Sessions/FeedItem.cs View File

@@ -10,13 +10,13 @@ namespace Tensorflow
/// </summary>
public class FeedItem
{
public Tensor Key { get; }
public NDArray Value { get; }
public object Key { get; }
public object Value { get; }

public FeedItem(Tensor tensor, NDArray nd)
public FeedItem(object key, object val)
{
Key = tensor;
Value = nd;
Key = key;
Value = val;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow
private List<Tensor> _final_fetches = new List<Tensor>();
private List<T> _targets = new List<T>();

public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, Action feed_handles = null)
public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
{
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())


+ 2
- 2
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -160,9 +160,9 @@ namespace Tensorflow

if (!_is_empty)
{
/*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
var model_checkpoint_path1 = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
});*/
});
}

throw new NotImplementedException("");


+ 1
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -289,7 +289,7 @@ namespace Tensorflow
/// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session()
{
return tf.Session();
return tf.defaultSession;
}

public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)


+ 3
- 2
src/TensorFlowNET.Core/tf.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow
public static Context context = new Context(new ContextOptions(), new Status());

public static Graph g = new Graph();
public static Session session = new Session();
public static Session defaultSession;

public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid)
{
@@ -49,7 +49,8 @@ namespace Tensorflow

public static Session Session()
{
return session;
defaultSession = new Session();
return defaultSession;
}
}
}

Loading…
Cancel
Save