From 596afe2059c4d012d4dede25ddd57b560529fab9 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 10 Feb 2019 21:23:51 -0600 Subject: [PATCH] #173 --- src/TensorFlowNET.Core/Graphs/Graph.cs | 19 ++++++++++++++++--- src/TensorFlowNET.Core/ITensorOrOperation.cs | 2 ++ .../Operations/Operation.cs | 2 +- .../Sessions/BaseSession.cs | 17 ++++++++++++++++- .../Sessions/_ElementFetchMapper.cs | 16 +--------------- src/TensorFlowNET.Core/Tensors/Tensor.cs | 1 + test/TensorFlowNET.Examples/HelloWorld.cs | 2 +- 7 files changed, 38 insertions(+), 21 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 14ec0f00..42cf5111 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -41,7 +41,7 @@ namespace Tensorflow _graph_key = $"grap-key-{ops.uid()}/"; } - public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) + public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) { return _as_graph_element_locked(obj, allow_tensor, allow_operation); } @@ -54,7 +54,7 @@ namespace Tensorflow return null; } - private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) + private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) { string types_str = ""; @@ -75,6 +75,19 @@ namespace Tensorflow if (temp_obj != null) obj = temp_obj; + // If obj appears to be a name... + if (obj is String str) + { + if(str.Contains(":") && allow_tensor) + { + string op_name = str.Split(':')[0]; + int out_n = int.Parse(str.Split(':')[1]); + + if (_nodes_by_name.ContainsKey(op_name)) + return _nodes_by_name[op_name].outputs[out_n]; + } + } + if (obj is Tensor tensor && allow_tensor) { if (tensor.Graph.Equals(this)) @@ -166,7 +179,7 @@ namespace Tensorflow public void _add_op(Operation op) { _nodes_by_id[op._id] = op; - //_nodes_by_name[op.name] = op; + _nodes_by_name[op.name] = op; _version = Math.Max(_version, op._id); } diff --git a/src/TensorFlowNET.Core/ITensorOrOperation.cs b/src/TensorFlowNET.Core/ITensorOrOperation.cs index ed714f02..511fd116 100644 --- a/src/TensorFlowNET.Core/ITensorOrOperation.cs +++ b/src/TensorFlowNET.Core/ITensorOrOperation.cs @@ -13,5 +13,7 @@ namespace Tensorflow string Device { get; } Operation op { get; } string name { get; } + TF_DataType dtype { get; } + Tensor[] outputs { get; } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 1822de98..00a9241b 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -17,7 +17,7 @@ namespace Tensorflow public string type => OpType; public Operation op => this; - + public TF_DataType dtype => TF_DataType.DtInvalid; private Status status = new Status(); public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 21949aaa..5a828029 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -44,9 +44,24 @@ namespace Tensorflow private NDArray _run(object fetches, FeedItem[] feed_dict = null) { var feed_dict_tensor = new Dictionary(); + var feed_map = new Dictionary(); + // Validate and process feed_dict. if (feed_dict != null) - feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); + { + foreach(var subfeed in feed_dict) + { + var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); + var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); + switch(subfeed.Value) + { + case string str: + feed_dict_tensor[subfeed_t] = np.array(str); + feed_map[subfeed_t.name] = new Tuple(subfeed_t, subfeed.Value); + break; + } + } + } // Create a fetch handler to take care of the structure of fetches. var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index 3221285f..cec214a4 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -20,21 +20,7 @@ namespace Tensorflow foreach(var fetch in fetches) { - switch(fetch) - { - case Tensor tensor: - el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true); - break; - case Operation op: - el = g.as_graph_element(op, allow_tensor: true, allow_operation: true); - break; - case String str: - // Looks like a Tensor name and can be a Tensor. - el = g._nodes_by_name[str]; - break; - default: - throw new NotImplementedException("_ElementFetchMapper"); - } + el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); } _unique_fetches.Add(el); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index d4ae2549..d64086c1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -21,6 +21,7 @@ namespace Tensorflow public Graph Graph => op?.Graph; public Operation op { get; } + public Tensor[] outputs => op.outputs; /// /// The string name of this tensor. diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index 26dab3be..f726bc06 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -18,7 +18,7 @@ namespace TensorFlowNET.Examples The value returned by the constructor represents the output of the Constant op. */ - var str = "Hello, TensorFlow!"; + var str = "Hello, TensorFlow.NET!"; var hello = tf.constant(str); // Start tf session