Browse Source

feed dict

tags/v0.1.0-Tensor
Oceania2018 7 years ago
parent
commit
290a694e3c
7 changed files with 102 additions and 29 deletions
  1. +14
    -13
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +59
    -0
      src/TensorFlowNET.Core/Sessions/FeedDict.cs
  3. +5
    -5
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Sessions/c_api.session.cs
  5. +1
    -0
      src/TensorFlowNET.Core/c_api_util.cs
  6. +3
    -9
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  7. +18
    -0
      test/TensorFlowNET.UnitTest/PlaceholderTest.cs

+ 14
- 13
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -40,35 +40,35 @@ namespace Tensorflow
}

public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
public virtual object run(Tensor fetches, FeedDict feed_dict = null)
{
var result = _run(fetches, feed_dict);

return result;
}

private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
private unsafe object _run(Tensor fetches, FeedDict feed_dict = null)
{
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();
var feed_dict_tensor = new FeedDict();

if (feed_dict != null)
{
NDArray np_val = null;
foreach (var feed in feed_dict)
foreach (FeedValue feed in feed_dict)
{
switch (feed.Value)
switch (feed.feed_val)
{
case float value:
np_val = np.asarray(value);
break;
}

feed_dict_tensor[feed.Key] = np_val;
feed_dict_tensor[feed.feed] = np_val;
}
}

// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches);
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);

// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
@@ -80,19 +80,20 @@ namespace Tensorflow

// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_fetches);
var results = _do_run(final_fetches, feed_dict_tensor);

return fetch_handler.build_results(null, results);
}

private object[] _do_run(List<object> fetch_list)
private object[] _do_run(List<Tensor> fetch_list, FeedDict feed_dict)
{
var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray();
var feeds = feed_dict.items().Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();

return _call_tf_sessionrun(fetches);
return _call_tf_sessionrun(feeds, fetches);
}

private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list)
private unsafe object[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();
@@ -103,7 +104,7 @@ namespace Tensorflow

c_api.TF_SessionRun(_session,
run_options: IntPtr.Zero,
inputs: new TF_Output[] { },
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: new IntPtr[] { },
ninputs: 0,
outputs: fetch_list,


+ 59
- 0
src/TensorFlowNET.Core/Sessions/FeedDict.cs View File

@@ -0,0 +1,59 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class FeedDict : IEnumerable
{
private Dictionary<Tensor, object> feed_dict;

public FeedDict()
{
feed_dict = new Dictionary<Tensor, object>();
}

public object this[Tensor feed]
{
get
{
return feed_dict[feed];
}

set
{
feed_dict[feed] = value;
}
}

public FeedDict Add(Tensor feed, object value)
{
feed_dict.Add(feed, value);
return this;
}

public IEnumerator GetEnumerator()
{
foreach (KeyValuePair<Tensor, object> feed in feed_dict)
{
yield return new FeedValue
{
feed = feed.Key,
feed_val = feed.Value
};
}
}

public Dictionary<Tensor, object> items()
{
return feed_dict;
}
}

public struct FeedValue
{
public Tensor feed { get; set; }
public object feed_val { get; set; }
}
}

+ 5
- 5
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -10,12 +10,12 @@ namespace Tensorflow
public class _FetchHandler
{
private _ElementFetchMapper _fetch_mapper;
private List<object> _fetches = new List<object>();
private List<Tensor> _fetches = new List<Tensor>();
private List<bool> _ops = new List<bool>();
private List<object> _final_fetches = new List<object>();
private List<Tensor> _final_fetches = new List<Tensor>();
private List<object> _targets = new List<object>();

public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null)
public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null)
{
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())
@@ -24,7 +24,7 @@ namespace Tensorflow
{
case Tensor val:
_assert_fetchable(graph, val.op);
_fetches.Add(fetch);
_fetches.Add(val);
_ops.Add(false);
break;
}
@@ -47,7 +47,7 @@ namespace Tensorflow
}
}

public List<Object> fetches()
public List<Tensor> fetches()
{
return _final_fetches;
}


+ 2
- 2
src/TensorFlowNET.Core/Sessions/c_api.session.cs View File

@@ -38,8 +38,8 @@ namespace Tensorflow
/// </summary>
/// <param name="session"></param>
/// <param name="run_options"></param>
/// <param name="inputs"></param>
/// <param name="input_values"></param>
/// <param name="inputs">TF_Output</param>
/// <param name="input_values">TF_Tensor</param>
/// <param name="ninputs"></param>
/// <param name="outputs"></param>
/// <param name="output_values"></param>


+ 1
- 0
src/TensorFlowNET.Core/c_api_util.cs View File

@@ -8,6 +8,7 @@ namespace Tensorflow
{
public static TF_Output tf_output(IntPtr c_op, int index)
{
var ret = new TF_Output();
ret.oper = c_op;
ret.index = index;


+ 3
- 9
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -9,12 +9,6 @@ namespace TensorFlowNET.UnitTest
[TestClass]
public class OperationsTest
{
[TestMethod]
public void placeholder()
{
var x = tf.placeholder(tf.float32);
}

[TestMethod]
public void addInPlaceholder()
{
@@ -24,9 +18,9 @@ namespace TensorFlowNET.UnitTest

using(var sess = tf.Session())
{
var feed_dict = new Dictionary<Tensor, object>();
feed_dict.Add(a, 3.0f);
feed_dict.Add(b, 2.0f);
var feed_dict = new FeedDict()
.Add(a, 3.0f)
.Add(b, 2.0f);

var o = sess.run(c, feed_dict);
}


+ 18
- 0
test/TensorFlowNET.UnitTest/PlaceholderTest.cs View File

@@ -0,0 +1,18 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class PlaceholderTest
{
[TestMethod]
public void placeholder()
{
var x = tf.placeholder(tf.float32);
}
}
}

Loading…
Cancel
Save