| @@ -41,7 +41,7 @@ namespace Tensorflow | |||||
| _graph_key = $"grap-key-{ops.uid()}/"; | _graph_key = $"grap-key-{ops.uid()}/"; | ||||
| } | } | ||||
| public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| { | { | ||||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ namespace Tensorflow | |||||
| return null; | return null; | ||||
| } | } | ||||
| private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| { | { | ||||
| string types_str = ""; | string types_str = ""; | ||||
| @@ -75,6 +75,19 @@ namespace Tensorflow | |||||
| if (temp_obj != null) | if (temp_obj != null) | ||||
| obj = temp_obj; | obj = temp_obj; | ||||
| // If obj appears to be a name... | |||||
| if (obj is String str) | |||||
| { | |||||
| if(str.Contains(":") && allow_tensor) | |||||
| { | |||||
| string op_name = str.Split(':')[0]; | |||||
| int out_n = int.Parse(str.Split(':')[1]); | |||||
| if (_nodes_by_name.ContainsKey(op_name)) | |||||
| return _nodes_by_name[op_name].outputs[out_n]; | |||||
| } | |||||
| } | |||||
| if (obj is Tensor tensor && allow_tensor) | if (obj is Tensor tensor && allow_tensor) | ||||
| { | { | ||||
| if (tensor.Graph.Equals(this)) | if (tensor.Graph.Equals(this)) | ||||
| @@ -166,7 +179,7 @@ namespace Tensorflow | |||||
| public void _add_op(Operation op) | public void _add_op(Operation op) | ||||
| { | { | ||||
| _nodes_by_id[op._id] = op; | _nodes_by_id[op._id] = op; | ||||
| //_nodes_by_name[op.name] = op; | |||||
| _nodes_by_name[op.name] = op; | |||||
| _version = Math.Max(_version, op._id); | _version = Math.Max(_version, op._id); | ||||
| } | } | ||||
| @@ -13,5 +13,7 @@ namespace Tensorflow | |||||
| string Device { get; } | string Device { get; } | ||||
| Operation op { get; } | Operation op { get; } | ||||
| string name { get; } | string name { get; } | ||||
| TF_DataType dtype { get; } | |||||
| Tensor[] outputs { get; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,7 +17,7 @@ namespace Tensorflow | |||||
| public string type => OpType; | public string type => OpType; | ||||
| public Operation op => this; | public Operation op => this; | ||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | |||||
| private Status status = new Status(); | private Status status = new Status(); | ||||
| public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | ||||
| @@ -44,9 +44,24 @@ namespace Tensorflow | |||||
| private NDArray _run(object fetches, FeedItem[] feed_dict = null) | private NDArray _run(object fetches, FeedItem[] feed_dict = null) | ||||
| { | { | ||||
| var feed_dict_tensor = new Dictionary<object, object>(); | var feed_dict_tensor = new Dictionary<object, object>(); | ||||
| var feed_map = new Dictionary<object, object>(); | |||||
| // Validate and process feed_dict. | |||||
| if (feed_dict != null) | if (feed_dict != null) | ||||
| feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); | |||||
| { | |||||
| foreach(var subfeed in feed_dict) | |||||
| { | |||||
| var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||||
| var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | |||||
| switch(subfeed.Value) | |||||
| { | |||||
| case string str: | |||||
| feed_dict_tensor[subfeed_t] = np.array(str); | |||||
| feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Create a fetch handler to take care of the structure of fetches. | // Create a fetch handler to take care of the structure of fetches. | ||||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | ||||
| @@ -20,21 +20,7 @@ namespace Tensorflow | |||||
| foreach(var fetch in fetches) | foreach(var fetch in fetches) | ||||
| { | { | ||||
| switch(fetch) | |||||
| { | |||||
| case Tensor tensor: | |||||
| el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||||
| break; | |||||
| case Operation op: | |||||
| el = g.as_graph_element(op, allow_tensor: true, allow_operation: true); | |||||
| break; | |||||
| case String str: | |||||
| // Looks like a Tensor name and can be a Tensor. | |||||
| el = g._nodes_by_name[str]; | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("_ElementFetchMapper"); | |||||
| } | |||||
| el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); | |||||
| } | } | ||||
| _unique_fetches.Add(el); | _unique_fetches.Add(el); | ||||
| @@ -21,6 +21,7 @@ namespace Tensorflow | |||||
| public Graph Graph => op?.Graph; | public Graph Graph => op?.Graph; | ||||
| public Operation op { get; } | public Operation op { get; } | ||||
| public Tensor[] outputs => op.outputs; | |||||
| /// <summary> | /// <summary> | ||||
| /// The string name of this tensor. | /// The string name of this tensor. | ||||
| @@ -18,7 +18,7 @@ namespace TensorFlowNET.Examples | |||||
| The value returned by the constructor represents the output | The value returned by the constructor represents the output | ||||
| of the Constant op. */ | of the Constant op. */ | ||||
| var str = "Hello, TensorFlow!"; | |||||
| var str = "Hello, TensorFlow.NET!"; | |||||
| var hello = tf.constant(str); | var hello = tf.constant(str); | ||||
| // Start tf session | // Start tf session | ||||